[ICLR'21] Counterfactual Generative Networks

Overview

Counterfactual Generative Networks

[Project] [PDF] [Blog] [Music Video] [Colab]

This repository contains the code for the ICLR 2021 paper "Counterfactual Generative Networks" by Axel Sauer and Andreas Geiger. If you want to take the CGN for a spin and generate counterfactual images, you can try out the Colab below.

CGN

If you find our code or paper useful, please cite

@inproceedings{Sauer2021ICLR,
 author =  {Axel Sauer, Andreas Geiger},
 title = {Counterfactual Generative Networks},
 booktitle = {International Conference on Learning Representations (ICLR)},
 year = {2021}}

Setup

Install anaconda (if you don't have it yet)

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profile

Clone the repo and build the environment

git clone https://github.com/autonomousvision/counterfactual_generative_networks
cd counterfactual_generative_networks
conda env create -f environment.yml
conda activate cgn

Make all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.

./scripts/download_data.sh
./scripts/download_weights.sh

MNISTs

The main functions of this sub-repo are:

  • Generating the MNIST variants
  • Training a CGN
  • Generating counterfactual datasets
  • Training a shape classifier

Train the CGN

We provide well-working configs and weights in mnists/experiments. To train a CGN on, e.g., Wildlife MNIST, run

python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml

For more info, add --help. Weights and samples will be saved in mnists/experiments/.

Generate Counterfactual Data

To generate the counterfactuals for, e.g., double-colored MNIST, run

python mnists/generate_data.py \
--weight_path mnists/experiments/cgn_double_colored_MNIST/weights/ckp.pth \
--dataset double_colored_MNIST --no_cfs 10 --dataset_size 100000

Make sure that you provide the right dataset together with the weights. You can adapt the weight-path to use your own weights. The command above generates ten counterfactuals per shape.

Train the Invariant Classifier

The classifier training uses Tensor datasets, so you need to save the non-counterfactual datasets as tensors. For DATASET = {colored_MNIST, double_colored_MNIST, wildlife_MNIST}, run

python mnists/generate_data.py --dataset DATASET

To train, e.g., a shape classifier (invariant to foreground and background) on wildlife MNIST, run,

python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactual

Add --help for info on the available options and arguments. The hyperparameters are unchanged for all experiments.

ImageNet

The main functions of this sub-repo are:

  • Training a CGN
  • Generating data (samples, interpolations, or a whole dataset)
  • Training an invariant classifier ensemble

Train the CGN

Run

python imagenet/train_cgn.py --model_name MODEL_NAME

The default parameters should give you satisfactory results. You can change them in imagenet/config.yml. For more info, add --help. Weights and samples will be saved in imagenet/data/MODEL_NAME.

Generate Counterfactual Data

Samples. To generate a dataset of counterfactual images, run

python imagenet/generate_data.py --mode random --weights_path imagenet/weights/cgn.pth \
--n_data 100 --weights_path imagenet/weights/cgn.pth --run_name RUN_NAME \
--truncation 0.5 --batch_sz 1

The results will be saved in imagenet/data. For more info, add --help. If you want to save only masks, textures, etc., you need to change this directly in the code (see line 206).

The labels will be stored in a csv file. You can read them as follows:

import pandas as pd
df = pd.read_csv(path, index_col=0)
df = df.set_index('im_name')
shape_cls = df['shape_cls']['RUN_NAME_0000000.png']

Generating a dataset to train a classfier. Produce one dataset with --run_name train, the other one with --run_name val. If you have several GPUs available, you can index the name, e.g., --run_name train_GPU_NUM. The class ImagenetCounterfactual will glob all these datasets and generate a single, big training set. Make sure to set --batch_sz 1. With a larger batch size, a batch will be saved as a single png; this is useful for visualization, not for training.

Interpolations. To generate interpolation sheets, e.g., from a barn (425) to whale (147), run

python imagenet/generate_data.py --mode fixed_classes \
--n_data 1 --weights_path imagenet/weights/cgn.pth --run_name barn_to_whale \
--truncation 0.3 --interp all --classes 425 425 425 --interp_cls 147 --save_noise

You can also do counterfactual interpolations, i.e., interpolating only over, e.g., shape, by setting --interp shape.

Interpolation Gif. To generate a gif like in the teaser (Sample an image of class $1, than interpolate to shape $2, then background $3, then shape $4, and finally back to $1), run

./scripts/generate_teaser_gif.sh 992 293 147 330

The positional arguments are the classes, see imagenet labels for the available options.

Train the Invariant Classifier Ensemble

Training. First, you need to make sure that you have all datasets in imagenet/data/. Download Imagenet, e.g., from Kaggle, produce a counterfactual dataset (see above), and download the Cue-Conflict and BG-Challenge dataset (via the download script in scripts).

To train a classifier on a single GPU with a pre-trained Resnet-50 backbone, run

python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME

Again, add --help for more information on the possible arguments.

Distributed Training. To switch to multi-GPU training, run echo $CUDA_VISIBLE_DEVICES to see if the GPUs are visible. In the case of a single node with several GPUs, you can run, e.g.,

python imagenet/train_classifier.py -a resnet50 -b 256 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME \
--rank 0 --multiprocessing-distributed --dist-url tcp://127.0.0.1:8890 --world-size 1

If your setup differs, e.g., several GPU machines, you need to adapt the rank and world size.

Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.

Acknowledgments

We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:

Deep Reinforcement Learning based autonomous navigation for quadcopters using PPO algorithm.

PPO-based Autonomous Navigation for Quadcopters This repository contains an implementation of Proximal Policy Optimization (PPO) for autonomous naviga

Bilal Kabas 16 Nov 11, 2022
Official PyTorch implementation of the paper "TEMOS: Generating diverse human motions from textual descriptions"

TEMOS: TExt to MOtionS Generating diverse human motions from textual descriptions Description Official PyTorch implementation of the paper "TEMOS: Gen

Mathis Petrovich 187 Dec 27, 2022
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
Neural Lexicon Reader: Reduce Pronunciation Errors in End-to-end TTS by Leveraging External Textual Knowledge

Neural Lexicon Reader: Reduce Pronunciation Errors in End-to-end TTS by Leveraging External Textual Knowledge This is an implementation of the paper,

Mutian He 19 Oct 14, 2022
👨‍💻 run nanosaur in simulation with Gazebo/Ingnition

🦕 👨‍💻 nanosaur_gazebo nanosaur The smallest NVIDIA Jetson dinosaur robot, open-source, fully 3D printable, based on ROS2 & Isaac ROS. Designed & ma

nanosaur 9 Jul 19, 2022
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
Repo for "Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions" https://arxiv.org/abs/2201.12296

Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions This repo contains the dataset and code for the paper Benchmarking Ro

Jiachen Sun 168 Dec 29, 2022
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

jemmy li 121 Sep 26, 2022
Assessing syntactic abilities of BERT

BERT-Syntax Assesing the syntactic abilities of BERT. What Evaluate Google's BERT-Base and BERT-Large models on the syntactic agreement datasets from

Yoav Goldberg 147 Aug 02, 2022
Tutorials, assignments, and competitions for MIT Deep Learning related courses.

MIT Deep Learning This repository is a collection of tutorials for MIT Deep Learning courses. More added as courses progress. Tutorial: Deep Learning

Lex Fridman 9.5k Jan 07, 2023
Polynomial-time Meta-Interpretive Learning

Louise - polynomial-time Program Learning Getting help with Louise Louise's author can be reached by email at Stassa Patsantzis 64 Dec 26, 2022

Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite.

TFLite-HITNET-Stereo-depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite. Stereo depth e

Ibai Gorordo 22 Oct 20, 2022
Code To Tune or Not To Tune? Zero-shot Models for Legal Case Entailment.

COLIEE 2021 - task 2: Legal Case Entailment This repository contains the code to reproduce NeuralMind's submissions to COLIEE 2021 presented in the pa

NeuralMind 13 Dec 16, 2022
RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation Ported from https://github.com/hzwer/arXiv2020-RIFE Dependencies NumPy

49 Jan 07, 2023
The trained model and denoising example for paper : Cardiopulmonary Auscultation Enhancement with a Two-Stage Noise Cancellation Approach

The trained model and denoising example for paper : Cardiopulmonary Auscultation Enhancement with a Two-Stage Noise Cancellation Approach

ycj_project 1 Jan 18, 2022
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Unity Technologies 187 Dec 24, 2022
DecoupledNet is semantic segmentation system which using heterogeneous annotations

DecoupledNet: Decoupled Deep Neural Network for Semi-supervised Semantic Segmentation Created by Seunghoon Hong, Hyeonwoo Noh and Bohyung Han at POSTE

Hyeonwoo Noh 74 Sep 22, 2021
[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation

[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation [Paper] Prerequisites To install requirements: pip install -r requirements.txt

Guangrui Li 84 Dec 26, 2022
Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"

CSDI This is the github repository for the NeurIPS 2021 paper "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

106 Jan 04, 2023