PyTorch implementation of the TTC algorithm

Overview

Trust-the-Critics

This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critics: Generatorless and Multipurpose WGANs with Initial Convergence Guarantees.

How to run this code

  • Create a Python virtual environment with Python 3.8 installed.
  • Install the necessary Python packages listed in the requirements.txt file (this can be done through pip install -r /path/to/requirements.txt).

In the example_shell_scripts folder, we include samples of shell scripts we used to run our experiments. We note that training generative models is computationally demanding, and thus requires adequate computational resources (i.e. running this on your laptop is not recommended).

TTC algorithm

The various experiments we run with TTC are described in Section 5 and Addendix B of the paper. Illustrating the flexibility of the TTC algorithm, the image generation, denoising and translation experiments can all be run using the ttc.py script; the only necessary changes are the source and target datasets. Running TTC with a given source and a given target will train and save several critic neural networks that can subsequently be used to push the source distribution towards the target distribution by applying the 'steptaker' function found in TTC_utils/steptaker.py once for each critic.

Necessary arguments for ttc.py are:

  • 'source' : The name of the distribution or dataset that is to be pushed towards the target (options are listed in ttc.py).
  • 'target' : The name of the target dataset (options are listed in ttc.py).
  • 'data' : The path of a directory where the necessary data is located. This includes the target dataset, in a format that can be accessed by a dataloader object obtained from the corresponding function in dataloader.py. Such a dataloader always belongs to the torch.utils.data.DataLoader class (e.g. if target=='mnist', then the corresponding dataloader will be an instance of torchvision.datasets.MNIST, and the MNIST dataset should be placed in 'data' in a way that reflects this). If the source is a dataset, it needs to be placed in 'data' as well. If source=='untrained_gen', then the untrained generator used to create the source distribution needs to be saved under 'data/ugen.pth'.
  • 'temp_dir' : The path of a directory where the trained critics will be saved, along with a few other files (including the log.pkl file that contains the step sizes). Despite the name, this folder isn't necessarily temporary.

Other optional arguments are described in a commented section at the top of the ttc.py script. Note that running ttc.py will only train the critics that the TTC algorithm uses to push the source distribution towards the target distribution, it will not actually push any samples from the source towards the target (as mentioned above, this is done using the steptaker function).

TTC image generation
For a generative experiment, run ttc.py with the source argument set to either 'noise' or 'untrained_gen' and the target of your choice. Then, run ttc_eval.py, which will use the saved critics and step sizes to push noise inputs towards the target distribution according to the TTC algorithm (using the steptaker function), and which will optionally evaluate generative performance with FID and/or MMD (FID is used in the paper). The arguments 'source', 'target', 'data', 'temp_dir' and 'model' for ttc_eval.py should be set to the same values as when running ttc.py. If evaluating FID, the folder specified by 'temp_dir' should contain a subdirectory named 'temp_dir/{target}test' (e.g. 'temp_dir/mnisttest' if target=='mnist') containing the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

TTC denoising
For a denoising experiment, run ttc.py with source=='noisybsds500' and target=='bsds500' (specifying a noise level with the 'sigma' argument). Then, run denoise_eval.py (with the same 'temp_dir', 'data' and 'model' arguments), which will add noise to images, denoise them using the TTC algorithm and the saved critics, and evaluate PSNR's.

TTC Monet translation
For a denoising experiment, run ttc.py with source=='photo' and target=='monet'. Then run ttc_eval.py (with the same 'source', 'target', 'temp_dir', 'data' and 'model' arguments, and presumably with no FID or MMD evaluation), which will sample realistic images from the source and make them look like Monet paintings.

WGAN misalignment

The WGAN misalignment experiments are described in Section 3 and Appendix B.1 of the paper, and are run using misalignments.py. This script trains a WGAN while, at some iterations, measuring how misaligned the movement of generated samples caused by updating the generator is from the critic's gradient. The generator's FID is also measured at the same iterations.

The required arguments for misalignments.py are:

  • 'target' : The dataset used to train the WGAN - can be either 'mnist' or 'fashion' (for Fashion-MNIST).
  • 'data' : The path of a folder where the MNIST (or Fashion-MNIST) dataset is located, in a format that can be accessed by an instance of the torchvision.datasets.MNIST class (resp torchvision.datasets.FashionMNIST).
  • 'fid_data' : The path of a folder containing the test data from the MNIST dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).
  • 'checkpoints' : A string of integers separated by underscores. The integers specify the iterations at which misalignments and FID are computed, and training will continue until the largest iteration is reached.

Other optional arguments (including 'results_path' and 'temp_dir') are described in a commented section at the top of the misalignments.py. The misalignment results reported in the paper (Tables 1 and 5, and Figure 3), correspond to using the default hyperparameters and to setting the 'checkpoints' argument roughly equal to '10_25000_40000', with '10' corresponding the early stage in training, '25000' to the mid stage, and '40000' to the late stage.

WGAN generation

For completeness we include the code that was used to obtain the WGAN FID statistics in Table 3 of the paper, which includes the wgan_gp.py and wgan_gp_eval.py scripts. The former trains a WGAN with the InfoGAN architecture on the dataset specified by the 'target' argument, saving generator model dictionaries in the folder specified by 'temp_dir' at ten equally spaced stages in training. The wgan_gp_eval.py script evaluates the performance of the generator with the different model dictionaries in 'temp_dir'.

The necessary arguments to run wgan_gp.py are:

  • 'target' : The name of the dataset to generate (can be either 'mnist', 'fashion' or 'cifar10').
  • 'data' : Folder where the dataset is located.
  • 'temp_dir' : Folder where the model dictionaries are saved.

Once wgan_gp.py has run, wgan_gp_eval.py should be called with the same arguments for 'target', 'data' and 'temp_dir', and setting the 'model' argument to 'infogan'. If evaluating FID, the 'temp_dir' folder needs to contain the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

Reproducibility

This repository contains two branches: 'main' and 'reproducible'. You are currectly viewing the 'main' branch, which contains a clean version of the code meant to be easy to read and interpret and to run more efficiently than the version on the 'reproducible' branch. The results obtained by running the code on this branch should be nearly (but not perfectly) identical to the results stated in the papers, the differences stemming from the randomness inherent to the experiments. The 'reproducible' branch allows one to replicate exactly the results stated in the paper (random seeds are specified) for the TTC experiments.

Computing architecture and running times

We ran different versions of the code presented here on Compute Canada (https://www.computecanada.ca/) clusters, always using a single NVIDIA V100 Volta or NVIDIA A100 Ampere GPU. Here are rough estimations of the running times for our experiments.

  • MNIST/Fashion MNIST generation training run (TTC): 60-90 minutes.
  • MNIST/Fashion MNIST generation training run (WGAN): 45-90 minutes (this includes misalignments computations).
  • CIFAR10 generation training run: 3-4 hours (TTC), 90 minutes (WGAN-GP).
  • Image translation training run: up to 20 hours.
  • Image denoising training run: 8-10 hours.

Assets

Portions of this code, as well as the datasets used to produce our experimental results, make use of existing assets. We provide here a list of all assets used, along with the licenses under which they are distributed, if specified by the originator:

Official implementation of the MM'21 paper Constrained Graphic Layout Generation via Latent Optimization

[MM'21] Constrained Graphic Layout Generation via Latent Optimization This repository provides the official code for the paper "Constrained Graphic La

Kotaro Kikuchi 73 Dec 27, 2022
Scalable machine learning based time series forecasting

mlforecast Scalable machine learning based time series forecasting. Install PyPI pip install mlforecast Optional dependencies If you want more functio

Nixtla 145 Dec 24, 2022
Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation

Implicit Internal Video Inpainting Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation paper | project

202 Dec 30, 2022
Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV

Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV File YOLOv3 weight can be downloaded

Ngoc Quyen Ngo 2 Mar 27, 2022
Bayesian Image Reconstruction using Deep Generative Models

Bayesian Image Reconstruction using Deep Generative Models R. Marinescu, D. Moyer, P. Golland For technical inquiries, please create a Github issue. F

Razvan Valentin Marinescu 51 Nov 23, 2022
SPT_LSA_ViT - Implementation for Visual Transformer for Small-size Datasets

Vision Transformer for Small-Size Datasets Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song | Paper Inha University Abstract Recently, the Vision

Lee SeungHoon 87 Jan 01, 2023
PyTorch implementation of the ExORL: Exploratory Data for Offline Reinforcement Learning

ExORL: Exploratory Data for Offline Reinforcement Learning This is an original PyTorch implementation of the ExORL framework from Don't Change the Alg

Denis Yarats 52 Jan 01, 2023
Official Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021)

TDEER 🦌 🦒 Official Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021) Overview TDEE

33 Dec 23, 2022
NHL 94 AI contests

nhl94-ai The end goals of this project is to: Train Models that play NHL 94 Support AI vs AI contests in NHL 94 Provide an improved AI opponent for NH

Mathieu Poliquin 2 Dec 06, 2021
JAX-based neural network library

Haiku: Sonnet for JAX Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku What is Haiku? Haiku i

DeepMind 2.3k Jan 04, 2023
Improving Transferability of Representations via Augmentation-Aware Self-Supervision

Improving Transferability of Representations via Augmentation-Aware Self-Supervision Accepted to NeurIPS 2021 TL;DR: Learning augmentation-aware infor

hankook 38 Sep 16, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles

Workspace Permissions Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles. Features Configure foreach workspace

Patrick.St. 18 Sep 26, 2022
Lab Materials for MIT 6.S191: Introduction to Deep Learning

This repository contains all of the code and software labs for MIT 6.S191: Introduction to Deep Learning! All lecture slides and videos are available

Alexander Amini 5.6k Dec 26, 2022
Transformers based fully on MLPs

Awesome MLP-based Transformers papers An up-to-date list of Transformers based fully on MLPs without attention! Why this repo? After transformers and

Fawaz Sammani 35 Dec 30, 2022
Pose estimation with MoveNet Lightning

Pose Estimation With MoveNet Lightning MoveNet is the TensorFlow pre-trained model that identifies 17 different key points of the human body. It is th

Yash Vora 2 Jan 04, 2022
A tiny, friendly, strong baseline code for Person-reID (based on pytorch).

Pytorch ReID Strong, Small, Friendly A tiny, friendly, strong baseline code for Person-reID (based on pytorch). Strong. It is consistent with the new

Zhedong Zheng 3.5k Jan 08, 2023
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing described in paper Discontinuous Grammar as a Foreign Language.

Discontinuous Grammar as a Foreign Language This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing

Daniel Fernández-González 2 Apr 07, 2022
Consumer Fairness in Recommender Systems: Contextualizing Definitions and Mitigations

Consumer Fairness in Recommender Systems: Contextualizing Definitions and Mitigations This is the repository for the paper Consumer Fairness in Recomm

7 Nov 30, 2022