The Official PyTorch Implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 spotlight paper)

Related tags

Deep LearningVAEBM
Overview

Official PyTorch implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 Spotlight Paper)

Zhisheng Xiao·Karsten Kreis·Jan Kautz·Arash Vahdat


VAEBM trains an energy network to refine the data distribution learned by an NVAE, where the enery network and the VAE jointly define an Energy-based model. The NVAE is pretrained before training the energy network, and please refer to NVAE's implementation for more details about constructing and training NVAE.

Set up datasets

We trained on several datasets, including CIFAR10, CelebA64, LSUN Church 64 and CelebA HQ 256. For large datasets, we store the data in LMDB datasets for I/O efficiency. Check here for information regarding dataset preparation.

Training NVAE

We use the following commands on each dataset for training the NVAE backbone. To train NVAEs, please use its original codebase with commands given here.

CIFAR-10 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset cifar10 \
      --num_channels_enc 128 --num_channels_dec 128 --epochs 400 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 30 --batch_size 32 \
      --weight_decay_norm 1e-1 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-64 (8x 16-GB GPUs)

python train.py --data  $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 50 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-HQ-256 (8x 32-GB GPUs)

python train.py -data  $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \
      --num_channels_enc 32 --num_channels_dec 32 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 \
      --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_x_bits 5 --num_latent_scales 5 --num_groups_per_scale 4 \
      --num_nf 2 --batch_size 8 --fast_adamax  --num_mixture_dec 1 \
      --weight_decay_norm_anneal  --weight_decay_norm_init 1e1 --learning_rate 6e-3 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

LSUN Churches Outdoor 64 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/LSUN/ --root $CHECKPOINT_DIR --save $EXPR_ID --dataset lsun_church_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 60 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

Training VAEBM

We use the following commands on each dataset for training VAEBM. Note that you need to train the NVAE on corresponding dataset before running the training command here. After training the NVAE, pass the path of the checkpoint to the --checkpoint argument.

Note that the training of VAEBM will eventually explode (See Appendix E of our paper), and therefore it is important to save checkpoint regularly. After the training explodes, stop running the code and use the last few saved checkpoints for testing.

CIFAR-10

We train VAEBM on CIFAR-10 using one 32-GB V100 GPU.

python train_VAEBM.py  --checkpoint ./checkpoints/cifar10/checkpoint.pt --experiment cifar10_exp1
--dataset cifar10 --im_size 32 --data ./data/cifar10 --num_steps 10 
--wd 3e-5 --step_size 8e-5 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 --max_p 0.6 
--anneal_step 5000. --batch_size 32 --n_channel 128

CelebA 64

We train VAEBM on CelebA 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --experiment celeba64_exp1 --dataset celeba_64 
--im_size 64 --lr 5e-5 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 5e-6 --total_iter 30000 
--alpha_s 0.2 

LSUN Church 64

We train VAEBM on LSUN Church 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --experiment lsunchurch_exp1 --dataset lsun_church 
--im_size 64 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 4e-6 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 
--use_buffer --max_p 0.6 --anneal_step 5000

CelebA HQ 256

We train VAEBM on CelebA HQ 256 using four 32-GB V100 GPUs.

python train_VAEBM_distributed.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --experiment celeba256_exp1 --dataset celeba_256
--num_process_per_node 4 --im_size 256 --batch_size 4 --n_channel 64 --num_steps 6 --use_mu_cd --wd 3e-5 --step_size 3e-6 
--total_iter 9000 --alpha_s 0.3 --lr 4e-5 --use_buffer --max_p 0.6 --anneal_step 3000 --buffer_size 2000

Sampling from VAEBM

To generate samples from VAEBM after training, run sample_VAEBM.py, and it will generate 50000 test images in your given path. When sampling, we typically use longer Langvin dynamics than training for better sample quality, see Appendix E of the paper for the step sizes and number of steps we use to obtain test samples for each dataset. Other parameters that ensure successfully loading the VAE and energy network are the same as in the training codes.

For example, the script used to sample CIFAR-10 is

python sample_VAEBM.py --checkpoint ./checkpoints/cifar_10/checkpoint.pt --ebm_checkpoint ./saved_models/cifar_10/cifar_exp1/EBM.pth 
--dataset cifar10 --im_size 32 --batch_size 40 --n_channel 128 --num_steps 16 --step_size 8e-5 

For CelebA 64,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_64/celeba64_exp1/EBM.pth 
--dataset celeba_64 --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 5e-6 

For LSUN Church 64,

python sample_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --ebm_checkpoint ./saved_models/lsun_chruch/lsunchurch_exp1/EBM.pth 
--dataset lsun_church --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 4e-6 

For CelebA HQ 256,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_256/celeba256_exp1/EBM.pth 
--dataset celeba_256 --im_size 256 --batch_size 10 --n_channel 64 --num_steps 24 --step_size 3e-6 

Evaluation

After sampling, use the Tensorflow or PyTorch implementation to compute the FID scores. For example, when using the Tensorflow implementation, you can obtain the FID score by saving the training images in /path/to/training_images and running the script:

python fid.py /path/to/training_images /path/to/sampled_images

For CIFAR-10, the training statistics can be downloaded from here, and the FID score can be computed by running

python fid.py /path/to/sampled_images /path/to/precalculated_stats.npz

For the Inception Score, save samples in a single numpy array with pixel values in range [0, 255] and simply run

python ./thirdparty/inception_score.py --sample_dir /path/to/sampled_images

where the code for computing Inception Score is adapted from here.

License

Please check the LICENSE file. VAEBM may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Bibtex

Cite our paper using the following bibtex item:

@inproceedings{
xiao2021vaebm,
title={VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models},
author={Zhisheng Xiao and Karsten Kreis and Jan Kautz and Arash Vahdat},
booktitle={International Conference on Learning Representations},
year={2021}
}
Turning SymPy expressions into PyTorch modules.

sympytorch A micro-library as a convenience for turning SymPy expressions into PyTorch Modules. All SymPy floats become trainable parameters. All SymP

Patrick Kidger 89 Dec 13, 2022
AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

Frank Liu 26 Oct 13, 2022
How the Deep Q-learning method works and discuss the new ideas that makes the algorithm work

Deep Q-Learning Recommend papers The first step is to read and understand the method that you will implement. It was first introduced in a 2013 paper

1 Jan 25, 2022
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift

This repository contains the official code of OSTAR in "Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift" (ICLR 2022).

Matthieu Kirchmeyer 5 Dec 06, 2022
The materials used in the SaxonJS tutorial presented at Declarative Amsterdam, 2021

SaxonJS-Tutorial-2021, version 1.0.4 Last updated on 4 November, 2021. Table of contents Background Prerequisites Starting a web server Running a Java

Saxonica 11 Oct 23, 2022
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

ONNX Runtime is a cross-platform inference and training machine-learning accelerator. ONNX Runtime inference can enable faster customer experiences an

Microsoft 8k Jan 04, 2023
Auto-Lama combines object detection and image inpainting to automate object removals

Auto-Lama Auto-Lama combines object detection and image inpainting to automate object removals. It is build on top of DE:TR from Facebook Research and

44 Dec 09, 2022
A simple code to perform canny edge contrast detection on images.

CECED-Canny-Edge-Contrast-Enhanced-Detection A simple code to perform canny edge contrast detection on images. A simple code to process images using c

Happy N. Monday 3 Feb 15, 2022
Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

pgmpy 2.2k Jan 03, 2023
A TensorFlow implementation of the Mnemonic Descent Method.

MDM A Tensorflow implementation of the Mnemonic Descent Method. Mnemonic Descent Method: A recurrent process applied for end-to-end face alignment G.

123 Oct 07, 2022
Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning"

Prompt-Tuning Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning" Currently, we support the following huggigface models: Bart

Andrew Zeng 36 Dec 19, 2022
Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process

Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process, a complete algorithm library is esta

Fu Pengyou 50 Jan 07, 2023
5 Jan 05, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
GANTheftAuto is a fork of the Nvidia's GameGAN

Description GANTheftAuto is a fork of the Nvidia's GameGAN, which is research focused on emulating dynamic game environments. The early research done

Harrison 801 Dec 27, 2022
Implementation of various Vision Transformers I found interesting

Implementation of various Vision Transformers I found interesting

Kim Seonghyeon 78 Dec 06, 2022
Official Code Release for Container : Context Aggregation Network

Container: Context Aggregation Network Official Code Release for Container : Context Aggregation Network Comparion between CNN, MLP-Mixer and Transfor

peng gao 42 Nov 17, 2021
Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor.

Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor. It is devel

33 Nov 11, 2022