Research code for the paper "Variational Gibbs inference for statistical estimation from incomplete data".

Overview

Variational Gibbs inference (VGI)

This repository contains the research code for

Simkus, V., Rhodes, B., Gutmann, M. U., 2021. Variational Gibbs inference for statistical model estimation from incomplete data.

The code is shared for reproducibility purposes and is not intended for production use. It should also serve as a reference implementation for anyone wanting to use VGI for model estimation from incomplete data.

Abstract

Statistical models are central to machine learning with broad applicability across a range of downstream tasks. The models are typically controlled by free parameters that are estimated from data by maximum-likelihood estimation. However, when faced with real-world datasets many of the models run into a critical issue: they are formulated in terms of fully-observed data, whereas in practice the datasets are plagued with missing data. The theory of statistical model estimation from incomplete data is conceptually similar to the estimation of latent-variable models, where powerful tools such as variational inference (VI) exist. However, in contrast to standard latent-variable models, parameter estimation with incomplete data often requires estimating exponentially-many conditional distributions of the missing variables, hence making standard VI methods intractable. We address this gap by introducing variational Gibbs inference (VGI), a new general-purpose method to estimate the parameters of statistical models from incomplete data.

VGI demo

We invite the readers of the paper to also see the Jupyter notebook, where we demonstrate VGI on two statistical models and animate the learning process to help better understand the method.

Below is an animation from the notebook of a Gaussian Mixture Model fitted from incomplete data using the VGI algorithm (left), and the variational Gibbs conditional approximations (right) throughout iterations.

demo_vgi_mog_fit.mp4

Dependencies

Install python dependencies from conda and the cdi project package with

conda env create -f environment.yml
conda activate cdi
python setup.py develop

If the dependencies in environment.yml change, update dependencies with

conda env update --file environment.yml

Summary of the repository structure

Data

All data used in the paper are stored in data directory and the corresponding data loaders can be found in cdi/data directory.

Method code

The main code to the various methods used in the paper can be found in cdi/trainers directory.

  • trainer_base.py implements the main data loading and preprocessing code.
  • variational_cdi.py and cdi.py implement the key code for variational Gibbs inference (VGI).
  • mcimp.py implements the code for variational block-Gibbs inference (VBGI) used in the VAE experiments.
  • The other scripts in cdi/trainers implement the comparison methods and variational conditional pre-training.

Statistical models

The code for the statistical (factor analysis, VAEs, and flows) and the variational models are located in cdi/models.

Configuration files

The experiment_configs directory contains the configuration files for all experiments. The config files include all the hyperparameter settings necessary to reproduce our results. The config files are in a json format. They are passed to the main running script as a command-line argument and values in them can be overriden with additional command-line arguments.

Run scripts

train.py is the main code we use to run the experiments, and test.py is the main script to produce analysis results presented in the paper.

Analysis code

The Jupyter notebooks in notebooks directory contain the code which was used to analysis the method and produce figures in the paper. You should also be able to use these notebooks to find the corresponding names of the config files for the experiments in the paper.

Running the code

Before running any code you'll need to activate the cdi conda environment (and make sure you've installed the dependencies)

conda activate cdi

Model fitting

To train a model use the train.py script, for example, to fit a rational-quadratic spline flow on 50% missing MiniBooNE dataset

python train.py --config=experiment_configs/flows_uci/learning_experiments/3/rqcspline_miniboone_chrqsvar_cdi_uncondgauss.json

Any parameters set in the config file can be overriden by passing additionals command-line arguments, e.g.

python train.py --config=experiment_configs/flows_uci/learning_experiments/3/rqcspline_miniboone_chrqsvar_cdi_uncondgauss.json --data.total_miss=0.33

Optional variational model warm-up

Some VGI experiments use variational model "warm-up", which pre-trains the variational model on observed data as probabilistic regressors. The experiment configurations for these runs will have var_pretrained_model set to the name of the pre-trained model. To run the corresponding pre-training script run, e.g.

python train.py --config=experiment_configs/flows_uci/learning_experiments/3/miniboone_chrqsvar_pretraining_uncondgauss.json

Running model evaluation

For model evaluation use test.py with the corresponding test config, e.g.

python test.py --test_config=experiment_configs/flows_uci/eval_loglik/3/rqcspline_miniboone_chrqsvar_cdi_uncondgauss.json

This will store all results in a file that we then analyse in the provided notebook.

For the VAE evaluation, where variational distribution fine-tuning is required for test log-likelihood evaluation use retrain_all_ckpts_on_test_and_run_test.py.

Using this codebase on your own task

While the main purpose of this repository is reproducibility of the research paper and a demonstration of the method, you should be able to adapt the code to fit your statistical models. We would advise you to first see the Jupyter notebook demo. The notebook provides an example of how to implement the target statistical model as well as the variational model of the conditionals, you can find further examples in cdi/models directory. If you intend to use a variational family that is different to ours you will also need to implement the corresponding sampling functions here.

Owner
Vaidotas Šimkus
PhD candidate in Data Science at the University of Edinburgh. Interested in deep generative models, variational inference, and the Bayesian principle.
Vaidotas Šimkus
A privacy-focused, intelligent security camera system.

Self-Hosted Home Security Camera System A privacy-focused, intelligent security camera system. Features: Multi-camera support w/ minimal configuration

Scott Barnes 175 Jan 01, 2023
Yolov3 pytorch implementation

YOLOV3 Pytorch实现 在bubbliiing大佬代码的基础上进行了修改,添加了部分注释。 预训练模型 预训练模型来源于bubbliiing。 链接:https://pan.baidu.com/s/1ncREw6Na9ycZptdxiVMApw 提取码:appk 训练自己的数据集 按照VO

4 Aug 27, 2022
Background-Click Supervision for Temporal Action Localization

Background-Click Supervision for Temporal Action Localization This repository is the official implementation of BackTAL. In this work, we study the te

LeYang 221 Oct 09, 2022
Development Kit for the SoccerNet Challenge

SoccerNetv2-DevKit Welcome to the SoccerNet-V2 Development Kit for the SoccerNet Benchmark and Challenge. This kit is meant as a help to get started w

Silvio Giancola 117 Dec 30, 2022
PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech

PortaSpeech - PyTorch Implementation PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech. Model Size Module Nor

Keon Lee 279 Jan 04, 2023
Quantized tflite models for ailia TFLite Runtime

ailia-models-tflite Quantized tflite models for ailia TFLite Runtime About ailia TFLite Runtime ailia TF Lite Runtime is a TensorFlow Lite compatible

ax Inc. 13 Dec 23, 2022
Regression Metrics Calculation Made easy for tensorflow2 and scikit-learn

Regression Metrics Installation To install the package from the PyPi repository you can execute the following command: pip install regressionmetrics I

Ashish Patel 11 Dec 16, 2022
BlockUnexpectedPackets - Preventing BungeeCord CPU overload due to Layer 7 DDoS attacks by scanning BungeeCord's logs

BlockUnexpectedPackets This script automatically blocks DDoS attacks that are sp

SparklyPower 3 Mar 31, 2022
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022
RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation Ported from https://github.com/hzwer/arXiv2020-RIFE Dependencies NumPy

49 Jan 07, 2023
A tool for making map images from OpenTTD save games

OpenTTD Surveyor A tool for making map images from OpenTTD save games. This is not part of the main OpenTTD codebase, nor is it ever intended to be pa

Aidan Randle-Conde 9 Feb 15, 2022
TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.

TensorFlow-Image-Models Introduction Usage Models Profiling License Introduction TensorfFlow-Image-Models (tfimm) is a collection of image models with

Martins Bruveris 227 Dec 20, 2022
Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting

Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting Note: You can find here the accompanying seq2seq RNN forecas

Guillaume Chevalier 1k Dec 25, 2022
wgan, wgan2(improved, gp), infogan, and dcgan implementation in lasagne, keras, pytorch

Generative Adversarial Notebooks Collection of my Generative Adversarial Network implementations Most codes are for python3, most notebooks works on C

tjwei 1.5k Dec 16, 2022
Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021)

UNITE and UNITE+ Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021) Unbalanced Intrinsic Feature Transport for Exemplar-bas

Fangneng Zhan 183 Nov 09, 2022
NHS AI Lab Skunkworks project: Long Stayer Risk Stratification

NHS AI Lab Skunkworks project: Long Stayer Risk Stratification A pilot project for the NHS AI Lab Skunkworks team, Long Stayer Risk Stratification use

NHSX 21 Nov 14, 2022
Educational 2D SLAM implementation based on ICP and Pose Graph

slam-playground Educational 2D SLAM implementation based on ICP and Pose Graph How to use: Use keyboard arrow keys to navigate robot. Press 'r' to vie

Kirill 19 Dec 17, 2022
Classifies galaxy morphology with Bayesian CNN

Zoobot Zoobot classifies galaxy morphology with deep learning. This code will let you: Reproduce and improve the Galaxy Zoo DECaLS automated classific

Mike Walmsley 39 Dec 20, 2022
Tensorflow 2.x implementation of Vision-Transformer model

Vision Transformer Unofficial Tensorflow 2.x implementation of the Transformer based Image Classification model proposed by the paper AN IMAGE IS WORT

Soumik Rakshit 16 Jul 20, 2022
Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Lightweight (Bayesian) Media Mix Model This is not an official Google product. L

Google 342 Jan 03, 2023