A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Related tags

Deep LearningLACE
Overview

Controllable and Compositional Generation with Latent-Space Energy-Based Models

Python 3.8 pytorch 1.7.1 Torchdiffeq 0.2.1

Teaser image Teaser image

Official PyTorch implementation of the NeurIPS 2021 paper:
Controllable and Compositional Generation with Latent-Space Energy-Based Models
Weili Nie, Arash Vahdat, Anima Anandkumar
https://nvlabs.github.io/LACE

Abstract: Controllable generation is one of the key requirements for successful adoption of deep generative models in real-world applications, but it still remains as a great challenge. In particular, the compositional ability to generate novel concept combinations is out of reach for most current models. In this work, we use energy-based models (EBMs) to handle compositional generation over a set of attributes. To make them scalable to high-resolution image generation, we introduce an EBM in the latent space of a pre-trained generative model such as StyleGAN. We propose a novel EBM formulation representing the joint distribution of data and attributes together, and we show how sampling from it is formulated as solving an ordinary differential equation (ODE). Given a pre-trained generator, all we need for controllable generation is to train an attribute classifier. Sampling with ODEs is done efficiently in the latent space and is robust to hyperparameters. Thus, our method is simple, fast to train, and efficient to sample. Experimental results show that our method outperforms the state-of-the-art in both conditional sampling and sequential editing. In compositional generation, our method excels at zero-shot generation of unseen attribute combinations. Also, by composing energy functions with logical operators, this work is the first to achieve such compositionality in generating photo-realistic images of resolution 1024x1024.

Requirements

  • Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
  • 1 high-end NVIDIA GPU with at least 24 GB of memory. We have done all testing and development using a single NVIDIA V100 GPU with memory size 32 GB.
  • 64-bit Python 3.8.
  • CUDA=10.0 and docker must be installed first.
  • Installation of the required library dependencies with Docker:
    docker build -f lace-cuda-10p0.Dockerfile --tag=lace-cuda-10-0:0.0.1 .
    docker run -it -d --gpus 0 --name lace --shm-size 8G -v $(pwd):/workspace -p 5001:6006 lace-cuda-10-0:0.0.1
    docker exec -it lace bash

Experiments on CIFAR-10

The CIFAR10 folder contains the codebase to get the main results on the CIFAR-10 dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., the latent code and label pairs) from here and unzip it to the CIFAR10 folder. Or you can go to the folder CIFAR10/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

In the script run_clf.sh, the variable x can be specified to w or z, representing that the latent classifier is trained in the w-space or z-space of StyleGAN, respectively.

Sampling

To get the conditional sampling results with the ODE or Langevin dynamics (LD) sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

By default, we set x to w, meaning we use the w-space classifier, because we find our method works the best in w-space. You can change the value of x to z or i to use the classifier in z-space or pixel space, for a comparison.

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need a pre-trained image classifier, which can be downloaded as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder CIFAR10/prepare_data and follow the instructions to compute the FID reference statistics with real images sampled from CIFAR-10.

Experiments on FFHQ

The FFHQ folder contains the codebase for getting the main results on the FFHQ dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here (originally from StyleFlow) and unzip it to the FFHQ folder.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., glasses) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

First, you have to get the pre-trained StyleGAN2 (config-f) by following the instructions in Convert StyleGAN2 weight from official checkpoints.

Conditional sampling

To get the conditional sampling results with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need to train an FFHQ image classifier, as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder FFHQ/prepare_models_data and follow the instructions to compute the FID reference statistics with the StyleGAN generated FFHQ images.

Sequential editing

To get the qualitative and quantitative results of sequential editing, you can run:

# User-specified sampling
bash scripts/run_seq_edit_sample.sh

# ACC and FID
bash scripts/run_seq_edit_score.sh

Note that:

  • Similarly, you first need to train an FFHQ image classifier and get the FID reference statics to compute ACC and FID score by following the instructions, respectively.

  • To get the face identity preservation (ID) score, you first need to download the pre-trained ArcFace network, which is publicly available here, to the folder FFHQ/pretrained/metrics.

Compositional Generation

To get the results of zero-shot generation on novel attribute combinations, you can run:

bash scripts/run_zero_shot.sh

To get the results of compositions of energy functions with logical operators, we run:

bash scripts/run_combine_energy.sh

Experiments on MetFaces

The MetFaces folder contains the codebase for getting the main results on the MetFaces dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the MetFaces folder. Or you can go to the folder MetFaces/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., yaw) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

Experiments on AFHQ-Cats

The AFHQ folder contains the codebase for getting the main results on the AFHQ-Cats dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the AFHQ folder. Or you can go to the folder AFHQ/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., breeds) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

License

Please check the LICENSE file. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Citation

Please cite our paper, if you happen to use this codebase:

@inproceedings{nie2021controllable,
  title={Controllable and compositional generation with latent-space energy-based models},
  author={Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
  booktitle={Neural Information Processing Systems (NeurIPS)},
  year={2021}
}
Owner
NVIDIA Research Projects
NVIDIA Research Projects
Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning Implemented & tested on Sort-of-CLEVR task. So

Kim Heecheol 800 Dec 05, 2022
E-RAFT: Dense Optical Flow from Event Cameras

E-RAFT: Dense Optical Flow from Event Cameras This is the code for the paper E-RAFT: Dense Optical Flow from Event Cameras by Mathias Gehrig, Mario Mi

Robotics and Perception Group 71 Dec 12, 2022
Dynamic View Synthesis from Dynamic Monocular Video

Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer This repository contains code to compute depth from a

Intelligent Systems Lab Org 2.3k Jan 01, 2023
MIM: MIM Installs OpenMMLab Packages

MIM provides a unified API for launching and installing OpenMMLab projects and their extensions, and managing the OpenMMLab model zoo.

OpenMMLab 254 Jan 04, 2023
An intuitive library to extract features from time series

Time Series Feature Extraction Library Intuitive time series feature extraction This repository hosts the TSFEL - Time Series Feature Extraction Libra

Associação Fraunhofer Portugal Research 589 Jan 04, 2023
N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting

N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting Recent progress in neural forecasting instigated significant improvements in the

Cristian Challu 82 Jan 04, 2023
Optimized primitives for collective multi-GPU communication

NCCL Optimized primitives for inter-GPU communication. Introduction NCCL (pronounced "Nickel") is a stand-alone library of standard communication rout

NVIDIA Corporation 2k Jan 09, 2023
R-Drop: Regularized Dropout for Neural Networks

R-Drop: Regularized Dropout for Neural Networks R-drop is a simple yet very effective regularization method built upon dropout, by minimizing the bidi

756 Dec 27, 2022
Code for the submitted paper Surrogate-based cross-correlation for particle image velocimetry

Surrogate-based cross-correlation (SBCC) This repository contains code for the submitted paper Surrogate-based cross-correlation for particle image ve

5 Jun 30, 2022
FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

HKBU High Performance Machine Learning Lab 6 Nov 18, 2022
DL course co-developed by YSDA, HSE and Skoltech

Deep learning course This repo supplements Deep Learning course taught at YSDA and HSE @fall'21. For previous iteration visit the spring21 branch. Lec

Yandex School of Data Analysis 1.3k Dec 30, 2022
CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable.

CausalNLP CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable. Install pip install -U

Arun S. Maiya 95 Jan 03, 2023
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
Pure python implementations of popular ML algorithms.

Minimal ML algorithms This repo includes minimal implementations of popular ML algorithms using pure python and numpy. The purpose of these notebooks

Alexis Gidiotis 3 Jan 10, 2022
一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。

captcha_server 一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。 使用方法 python = 3.8 以上环境 pip install -r requirements.txt -i https://pypi.douban.com/simple gun

Sml2h3 189 Dec 02, 2022
Semi-Supervised Learning, Object Detection, ICCV2021

End-to-End Semi-Supervised Object Detection with Soft Teacher By Mengde Xu*, Zheng Zhang*, Han Hu, Jianfeng Wang, Lijuan Wang, Fangyun Wei, Xiang Bai,

Microsoft 789 Dec 27, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
Source code for the NeurIPS 2021 paper "On the Second-order Convergence Properties of Random Search Methods"

Second-order Convergence Properties of Random Search Methods This repository the paper "On the Second-order Convergence Properties of Random Search Me

Adamos Solomou 0 Nov 13, 2021
Rasterize with the least efforts for researchers.

utils3d Rasterize and do image-based 3D transforms with the least efforts for researchers. Based on numpy and OpenGL. It could be helpful when you wan

Ruicheng Wang 8 Dec 15, 2022
Optimized code based on M2 for faster image captioning training

Transformer Captioning This repository contains the code for Transformer-based image captioning. Based on meshed-memory-transformer, we further optimi

lyricpoem 16 Dec 16, 2022