Invert and perturb GAN images for test-time ensembling

Overview

GAN Ensembling

Project Page | Paper | Bibtex

Ensembling with Deep Generative Views.
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang
CVPR 2021

Prerequisites

  • Linux
  • Python 3
  • NVIDIA GPU + CUDA CuDNN

Table of Contents:

  1. Colab - run a limited demo version without local installation
  2. Setup - download required resources
  3. Quickstart - short demonstration code snippet
  4. Notebooks - jupyter notebooks for visualization
  5. Pipeline - details on full pipeline

We project an input image into the latent space of a pre-trained GAN and perturb it slightly to obtain modifications of the input image. These alternative views from the GAN are ensembled at test-time, together with the original image, in a downstream classification task.

To synthesize deep generative views, we first align (Aligned Input) and reconstruct an image by finding the corresponding latent code in StyleGAN2 (GAN Reconstruction). We then investigate different approaches to produce image variations using the GAN, such as style-mixing on fine layers (Style-mix Fine), which predominantly changes color, or coarse layers (Style-mix Coarse), which changes pose.

Colab

This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.

Setup

  • Clone this repo:
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling

An example of the directory organization is below:

dataset/celebahq/
	images/images/
		000004.png
		000009.png
		000014.png
		...
	latents/
	latents_idinvert/
dataset/cars/
	devkit/
		cars_meta.mat
		cars_test_annos.mat
		cars_train_annos.mat
		...
	images/images/
		00001.jpg
		00002.jpg
		00003.jpg
		...
	latents/
dataset/catface/
	images/
	latents/
dataset/cifar10/
	cifar-10-batches-py/
	latents/

Quickstart

Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb.

import data
from networks import domain_generator

dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)

index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)

Notebooks

Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling.

The provided notebooks are:

  1. notebooks/demo.ipynb: basic usage example
  2. notebooks/evaluate_ensemble.ipynb: plot classification test accuracy as a function of ensemble weight
  3. notebooks/plot_precomputed_evaluations.ipynb: notebook to generate figures in paper

Full Pipeline

The full pipeline contains three main parts:

  1. optimize latent codes
  2. train classifiers
  3. evaluate the ensemble of GAN-generated images.

Examples for each step of the pipeline are contained in the following scripts:

bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh

To add to the pipeline:

  • Data: in the data/ directory, add the dataset in data/__init__.py and create the dataset class and transformation functions. See data/data_*.py for examples.
  • Generator: modify networks/domain_generators.py to add the generator in domain_generators.define_generator. The perturbation ranges for each dataset and generator are specified in networks/perturb_settings.py.
  • Classifier: modify networks/domain_classifiers.py to add the classifier in domain_classifiers.define_classifier

Acknowledgements

We thank the authors of these repositories:

Citation

If you use this code for your research, please cite our paper:

@inproceedings{chai2021ensembling,
  title={Ensembling with Deep Generative Views.},
  author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
  booktitle={CVPR},
  year={2021}
 }
Owner
Lucy Chai
Lucy Chai
RID-Noise: Towards Robust Inverse Design under Noisy Environments

This is code of RID-Noise. Reproduce RID-Noise Results Toy tasks Please refer to the notebook ridnoise.ipynb to view experiments on three toy tasks. B

Thyrix 2 Nov 23, 2022
ARAE-Tensorflow for Discrete Sequences (Adversarially Regularized Autoencoder)

ARAE Tensorflow Code Code for the paper Adversarially Regularized Autoencoders for Generating Discrete Structures by Zhao, Kim, Zhang, Rush and LeCun

19 Nov 12, 2021
[NeurIPS 2019] Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, Tengyu Ma This is the offi

Kaidi Cao 528 Jan 01, 2023
The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

AI2 114 Jan 06, 2023
Sudoku solver - A sudoku solver with python

sudoku_solver A sudoku solver What is Sudoku? Sudoku (Japanese: 数独, romanized: s

Sikai Lu 0 May 22, 2022
Repository for publicly available deep learning models developed in Rosetta community

trRosetta2 This package contains deep learning models and related scripts used by Baker group in CASP14. Installation Linux/Mac clone the package git

81 Dec 29, 2022
CAMPARI: Camera-Aware Decomposed Generative Neural Radiance Fields

CAMPARI: Camera-Aware Decomposed Generative Neural Radiance Fields Paper | Supplementary | Video | Poster If you find our code or paper useful, please

26 Nov 29, 2022
Implementation of the paper "Generating Symbolic Reasoning Problems with Transformer GANs"

Generating Symbolic Reasoning Problems with Transformer GANs This is the implementation of the paper Generating Symbolic Reasoning Problems with Trans

Reactive Systems Group 1 Apr 18, 2022
Deep Multi-Magnification Network for multi-class tissue segmentation of whole slide images

Deep Multi-Magnification Network This repository provides training and inference codes for Deep Multi-Magnification Network published here. Deep Multi

Computational Pathology 12 Aug 06, 2022
Tensorflow Repo for "DeepGCNs: Can GCNs Go as Deep as CNNs?"

DeepGCNs: Can GCNs Go as Deep as CNNs? In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly re

Guohao Li 612 Nov 15, 2022
A pytorch implementation of Reading Wikipedia to Answer Open-Domain Questions.

DrQA A pytorch implementation of the ACL 2017 paper Reading Wikipedia to Answer Open-Domain Questions (DrQA). Reading comprehension is a task to produ

Runqi Yang 394 Nov 08, 2022
Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

9.6k Jan 06, 2023
Official code for "Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021".

Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021. Introduction We proposed a novel model training paradi

Lucas 103 Dec 14, 2022
Implementation of the paper ''Implicit Feature Refinement for Instance Segmentation''.

Implicit Feature Refinement for Instance Segmentation This repository is an official implementation of the ACM Multimedia 2021 paper Implicit Feature

Lufan Ma 17 Dec 28, 2022
An efficient PyTorch library for Global Wheat Detection using YOLOv5. The project is based on this Kaggle competition Global Wheat Detection (2021).

Global-Wheat-Detection An efficient PyTorch library for Global Wheat Detection using YOLOv5. The project is based on this Kaggle competition Global Wh

Chuxin Wang 11 Sep 25, 2022
A toolkit for Lagrangian-based constrained optimization in Pytorch

Cooper About Cooper is a toolkit for Lagrangian-based constrained optimization in Pytorch. This library aims to encourage and facilitate the study of

Cooper 34 Jan 01, 2023
Deep Learning Training Scripts With Python

Deep Learning Training Scripts DNN Frameworks Caffe PyTorch Tensorflow CNN Models VGG ResNet DenseNet Inception Language Modeling GatedCNN-LM Attentio

Multicore Computing Research Lab 16 Dec 15, 2022
Repo for the ACMMM20 submission: "Personalized breath based biometric authentication with wearable multimodality".

personalized-breath Repo for the ACMMM20 submission: "Personalized breath based biometric authentication with wearable multimodality". Guideline To ex

Manh-Ha Bui 2 Nov 15, 2021
Simple cross-platform application for DaVinci surgical video frame annotation

About DaVid is a simple cross-platform GUI for annotating robotic and endoscopic surgical actions for use in deep-learning research. Features Simple a

Cyril Zakka 4 Oct 09, 2021
CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification (ICCV2021)

CM-NAS Official Pytorch code of paper CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification in ICCV2021. Vis

JDAI-CV 40 Nov 25, 2022