TyXe: Pyro-based BNNs for Pytorch users

Related tags

Deep LearningTyXe
Overview

TyXe: Pyro-based BNNs for Pytorch users

TyXe aims to simplify the process of turning Pytorch neural networks into Bayesian neural networks by leveraging the model definition and inference capabilities of Pyro. Our core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution and likelihood, enabling a flexible workflow where each component can be exchanged independently. Defining a BNN in TyXe takes as little as 5 lines of code:

net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1)
inference = tyxe.guides.AutoNormal
bnn = tyxe.VariationalBNN(net, prior, likelihood, inference)

In the following, we assume that you (roughly) know what a BNN is mathematically.

Motivating example

Standard neural networks give us a single function that fits the data, but many different ones are typically plausible. With only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and where it is uncertain.

ML Samples
Maximum likelihood fit Posterior samples

Implementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a distribution over different fits is typically more complicated and is specifically what we aim to simplify.

Training

Constructing a BNN object has been shown in the example above. For fitting the posterior approximation, we provide a high-level .fit method similar to libraries such as scikit-learn or keras:

optim = pyro.optim.Adam({"lr": 1e-3})
bnn.fit(data_loader, optim, num_epochs)

Prediction & evaluation

Further we provide .predict and .evaluation methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure:

predictions = bnn.predict(x_test, num_samples)
error, log_likelihood = bnn.evaluate(x_test, y_test, num_samples)

Local reparameterization

We implement local reparameterization for factorized Gaussians as a poutine, which reduces gradient noise during training. This means it can be enabled or disabled at both during training and prediction with a context manager:

with tyxe.poutine.local_reparameterization():
    bnn.fit(data_loader, optim, num_epochs)
    bnn.predict(x_test, num_predictions)

At the moment, this poutine does not work with the AutoNormal and AutoDiagonalNormal guides in pyro, since those draw the weights from a Delta distribution, so you need to use tyxe.guides.ParameterwiseDiagonalNormal as your guide.

MCMC

We provide a unified interface to pyro's MCMC implementations, simply use the tyxe.MCMC_BNN class instead and provide a kernel instead of the guide:

kernel = pyro.infer.mcmcm.NUTS
bnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel)

Any parameters that pyro's MCMC class accepts can be passed through the keyword arguments of the .fit method.

Continual learning

Due to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting. For example, you can construct a tyxe.priors.DictPrior by extracting the distributions over all weights and biases from a ParameterwiseDiagonalNormal instance using the get_detached_distributions method and pass it to bnn.update_prior to implement Variational Continual Learning in a few lines of code. See examples/vcl.py for a basic example on split-MNIST and split-CIFAR.

Network architectures

We don't implement any layer classes. You construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks.

Inference

For inference, we mainly provide an equivalent to pyro's AutoDiagonalNormal that is compatible with local reparameterization in tyxe.guides. This module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network. It should be possible to use any of pyro's autoguides for variational inference. See examples/resnet.py for a few options as well as initializing to pre-trained weights.

Priors

The priors can be found in tyxe.priors. We currently only support placing priors on the parameters. Through the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior. This helps, for example in learning the parameters of BatchNorm layers deterministically.

Likelihoods

tyxe.observation_models contains classes that wrap the most common torch.distributions for specifying noise models of data to

Installation

We recommend installing TyXe using conda with the provided environment.yml, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually. The environment assumes that you are using CUDA11.0, if this is not the case, simply change the cudatoolkit and dgl-cuda versions before running:

conda env create -f environment.yml
conda activate tyxe
pip install -e .

Citation

If you use TyXe, please consider citing:

@article{ritter2021tyxe,
  author    = {Hippolyt Ritter and
               Theofanis Karaletsos
               },
  title     = {TyXe: Pyro-based Bayesian neural nets for Pytorch},
  journal   = {International Conference on Probabilistic Programming (ProbProg)},
  volume    = {},
  pages     = {},
  year      = {2020},
  url       = {https://arxiv.org/abs/2110.00276}
}
Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy

Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy Simplex Algorithm is a popular algorithm for linear programmi

Reda BELHAJ 2 Oct 12, 2022
Bayesian optimisation library developped by Huawei Noah's Ark Library

Bayesian Optimisation Research This directory contains official implementations for Bayesian optimisation works developped by Huawei R&D, Noah's Ark L

HUAWEI Noah's Ark Lab 395 Dec 30, 2022
Live Hand Tracking Using Python

Live-Hand-Tracking-Using-Python Project Description: In this project, we will be

Hassan Shahzad 2 Jan 06, 2022
CTF Challenge for CSAW Finals 2021

Terminal Velocity Misc CTF Challenge for CSAW Finals 2021 This is a challenge I've had in mind for almost 15 years and never got around to building un

Jordan 6 Jul 30, 2022
Source code and notebooks to reproduce experiments and benchmarks on Bias Faces in the Wild (BFW).

Face Recognition: Too Bias, or Not Too Bias? Robinson, Joseph P., Gennady Livitz, Yann Henon, Can Qin, Yun Fu, and Samson Timoner. "Face recognition:

Joseph P. Robinson 41 Dec 12, 2022
[ECCV 2020] XingGAN for Person Image Generation

Contents XingGAN or CrossingGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowl

Hao Tang 218 Oct 29, 2022
Python suite to construct benchmark machine learning datasets from the MIMIC-III clinical database.

MIMIC-III Benchmarks Python suite to construct benchmark machine learning datasets from the MIMIC-III clinical database. Currently, the benchmark data

Chengxi Zang 6 Jan 02, 2023
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Ren Yurui 261 Jan 09, 2023
Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment"

DSN-IQA Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment" Requirements Python =3.8.0 Pytorch =1.7.1 Usage wit

7 Oct 13, 2022
22 Oct 14, 2022
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
This project is used for the paper Differentiable Programming of Isometric Tensor Network

This project is used for the paper "Differentiable Programming of Isometric Tensor Network". (arXiv:2110.03898)

Chenhua Geng 15 Dec 13, 2022
A stock generator that assess a list of stocks and returns the best stocks for investing and money allocations based on users choices of volatility, duration and number of stocks

Stock-Generator Please visit "Stock Generator.ipynb" for a clearer view and "Stock Generator.py" for scripts. The stock generator is designed to allow

jmengnyay 1 Aug 02, 2022
Freecodecamp Scientific Computing with Python Certification; Solution for Challenge 2: Time Calculator

Assignment Write a function named add_time that takes in two required parameters and one optional parameter: a start time in the 12-hour clock format

Hellen Namulinda 0 Feb 26, 2022
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 50 Dec 03, 2022
Does Pretraining for Summarization Reuqire Knowledge Transfer?

Pretraining summarization models using a corpus of nonsense

Approximately Correct Machine Intelligence (ACMI) Lab 12 Dec 19, 2022
DeRF: Decomposed Radiance Fields

DeRF: Decomposed Radiance Fields Daniel Rebain, Wei Jiang, Soroosh Yazdani, Ke Li, Kwang Moo Yi, Andrea Tagliasacchi Links Paper Project Page Abstract

UBC Computer Vision Group 24 Dec 02, 2022
NeurIPS 2021, self-supervised 6D pose on category level

SE(3)-eSCOPE video | paper | website Leveraging SE(3) Equivariance for Self-Supervised Category-Level Object Pose Estimation Xiaolong Li, Yijia Weng,

Xiaolong 63 Nov 22, 2022
Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Elias Kassapis 31 Nov 22, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 27, 2022