Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
Dilated Convolution with Learnable Spacings PyTorch

Dilated-Convolution-with-Learnable-Spacings-PyTorch Ismail Khalfaoui Hassani Dilated Convolution with Learnable Spacings (abbreviated to DCLS) is a no

15 Dec 09, 2022
Crowd-Kit is a powerful Python library that implements commonly-used aggregation methods for crowdsourced annotation and offers the relevant metrics and datasets

Crowd-Kit: Computational Quality Control for Crowdsourcing Documentation Crowd-Kit is a powerful Python library that implements commonly-used aggregat

Toloka 125 Dec 30, 2022
Video Frame Interpolation without Temporal Priors (a general method for blurry video interpolation)

Video Frame Interpolation without Temporal Priors (NeurIPS2020) [Paper] [video] How to run Prerequisites NVIDIA GPU + CUDA 9.0 + CuDNN 7.6.5 Pytorch 1

YoujianZhang 31 Sep 04, 2022
Deep Dual Consecutive Network for Human Pose Estimation (CVPR2021)

Beanie - is an asynchronous ODM for MongoDB, based on Motor and Pydantic. It uses an abstraction over Pydantic models and Motor collections to work wi

295 Dec 29, 2022
CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY

M-BERT-Study CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY Motivation Multilingual BERT (M-BERT) has shown surprising cross lingual a

CogComp 1 Feb 28, 2022
Semi-supervised Domain Adaptation via Minimax Entropy

Semi-supervised Domain Adaptation via Minimax Entropy (ICCV 2019) Install pip install -r requirements.txt The code is written for Pytorch 0.4.0, but s

Vision and Learning Group 243 Jan 09, 2023
A simple Rock-Paper-Scissors game using CV in python

ML18_Rock-Paper-Scissors-using-CV A simple Rock-Paper-Scissors game using CV in python For IITISOC-21 Rules and procedure to play the interactive game

Anirudha Bhagwat 3 Aug 08, 2021
Extracting and filtering paraphrases by bridging natural language inference and paraphrasing

nli2paraphrases Source code repository accompanying the preprint Extracting and filtering paraphrases by bridging natural language inference and parap

Matej Klemen 1 Mar 09, 2022
AbelNN: Deep Learning Python module from scratch

AbelNN: Deep Learning Python module from scratch I have implemented several neural networks from scratch using only Numpy. I have designed the module

Abel 2 Apr 12, 2022
NUANCED is a user-centric conversational recommendation dataset that contains 5.1k annotated dialogues and 26k high-quality user turns.

NUANCED: Natural Utterance Annotation for Nuanced Conversation with Estimated Distributions Overview NUANCED is a user-centric conversational recommen

Facebook Research 18 Dec 28, 2021
Consensus score for tripadvisor

ContripScore ContripScore is essentially a score that combines an Internet platform rating and a consensus rating from sentiment analysis (For instanc

Pepe 1 Jan 13, 2022
A Simple Framwork for CV Pre-training Model (SOCO, VirTex, BEiT)

A Simple Framwork for CV Pre-training Model (SOCO, VirTex, BEiT)

Sense-GVT 14 Jul 07, 2022
TICC is a python solver for efficiently segmenting and clustering a multivariate time series

TICC TICC is a python solver for efficiently segmenting and clustering a multivariate time series. It takes as input a T-by-n data matrix, a regulariz

406 Dec 12, 2022
[NeurIPS2021] Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks Code for NeurIPS 2021 Paper "Exploring Architectural Ingredients of A

Hanxun Huang 26 Dec 01, 2022
Learning Efficient Online 3D Bin Packing on Packing Configuration Trees

Learning Efficient Online 3D Bin Packing on Packing Configuration Trees This repository is being continuously updated, please stay tuned! Any code con

86 Dec 28, 2022
Hyperparameter Optimization for TensorFlow, Keras and PyTorch

Hyperparameter Optimization for Keras Talos • Key Features • Examples • Install • Support • Docs • Issues • License • Download Talos radically changes

Autonomio 1.6k Dec 15, 2022
Code for "3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop"

PyMAF This repository contains the code for the following paper: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop Hongwe

Hongwen Zhang 450 Dec 28, 2022
The ICS Chat System project for NYU Shanghai Fall 2021

ICS_Chat_System [Catenger] This is the ICS Chat System project for NYU Shanghai Fall 2021 Creators: Shavarsh Melikyan, Skyler Chen and Arghya Sarkar,

1 Dec 20, 2021
Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification

Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification Usage The required packages are lis

0 Feb 07, 2022
PyTorch implementation of Higher Order Recurrent Space-Time Transformer

Higher Order Recurrent Space-Time Transformer (HORST) This is the official PyTorch implementation of Higher Order Recurrent Space-Time Transformer. Th

13 Oct 18, 2022