The Official PyTorch Implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 spotlight paper)

Related tags

Deep LearningVAEBM
Overview

Official PyTorch implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 Spotlight Paper)

Zhisheng Xiao·Karsten Kreis·Jan Kautz·Arash Vahdat


VAEBM trains an energy network to refine the data distribution learned by an NVAE, where the enery network and the VAE jointly define an Energy-based model. The NVAE is pretrained before training the energy network, and please refer to NVAE's implementation for more details about constructing and training NVAE.

Set up datasets

We trained on several datasets, including CIFAR10, CelebA64, LSUN Church 64 and CelebA HQ 256. For large datasets, we store the data in LMDB datasets for I/O efficiency. Check here for information regarding dataset preparation.

Training NVAE

We use the following commands on each dataset for training the NVAE backbone. To train NVAEs, please use its original codebase with commands given here.

CIFAR-10 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset cifar10 \
      --num_channels_enc 128 --num_channels_dec 128 --epochs 400 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 30 --batch_size 32 \
      --weight_decay_norm 1e-1 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-64 (8x 16-GB GPUs)

python train.py --data  $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 50 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-HQ-256 (8x 32-GB GPUs)

python train.py -data  $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \
      --num_channels_enc 32 --num_channels_dec 32 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 \
      --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_x_bits 5 --num_latent_scales 5 --num_groups_per_scale 4 \
      --num_nf 2 --batch_size 8 --fast_adamax  --num_mixture_dec 1 \
      --weight_decay_norm_anneal  --weight_decay_norm_init 1e1 --learning_rate 6e-3 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

LSUN Churches Outdoor 64 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/LSUN/ --root $CHECKPOINT_DIR --save $EXPR_ID --dataset lsun_church_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 60 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

Training VAEBM

We use the following commands on each dataset for training VAEBM. Note that you need to train the NVAE on corresponding dataset before running the training command here. After training the NVAE, pass the path of the checkpoint to the --checkpoint argument.

Note that the training of VAEBM will eventually explode (See Appendix E of our paper), and therefore it is important to save checkpoint regularly. After the training explodes, stop running the code and use the last few saved checkpoints for testing.

CIFAR-10

We train VAEBM on CIFAR-10 using one 32-GB V100 GPU.

python train_VAEBM.py  --checkpoint ./checkpoints/cifar10/checkpoint.pt --experiment cifar10_exp1
--dataset cifar10 --im_size 32 --data ./data/cifar10 --num_steps 10 
--wd 3e-5 --step_size 8e-5 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 --max_p 0.6 
--anneal_step 5000. --batch_size 32 --n_channel 128

CelebA 64

We train VAEBM on CelebA 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --experiment celeba64_exp1 --dataset celeba_64 
--im_size 64 --lr 5e-5 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 5e-6 --total_iter 30000 
--alpha_s 0.2 

LSUN Church 64

We train VAEBM on LSUN Church 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --experiment lsunchurch_exp1 --dataset lsun_church 
--im_size 64 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 4e-6 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 
--use_buffer --max_p 0.6 --anneal_step 5000

CelebA HQ 256

We train VAEBM on CelebA HQ 256 using four 32-GB V100 GPUs.

python train_VAEBM_distributed.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --experiment celeba256_exp1 --dataset celeba_256
--num_process_per_node 4 --im_size 256 --batch_size 4 --n_channel 64 --num_steps 6 --use_mu_cd --wd 3e-5 --step_size 3e-6 
--total_iter 9000 --alpha_s 0.3 --lr 4e-5 --use_buffer --max_p 0.6 --anneal_step 3000 --buffer_size 2000

Sampling from VAEBM

To generate samples from VAEBM after training, run sample_VAEBM.py, and it will generate 50000 test images in your given path. When sampling, we typically use longer Langvin dynamics than training for better sample quality, see Appendix E of the paper for the step sizes and number of steps we use to obtain test samples for each dataset. Other parameters that ensure successfully loading the VAE and energy network are the same as in the training codes.

For example, the script used to sample CIFAR-10 is

python sample_VAEBM.py --checkpoint ./checkpoints/cifar_10/checkpoint.pt --ebm_checkpoint ./saved_models/cifar_10/cifar_exp1/EBM.pth 
--dataset cifar10 --im_size 32 --batch_size 40 --n_channel 128 --num_steps 16 --step_size 8e-5 

For CelebA 64,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_64/celeba64_exp1/EBM.pth 
--dataset celeba_64 --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 5e-6 

For LSUN Church 64,

python sample_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --ebm_checkpoint ./saved_models/lsun_chruch/lsunchurch_exp1/EBM.pth 
--dataset lsun_church --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 4e-6 

For CelebA HQ 256,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_256/celeba256_exp1/EBM.pth 
--dataset celeba_256 --im_size 256 --batch_size 10 --n_channel 64 --num_steps 24 --step_size 3e-6 

Evaluation

After sampling, use the Tensorflow or PyTorch implementation to compute the FID scores. For example, when using the Tensorflow implementation, you can obtain the FID score by saving the training images in /path/to/training_images and running the script:

python fid.py /path/to/training_images /path/to/sampled_images

For CIFAR-10, the training statistics can be downloaded from here, and the FID score can be computed by running

python fid.py /path/to/sampled_images /path/to/precalculated_stats.npz

For the Inception Score, save samples in a single numpy array with pixel values in range [0, 255] and simply run

python ./thirdparty/inception_score.py --sample_dir /path/to/sampled_images

where the code for computing Inception Score is adapted from here.

License

Please check the LICENSE file. VAEBM may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Bibtex

Cite our paper using the following bibtex item:

@inproceedings{
xiao2021vaebm,
title={VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models},
author={Zhisheng Xiao and Karsten Kreis and Jan Kautz and Arash Vahdat},
booktitle={International Conference on Learning Representations},
year={2021}
}
Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022

AMRBART An implementation for ACL2022 paper "Graph Pre-training for AMR Parsing and Generation". You may find our paper here (Arxiv). Requirements pyt

xfbai 60 Jan 03, 2023
CALVIN - A benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks

CALVIN CALVIN - A benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks Oier Mees, Lukas Hermann, Erick Rosete,

Oier Mees 107 Dec 26, 2022
LineBoard - Python+React+MySQL-白板即時系統改善人群行為

LineBoard-白板即時系統改善人群行為 即時顯示實驗室的使用狀況,並遠端預約排隊,以此來改善人們的工作效率 程式架構 運作流程 使用者先至該實驗室網站預約

Bo-Jyun Huang 1 Feb 22, 2022
DANet for Tabular data classification/ regression.

Deep Abstract Networks A PyTorch code implemented for the submission DANets: Deep Abstract Networks for Tabular Data Classification and Regression. Do

Ronnie Rocket 55 Sep 14, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
An Implementation of SiameseRPN with Feature Pyramid Networks

SiameseRPN with FPN This project is mainly based on HelloRicky123/Siamese-RPN. What I've done is just add a Feature Pyramid Network method to the orig

3 Apr 16, 2022
A DCGAN to generate anime faces using custom mined dataset

Anime-Face-GAN-Keras A DCGAN to generate anime faces using custom dataset in Keras. Dataset The dataset is created by crawling anime database websites

Pavitrakumar P 190 Jan 03, 2023
PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Responsible AI Toolbox A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainabi

24 Dec 22, 2022
Pytorch code for "State-only Imitation with Transition Dynamics Mismatch" (ICLR 2020)

This repo contains code for our paper State-only Imitation with Transition Dynamics Mismatch published at ICLR 2020. The code heavily uses the RL mach

20 Sep 08, 2022
Weakly Supervised Segmentation with Tensorflow. Implements instance segmentation as described in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

Weakly Supervised Segmentation with TensorFlow This repo contains a TensorFlow implementation of weakly supervised instance segmentation as described

Phil Ferriere 220 Dec 13, 2022
data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer"

C2F-FWN data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer" (https://arxiv.org/abs/

EKILI 46 Dec 14, 2022
Creative Applications of Deep Learning w/ Tensorflow

Creative Applications of Deep Learning w/ Tensorflow This repository contains lecture transcripts and homework assignments as Jupyter Notebooks for th

Parag K Mital 1.5k Dec 30, 2022
Display, filter and search log messages in your terminal

Textualog Display, filter and search logging messages in the terminal. This project is powered by rich and textual. Some of the ideas and code in this

Rik Huygen 24 Dec 10, 2022
Implementation of Uformer, Attention-based Unet, in Pytorch

Uformer - Pytorch Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection. This repository wi

Phil Wang 72 Dec 19, 2022
Pydantic models for pywttr and aiopywttr.

Pydantic models for pywttr and aiopywttr.

Almaz 2 Dec 08, 2022
A facial recognition doorbell system using a Raspberry Pi

Facial Recognition Doorbell This project expands on the person-detecting doorbell system to allow it to identify faces, and announce names accordingly

rydercalmdown 22 Apr 15, 2022
Proof-Of-Concept Piano-Drums Music AI Model/Implementation

Rock Piano "When all is one and one is all, that's what it is to be a rock and not to roll." ---Led Zeppelin, "Stairway To Heaven" Proof-Of-Concept Pi

Alex 4 Nov 28, 2021
Single-stage Keypoint-based Category-level Object Pose Estimation from an RGB Image

CenterPose Overview This repository is the official implementation of the paper "Single-stage Keypoint-based Category-level Object Pose Estimation fro

NVIDIA Research Projects 188 Dec 27, 2022
Posterior temperature optimized Bayesian models for inverse problems in medical imaging

Posterior temperature optimized Bayesian models for inverse problems in medical imaging Max-Heinrich Laves*, Malte Tölle*, Alexander Schlaefer, Sandy

Artificial Intelligence in Cardiovascular Medicine (AICM) 6 Sep 19, 2022
Source code for our paper "Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash"

Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash Abstract: Apple recently revealed its deep perceptual hashing system NeuralHash to

<a href=[email protected]"> 11 Dec 03, 2022