Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
A custom DeepStack model for detecting 16 human actions.

DeepStack_ActionNET This repository provides a custom DeepStack model that has been trained and can be used for creating a new object detection API fo

MOSES OLAFENWA 16 Nov 11, 2022
Application of K-means algorithm on a music dataset after a dimensionality reduction with PCA

PCA for dimensionality reduction combined with Kmeans Goal The Goal of this notebook is to apply a dimensionality reduction on a big dataset in order

Arturo Ghinassi 0 Sep 17, 2022
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

114 Jul 19, 2022
Pipeline for employing a Lightweight deep learning models for LOW-power systems

PL-LOW A high-performance deep learning model lightweight pipeline that gradually lightens deep neural networks in order to utilize high-performance d

POSTECH Data Intelligence Lab 9 Aug 13, 2022
Source code of article "Towards Toxic and Narcotic Medication Detection with Rotated Object Detector"

Towards Toxic and Narcotic Medication Detection with Rotated Object Detector Introduction This is the source code of article: Towards Toxic and Narcot

Woody. Wang 3 Oct 29, 2022
Code for paper "Vocabulary Learning via Optimal Transport for Neural Machine Translation"

**Codebase and data are uploaded in progress. ** VOLT(-py) is a vocabulary learning codebase that allows researchers and developers to automaticaly ge

416 Jan 09, 2023
Author: Wenhao Yu ([email protected]). ACL 2022. Commonsense Reasoning on Knowledge Graph for Text Generation

Diversifying Commonsense Reasoning Generation on Knowledge Graph Introduction -- This is the pytorch implementation of our ACL 2022 paper "Diversifyin

DM2 Lab @ ND 61 Dec 30, 2022
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
A PyTorch Implementation of Neural IMage Assessment

NIMA: Neural IMage Assessment This is a PyTorch implementation of the paper NIMA: Neural IMage Assessment (accepted at IEEE Transactions on Image Proc

yunxiaos 418 Dec 29, 2022
A Simple Example for Imitation Learning with Dataset Aggregation (DAGGER) on Torcs Env

Imitation Learning with Dataset Aggregation (DAGGER) on Torcs Env This repository implements a simple algorithm for imitation learning: DAGGER. In thi

Hao 66 Nov 23, 2022
[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Graph Stochastic Attention (GSAT) The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Att

85 Nov 27, 2022
Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

VIN: Value Iteration Networks A quick thank you A few others have released amazing related work which helped inspire and improve my own implementation

Kent Sommer 297 Dec 26, 2022
codes for Image Inpainting with External-internal Learning and Monochromic Bottleneck

Image Inpainting with External-internal Learning and Monochromic Bottleneck This repository is for the CVPR 2021 paper: 'Image Inpainting with Externa

97 Nov 29, 2022
MT3: Multi-Task Multitrack Music Transcription

MT3: Multi-Task Multitrack Music Transcription MT3 is a multi-instrument automatic music transcription model that uses the T5X framework. This is not

Magenta 867 Dec 29, 2022
3D Pose Estimation for Vehicles

3D Pose Estimation for Vehicles Introduction This work generates 4 key-points and 2 key-edges from vertices and edges of vehicles as ground truth. The

Jingyi Wang 1 Nov 01, 2021
Parameterising Simulated Annealing for the Travelling Salesman Problem

Parameterising Simulated Annealing for the Travelling Salesman Problem

Gary Sun 55 Jun 15, 2022
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
Technical experimentations to beat the stock market using deep learning :chart_with_upwards_trend:

DeepStock Technical experimentations to beat the stock market using deep learning. Experimentations Deep Learning Stock Prediction with Daily News Hea

Keon 449 Dec 29, 2022
Implementation of ToeplitzLDA for spatiotemporal stationary time series data.

Code for the ToeplitzLDA classifier proposed in here. The classifier conforms sklearn and can be used as a drop-in replacement for other LDA classifiers. For in-depth usage refer to the learning from

Jan Sosulski 5 Nov 07, 2022
2 Jul 19, 2022