PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Related tags

Deep Learningflowgmm
Overview

Flow Gaussian Mixture Model (FlowGMM)

This repository contains a PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Semi-Supervised Learning with Normalizing Flows

by Pavel Izmailov, Polina Kirichenko, Marc Finzi and Andrew Gordon Wilson.

Introduction

Normalizing flows transform a latent distribution through an invertible neural network for a flexible and pleasingly simple approach to generative modelling, while preserving an exact likelihood. In this paper, we introduce FlowGMM (Flow Gaussian Mixture Model), an approach to semi-supervised learning with normalizing flows, by modelling the density in the latent space as a Gaussian mixture, with each mixture component corresponding to a class represented in the labelled data. FlowGMM is distinct in its simplicity, unified treatment of labelled and unlabelled data with an exact likelihood, interpretability, and broad applicability beyond image data.

We show promising results on a wide range of semi-supervised classification problems, including AG-News and Yahoo Answers text data, UCI tabular data, and image datasets (MNIST, CIFAR-10 and SVHN).

Screenshot from 2019-12-29 19-32-26

Please cite our work if you find it useful:

@article{izmailov2019semi,
  title={Semi-Supervised Learning with Normalizing Flows},
  author={Izmailov, Pavel and Kirichenko, Polina and Finzi, Marc and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:1912.13025},
  year={2019}
}

Installation

To run the scripts you will need to clone the repo and install it locally. You can use the commands below.

git clone https://github.com/izmailovpavel/flowgmm.git
cd flowgmm
pip install -e .

Dependencies

We have the following dependencies for FlowGMM that must be installed prior to install to FlowGMM

We provide the scripts and example commands to reproduce the experiments from the paper.

Synthetic Datasets

The experiments on synthetic data are implemented in this ipython notebook. We additionaly provide another ipython notebook applying FlowGMM to labeled data only.

Tabular Datasets

The tabular datasets will be download and preprocessed automatically the first time they are needed. Using the commands below you can reproduce the performance from the table.

AGNEWS YAHOO HEPMASS MINIBOONE
MLP 77.5 55.7 82.2 80.4
Pi Model 80.2 56.3 87.9 80.8
FlowGMM 82.1 57.9 88.5 81.9

Text Classification (Updated)

Train FlowGMM on AG-News (200 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':.6}" --net_config "{'k':1024,'coupling_layers':7,'nperlayer':1}" --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 100 --dataset AG_News --lr 3e-4 --train 200

Train FlowGMM on YAHOO Answers (800 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':.2}" --net_config "{'k':1024,'coupling_layers':7,'nperlayer':1}" --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 200 --dataset YAHOO --lr 3e-4 --train 800

UCI Data

Train FlowGMM on MINIBOONE (20 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':3.}"\
 --net_config "{'k':256,'coupling_layers':10,'nperlayer':1}" --network RealNVPTabularWPrior \
 --trainer SemiFlow --num_epochs 300 --dataset MINIBOONE --lr 3e-4

Train FlowGMM on HEPMASS (20 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':10}"\
 --net_config "{'k':256,'coupling_layers':10,'nperlayer':1}" \
 --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 15 --dataset HEPMASS

Note that for on the low dimensional tabular data the FlowGMM models are quite sensitive to initialization. You may want to run the script a couple of times in case the model does not recover from a bad init.

The training script for the UCI dataset will automatically download the relevant MINIBOONE or HEPMASS datasets and unpack them into ~/datasets/UCI/., but for reference they come from here and here. We follow the preprocessing (where sensible) from Masked Autoregressive Flow for Density Estimation.

Baselines

Training the 3 Layer NN + Dropout on

YAHOO Answers: python experiments/train_flows/flowgmm_tabular_new.py --lr=1e-3 --dataset YAHOO --num_epochs 1000 --train 800

AG-NEWS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset AG_News --num_epochs 1000 --train 200

MINIBOONE: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset MINIBOONE --num_epochs 500

HEPMASS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset HEPMASS --num_epochs 500

Training the Pi Model on

YAHOO Answers: python flowgmm_tabular_new.py --lr=1e-3 --dataset YAHOO --num_epochs 300 --train 800 --trainer PiModel --trainer_config "{'cons_weight':.3}"

AG-NEWS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-3 --dataset AG_News --num_epochs 100 --train 200 --trainer PiModel --trainer_config "{'cons_weight':30}"

MINIBOONE: python flowgmm_tabular_new.py --lr 3e-4 --dataset MINIBOONE --trainer PiModel --trainer_config "{'cons_weight':30}" --num_epochs 10

HEPMASS: python experiments/train_flows/flowgmm_tabular_new.py --trainer PiModel --num_epochs 10 --dataset MINIBOONE --trainer_config "{'cons_weight':3}" --lr 1e-4

The notebook here can be used to run the kNN, Logistic Regression, and Label Spreading baselines once the data has already been downloaded by the previous scripts or if it was downloaded manually.

Image Classification

To run experiments with FlowGMM on image classification problems you first need to download and prepare the data. To do so, run the following scripts:

./data/bin/prepare_cifar10.sh
./data/bin/prepare_mnist.sh
./data/bin/prepare_svhn.sh

To run FlowGMM, you can use the following script

python3 experiments/train_flows/train_semisup_cons.py \
  --dataset=<DATASET> \
  --data_path=<DATAPATH> \
  --label_path=<LABELPATH> \
  --logdir=<LOGDIR> \
  --ckptdir=<CKPTDIR> \
  --save_freq=<SAVEFREQ> \ 
  --num_epochs=<EPOCHS> \
  --label_weight=<LABELWEIGHT> \
  --consistency_weight=<CONSISTENCYWEIGHT> \
  --consistency_rampup=<CONSISTENCYRAMPUP> \
  --lr=<LR> \
  --eval_freq=<EVALFREQ> \

Parameters:

  • DATASET — dataset name [MNIST/CIFAR10/SVHN]
  • DATAPATH — path to the directory containing data; if you used the data preparation scripts, you can use e.g. data/images/mnist as DATAPATH
  • LABELPATH — path to the label split generated by the data preparation scripts; this can be e.g. data/labels/mnist/1000_balanced_labels/10.npz or data/labels/cifar10/1000_balanced_labels/10.txt.
  • LOGDIR — directory where tensorboard logs will be stored
  • CKPTDIR — directory where checkpoints will be stored
  • SAVEFREQ — frequency of saving checkpoints in epochs
  • EPOCHS — number of training epochs (passes through labeled data)
  • LABELWEIGHT — weight of cross-entropy loss term (default: 1.)
  • CONSISTENCYWEIGHT — weight of consistency loss term (default: 1.)
  • CONSISTENCYRAMPUP — length of consistency ramp-up period in epochs (default: 1); consistency weight is linearly increasing from 0. to CONSISTENCYWEIGHT in the first CONSISTENCYRAMPUP epochs of training
  • LR — learning rate (default: 1e-3)
  • EVALFREQ — number of epochs between evaluation (default: 1)

Examples:

# MNIST, 100 labeled datapoints
python3 experiments/train_flows/train_semisup_cons.py --dataset=MNIST --data_path=data/images/mnist/ \
  --label_path=data/labels/mnist/100_balanced_labels/10.npz --logdir=<LOGDIR> --ckptdir=<CKPTDIR> \
  --save_freq=5000 --num_epochs=30001 --label_weight=3 --consistency_weight=1. --consistency_rampup=1000 \
  --lr=1e-5 --eval_freq=100 
  
# CIFAR-10, 4000 labeled datapoints
python3 experiments/train_flows/train_semisup_cons.py --dataset=CIFAR10 --data_path=data/images/cifar/cifar10/by-image/ \
  --label_path=data/labels/cifar10/4000_balanced_labels/10.txt --logdir=<LOGDIR> --ckptdir=<CKPTDIR> \ 
  --save_freq=500 --num_epochs=1501 --label_weight=3 --consistency_weight=1. --consistency_rampup=100 \
  --lr=1e-4 --eval_freq=50

References

Owner
Pavel Izmailov
Pavel Izmailov
Collect some papers about transformer with vision. Awesome Transformer with Computer Vision (CV)

Awesome Visual-Transformer Collect some Transformer with Computer-Vision (CV) papers. If you find some overlooked papers, please open issues or pull r

dkliang 2.8k Jan 08, 2023
Implement Decoupled Neural Interfaces using Synthetic Gradients in Pytorch

disclaimer: this code is modified from pytorch-tutorial Image classification with synthetic gradient in Pytorch I implement the Decoupled Neural Inter

Andrew 114 Dec 22, 2022
A small library for creating and manipulating custom JAX Pytree classes

Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree

Cristian Garcia 58 Nov 23, 2022
The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Codebase for learning control flow in transformers The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformer

Csordás Róbert 24 Oct 15, 2022
YOLOv5 + ROS2 object detection package

YOLOv5-ROS YOLOv5 + ROS2 object detection package This program changes the input of detect.py (ultralytics/yolov5) to sensor_msgs/Image of ROS2. Requi

Ar-Ray 23 Dec 19, 2022
Betafold - AlphaFold with tunings

BetaFold We (hegelab.org) craeted this standalone AlphaFold (AlphaFold-Multimer,

2 Aug 11, 2022
Uncertain natural language inference

Uncertain Natural Language Inference This repository hosts the code for the following paper: Tongfei Chen*, Zhengping Jiang*, Adam Poliak, Keisuke Sak

Tongfei Chen 14 Sep 01, 2022
ICCV2021 Papers with Code

ICCV2021 Papers with Code

Amusi 1.4k Jan 02, 2023
Implementation supporting the ICCV 2017 paper "GANs for Biological Image Synthesis"

GANs for Biological Image Synthesis This codes implements the ICCV-2017 paper "GANs for Biological Image Synthesis". The paper and its supplementary m

Anton Osokin 95 Nov 25, 2022
PyTorch Implementation of Temporal Output Discrepancy for Active Learning, ICCV 2021

Temporal Output Discrepancy for Active Learning PyTorch implementation of Semi-Supervised Active Learning with Temporal Output Discrepancy, ICCV 2021.

Siyu Huang 33 Dec 06, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 02, 2023
Breaking the Curse of Space Explosion: Towards Efficient NAS with Curriculum Search

Breaking the Curse of Space Explosion: Towards Effcient NAS with Curriculum Search Pytorch implementation for "Breaking the Curse of Space Explosion:

guoyong 17 Jan 03, 2023
This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021.

MCGC Description This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021. Datasets Results ACM DBLP IMDB Amazon photos Amazon co

31 Nov 14, 2022
Learning with Noisy Labels via Sparse Regularization, ICCV2021

Learning with Noisy Labels via Sparse Regularization This repository is the official implementation of [Learning with Noisy Labels via Sparse Regulari

Xiong Zhou 38 Oct 20, 2022
Code and models used in "MUSS Multilingual Unsupervised Sentence Simplification by Mining Paraphrases".

Multilingual Unsupervised Sentence Simplification Code and pretrained models to reproduce experiments in "MUSS: Multilingual Unsupervised Sentence Sim

Facebook Research 81 Dec 29, 2022
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
PyJokes - Joking around with Python library pyjokes

Hi, it's Muhaimin again 👋 This is something unorthodox but cool. Don't forget t

Muhaimin A. Salay Kanton 1 Feb 02, 2022
Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set (CVPRW 2019). A PyTorch implementation.

Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation This is an unofficial offici

Sicheng Xu 833 Dec 28, 2022
Tensorflow python implementation of "Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos"

Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos This repository is the official tensorflow python implementation

Yasamin Jafarian 287 Jan 06, 2023
On-device wake word detection powered by deep learning.

Porcupine Made in Vancouver, Canada by Picovoice Porcupine is a highly-accurate and lightweight wake word engine. It enables building always-listening

Picovoice 2.8k Dec 29, 2022