Ladder Variational Autoencoders (LVAE) in PyTorch

Overview

Ladder Variational Autoencoders (LVAE)

PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]:

                 LVAE equation

where the variational distributions q at each layer are multivariate Normal with diagonal covariance.

Significant differences from [1] include:

  • skip connections in the generative path: conditioning on all layers above rather than only on the layer above (see for example [2])
  • spatial (convolutional) latent variables
  • free bits [3] instead of beta annealing [4]

Install requirements and run MNIST example

pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python main.py --zdims 32 32 32 --downsample 1 1 1 --nonlin elu --skip --blocks-per-layer 4 --gated --freebits 0.5 --learn-top-prior --data-dep-init --seed 42 --dataset static_mnist

Dependencies include boilr (a framework for PyTorch) and multiobject (which provides multi-object datasets with PyTorch dataloaders).

Likelihood results

Log likelihood bounds on the test set (average over 4 random seeds).

dataset num layers -ELBO - log p(x)
[100 iws]
- log p(x)
[1000 iws]
binarized MNIST 3 82.14 79.47 79.24
binarized MNIST 6 80.74 78.65 78.52
binarized MNIST 12 80.50 78.50 78.30
multi-dSprites (0-2) 12 26.9 23.2
SVHN 15 4012 (1.88) 3973 (1.87)
CIFAR10 3 7651 (3.59) 7591 (3.56)
CIFAR10 6 7321 (3.44) 7268 (3.41)
CIFAR10 15 7128 (3.35) 7068 (3.32)
CelebA 20 20026 (2.35) 19913 (2.34)

Note:

  • Bits per dimension in brackets.
  • 'iws' stands for importance weighted samples. More samples means tighter log likelihood lower bound. The bound converges to the actual log likelihood as the number of samples goes to infinity [5]. Note that the model is always trained with the ELBO (1 sample).
  • Each pixel in the images is modeled independently. The likelihood is Bernoulli for binary images, and discretized mixture of logistics with 10 components [6] otherwise.
  • One day I'll get around to evaluating the IW bound on all datasets with 10000 samples.

Supported datasets

  • Statically binarized MNIST [7], see Hugo Larochelle's website http://www.cs.toronto.edu/~larocheh/public/datasets/
  • SVHN
  • CIFAR10
  • CelebA rescaled and cropped to 64x64 – see code for details. The path in experiment.data.DatasetLoader has to be modified
  • binary multi-dSprites: 64x64 RGB shapes (0 to 2) in each image

Samples

Binarized MNIST

MNIST samples

Multi-dSprites

multi-dSprites samples

SVHN

SVHN samples

CIFAR

CIFAR samples

CelebA

CelebA samples

Hierarchical representations

Here we try to visualize the representations learned by individual layers. We can get a rough idea of what's going on at layer i as follows:

  • Sample latent variables from all layers above layer i (Eq. 1).

  • With these variables fixed, take S conditional samples at layer i (Eq. 2). Note that they are all conditioned on the same samples. These correspond to one row in the images below.

  • For each of these samples (each small image in the images below), pick the mode/mean of the conditional distribution of each layer below (Eq. 3).

  • Finally, sample an image x given the latent variables (Eq. 4).

Formally:

                

where s = 1, ..., S denotes the sample index.

The equations above yield S sample images conditioned on the same values of z for layers i+1 to L. These S samples are shown in one row of the images below. Notice that samples from each row are almost identical when the variability comes from a low-level layer, as such layers mostly model local structure and details. Higher layers on the other hand model global structure, and we observe more and more variability in each row as we move to higher layers. When the sampling happens in the top layer (i = L), all samples are completely independent, even within a row.

Binarized MNIST: layers 4, 8, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

SVHN: layers 4, 10, 13, and 15 (top layer)

SVHN layers 4   SVHN layers 10

SVHN layers 13   SVHN layers 15

CIFAR: layers 3, 7, 10, and 15 (top layer)

CIFAR layers 3   CIFAR layers 7

CIFAR layers 10   CIFAR layers 15

CelebA: layers 6, 11, 16, and 20 (top layer)

CelebA layers 6

CelebA layers 11

CelebA layers 16

CelebA layers 20

Multi-dSprites: layers 3, 7, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

Experimental details

I did not perform an extensive hyperparameter search, but this worked pretty well:

  • Downsampling by a factor of 2 in the beginning of inference. After that, activations are downsampled 4 times for 64x64 images (CelebA and multi-dSprites), and 3 times otherwise. The spatial size of the final feature map is always 2x2. Between these downsampling steps there is approximately the same number of stochastic layers.
  • 4 residual blocks between stochastic layers. Haven't tried with more than 4 though, as models become quite big and we get diminishing returns.
  • The deterministic parts of bottom-up and top-down architecture are (almost) perfectly mirrored for simplicity.
  • Stochastic layers have spatial random variables, and the number of rvs per "location" (i.e. number of channels of the feature map after sampling from a layer) is 32 in all layers.
  • All other feature maps in deterministic paths have 64 channels.
  • Skip connections in the generative model (--skip).
  • Gated residual blocks (--gated).
  • Learned prior of the top layer (--learn-top-prior).
  • A form of data-dependent initialization of weights (--data-dep-init). See code for details.
  • freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for smaller models.
  • For everything else, see _add_args() in experiment/experiment_manager.py.

With these settings, the number of parameters is roughly 1M per stochastic layer. I tried to control for this by experimenting e.g. with half the number of layers but twice the number of residual blocks, but it looks like the number of stochastic layers is what matters the most.

References

[1] CK Sønderby, T Raiko, L Maaløe, SK Sønderby, O Winther. Ladder Variational Autoencoders, NIPS 2016

[2] L Maaløe, M Fraccaro, V Liévin, O Winther. BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling, NeurIPS 2019

[3] DP Kingma, T Salimans, R Jozefowicz, X Chen, I Sutskever, M Welling. Improved Variational Inference with Inverse Autoregressive Flow, NIPS 2016

[4] I Higgins, L Matthey, A Pal, C Burgess, X Glorot, M Botvinick, S Mohamed, A Lerchner. beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, ICLR 2017

[5] Y Burda, RB Grosse, R Salakhutdinov. Importance Weighted Autoencoders, ICLR 2016

[6] T Salimans, A Karpathy, X Chen, DP Kingma. PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications, ICLR 2017

[7] H Larochelle, I Murray. The neural autoregressive distribution estimator, AISTATS 2011

Owner
Andrea Dittadi
PhD student at DTU Compute | representation learning, deep generative models
Andrea Dittadi
RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020)

RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020) Hong Wang, Qi Xie, Qian Zhao, and Deyu Meng [PDF] [Supplementary M

Hong Wang 6 Sep 27, 2022
A library for Deep Learning Implementations and utils

deeply A Deep Learning library Table of Contents Features Quick Start Usage License Features Python 2.7+ and Python 3.4+ compatible. Quick Start $ pip

Achilles Rasquinha 1 Dec 12, 2022
Alignment Attention Fusion framework for Few-Shot Object Detection

AAF framework Framework generalities This repository contains the code of the AAF framework proposed in this paper. The main idea behind this work is

Pierre Le Jeune 20 Dec 16, 2022
Preprocessed Datasets for our Multimodal NER paper

Unified Multimodal Transformer (UMT) for Multimodal Named Entity Recognition (MNER) Two MNER Datasets and Codes for our ACL'2020 paper: Improving Mult

76 Dec 21, 2022
HAT: Hierarchical Aggregation Transformers for Person Re-identification

HAT: Hierarchical Aggregation Transformers for Person Re-identification

11 Sep 05, 2022
Deep Learning Specialization by Andrew Ng, deeplearning.ai.

Deep Learning Specialization on Coursera Master Deep Learning, and Break into AI This is my personal projects for the course. The course covers deep l

Engen 1.5k Jan 07, 2023
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
Official Implementation of PCT

Official Implementation of PCT Prerequisites python == 3.8.5 Please make sure you have the following libraries installed: numpy torch=1.4.0 torchvisi

32 Nov 21, 2022
Code for the paper "Learning-Augmented Algorithms for Online Steiner Tree"

Learning-Augmented Algorithms for Online Steiner Tree This is the code for the paper "Learning-Augmented Algorithms for Online Steiner Tree". Requirem

0 Dec 09, 2021
Learning Energy-Based Models by Diffusion Recovery Likelihood

Learning Energy-Based Models by Diffusion Recovery Likelihood Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma Paper: https://arxiv.o

Ruiqi Gao 41 Nov 22, 2022
Permeability Prediction Via Multi Scale 3D CNN

Permeability-Prediction-Via-Multi-Scale-3D-CNN Data: The raw CT rock cores are obtained from the Imperial Colloge portal. The CT rock cores are sub-sa

Mohamed Elmorsy 2 Jul 06, 2022
SimBERT升级版(SimBERTv2)!

RoFormer-Sim RoFormer-Sim,又称SimBERTv2,是我们之前发布的SimBERT模型的升级版。 介绍 https://kexue.fm/archives/8454 训练 tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 下载

318 Dec 31, 2022
Code for "Hierarchical Skills for Efficient Exploration" HSD-3 Algorithm and Baselines

Hierarchical Skills for Efficient Exploration This is the source code release for the paper Hierarchical Skills for Efficient Exploration. It contains

Facebook Research 38 Dec 06, 2022
Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages"

Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data

Ayush Daksh 12 Dec 01, 2022
Classic Papers for Beginners and Impact Scope for Authors.

There have been billions of academic papers around the world. However, maybe only 0.0...01% among them are valuable or are worth reading. Since our limited life has never been forever, TopPaper provi

Qiulin Zhang 228 Dec 18, 2022
The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color

The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color Overview Code and dataset for The World of an Octopus: H

1 Nov 13, 2021
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 119 Sep 28, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
Codes for “A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing Change Detection”

DSAMNet The pytorch implementation for "A Deeply-supervised Attention Metric-based Network and an Open Aerial Image Dataset for Remote Sensing Change

Mengxi Liu 41 Dec 14, 2022
mlpack: a scalable C++ machine learning library --

a fast, flexible machine learning library Home | Documentation | Doxygen | Community | Help | IRC Chat Download: current stable version (3.4.2) mlpack

mlpack 4.2k Jan 09, 2023