Natural Posterior Network: Deep Bayesian Predictive Uncertainty for Exponential Family Distributions

Overview

Natural Posterior Network

This repository provides the official implementation of the Natural Posterior Network (NatPN) and the Natural Posterior Ensemble (NatPE) as presented in the following paper:

Natural Posterior Network: Deep Bayesian Predictive Uncertainty for Exponential Family Distributions
Bertrand Charpentier*, Oliver Borchert*, Daniel Zügner, Simon Geisler, Stephan Günnemann
International Conference on Learning Representations, 2022

Features

The implementation of NatPN that is found in this repository provides the following features:

  • High-level estimator interface that makes NatPN as easy to use as Scikit-learn estimators
  • Simple bash script to train and evaluate NatPN
  • Ready-to-use PyTorch Lightning data modules with 8 of the 9 datasets used in the paper*

In addition, we provide a public Weights & Biases project. This project will be filled with training and evaluation runs that allow you (1) to inspect the performance of different NatPN models and (2) to download the model parameters. See the example notebook for instructions on how to use such a pretrained model.

*The Kin8nm dataset is not included as it has disappeared from the UCI Repository.

Installation

Prior to installation, you may want to install all dependencies (Python, CUDA, Poetry). If you are running on an AWS EC2 instance with Ubuntu 20.04, you can use the provided bash script:

sudo bash bin/setup-ec2.sh

In order to use the code in this repository, you should first clone the repository:

git clone [email protected]:borchero/natural-posterior-network.git natpn

Then, in the root of the repository, you can install all dependencies via Poetry:

poetry install

Quickstart

Shell Script

To simply train and evaluate NatPN on a particular dataset, you can use the train shell script. For example, to train and evaluate NatPN on the Sensorless Drive dataset, you can run the following command in the root of the repository:

poetry run train --dataset sensorless-drive

The dataset gets downloaded automatically the first time this command is called. The performance metrics of the trained model is printed to the console and the trained model is discarded. In order to track both the metrics and the model parameters via Weights & Biases, use the following command:

poetry run train --dataset sensorless-drive --experiment first-steps

To list all options of the shell script, simply run:

poetry run train --help

This command will also provide explanations for all the parameters that can be passed.

Estimator

If you want to use NatPN from your code, the easiest way to get started is to use the Scikit-learn-like estimator:

from natpn import NaturalPosteriorNetwork

The documentation of the estimator's __init__ method provides a comprehensive overview of all the configuration options. For a simple example of using the estimator, refer to the example notebook.

Module

If you need even more customization, you can use natpn.nn.NaturalPosteriorNetworkModel directly. The natpn.nn package provides plenty of documentation and allows to configure your NatPN model as much as possible.

Further, the natpn.model package provides PyTorch Lightning modules which allow you to train, evaluate, and fine-tune models.

Running Hyperparameter Searches

If you want to run hyperparameter searches on a local Slurm cluster, you can use the files provided in the sweeps directory. To run the grid search, simply execute the file:

poetry run python sweeps/<file>

To make sure that your experiment is tracked correctly, you should also set the WANDB_PROJECT environment variable in a place that is read by the slurm script (found in sweeps/slurm).

Feel free to adapt the scripts to your liking to run your own hyperparameter searches.

Citation

If you are using the model or the code in this repository, please cite the following paper:

@inproceedings{natpn,
    title={{Natural} {Posterior} {Network}: {Deep} {Bayesian} {Predictive} {Uncertainty} for {Exponential} {Family} {Distributions}},
    author={Charpentier, Bertrand and Borchert, Oliver and Z\"{u}gner, Daniel and Geisler, Simon and G\"{u}nnemann, Stephan},
    booktitle={International Conference on Learning Representations},
    year={2022}
}

Contact Us

If you have any questions regarding the code, please contact us via mail.

License

The code in this repository is licensed under the MIT License.

Owner
Oliver Borchert
MSc Data Engineering and Analytics @ TUM | Applied Science Intern @ AWS
Oliver Borchert
PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation The paper: https://arxiv.org/abs/1704.03296 What makes

Jacob Gildenblat 322 Dec 17, 2022
Implementation of DocFormer: End-to-End Transformer for Document Understanding, a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU)

DocFormer - PyTorch Implementation of DocFormer: End-to-End Transformer for Document Understanding, a multi-modal transformer based architecture for t

171 Jan 06, 2023
Code for the Higgs Boson Machine Learning Challenge organised by CERN & EPFL

A method to solve the Higgs boson challenge using Least Squares - Novae This project is the Project 1 of EPFL CS-433 Machine Learning. The project is

Giacomo Orsi 1 Nov 09, 2021
MassiveSumm: a very large-scale, very multilingual, news summarisation dataset

MassiveSumm: a very large-scale, very multilingual, news summarisation dataset This repository contains links to data and code to fetch and reproduce

Daniel Varab 19 Dec 16, 2022
Open source repository for the code accompanying the paper 'PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations'.

PatchNets This is the official repository for the project "PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations". For details,

16 May 22, 2022
A Keras implementation of YOLOv4 (Tensorflow backend)

keras-yolo4 请使用更完善的版本: https://github.com/miemie2013/Keras-YOLOv4 Please visit here for more complete model: https://github.com/miemie2013/Keras-YOLOv

384 Nov 29, 2022
Unifying Global-Local Representations in Salient Object Detection with Transformer

GLSTR (Global-Local Saliency Transformer) This is the official implementation of paper "Unifying Global-Local Representations in Salient Object Detect

11 Aug 24, 2022
The code for the NSDI'21 paper "BMC: Accelerating Memcached using Safe In-kernel Caching and Pre-stack Processing".

BMC The code for the NSDI'21 paper "BMC: Accelerating Memcached using Safe In-kernel Caching and Pre-stack Processing". BibTex entry available here. B

Orange 383 Dec 16, 2022
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022
DAT4 - General Assembly's Data Science course in Washington, DC

DAT4 Course Repository Course materials for General Assembly's Data Science course in Washington, DC (12/15/14 - 3/16/15). Instructors: Sinan Ozdemir

Kevin Markham 779 Dec 25, 2022
Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021)

Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021) The implementation of Reducing Infromation Bottleneck for W

Jungbeom Lee 81 Dec 16, 2022
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
A collection of inference modules for fastai2

fastinference A collection of inference modules for fastai including inference speedup and interpretability Install pip install fastinference There ar

Zachary Mueller 83 Oct 10, 2022
Simultaneous NMT/MMT framework in PyTorch

This repository includes the codes, the experiment configurations and the scripts to prepare/download data for the Simultaneous Machine Translation wi

<a href=[email protected]"> 37 Sep 29, 2022
A Python-based development platform for automated trading systems - from backtesting to optimisation to livetrading.

AutoTrader AutoTrader is Python-based platform intended to help in the development, optimisation and deployment of automated trading systems. From sim

Kieran Mackle 485 Jan 09, 2023
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023
Source code for CVPR 2020 paper "Learning to Forget for Meta-Learning"

L2F - Learning to Forget for Meta-Learning Sungyong Baik, Seokil Hong, Kyoung Mu Lee Source code for CVPR 2020 paper "Learning to Forget for Meta-Lear

Sungyong Baik 29 May 22, 2022
On the model-based stochastic value gradient for continuous reinforcement learning

On the model-based stochastic value gradient for continuous reinforcement learning This repository is by Brandon Amos, Samuel Stanton, Denis Yarats, a

Facebook Research 46 Dec 15, 2022
This is the code used in the paper "Entity Embeddings of Categorical Variables".

This is the code used in the paper "Entity Embeddings of Categorical Variables". If you want to get the original version of the code used for the Kagg

Cheng Guo 845 Nov 29, 2022