Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN

Overview

Overview

PyTorch 0.4.1 | Python 3.6.5

Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein gradient penalty, least squares, deep regret analytic, bounded equilibrium, relativistic, f-divergence, Fisher, and information generative adversarial networks (GANs), and standard, variational, and bounded information rate variational autoencoders (VAEs).

Paper links are supplied at the beginning of each file with a short summary of the paper. See src folder for files to run via terminal, or notebooks folder for Jupyter notebook visualizations via your local browser. The main file changes can be see in the train, train_D, and train_G of the Trainer class, although changes are not completely limited to only these two areas (e.g. Wasserstein GAN clamps weight in the train function, BEGAN gives multiple outputs from train_D, fGAN has a slight modification in viz_loss function to indicate method used in title).

All code in this repository operates in a generative, unsupervised manner on binary (black and white) MNIST. The architectures are compatible with a variety of datatypes (1D, 2D, square 3D images). Plotting functions work with binary/RGB images. If a GPU is detected, the models use it. Otherwise, they default to CPU. VAE Trainer classes contain methods to visualize latent space representations (see make_all function).

Usage

To initialize an environment:

python -m venv env  
. env/bin/activate  
pip install -r requirements.txt  

For playing around in Jupyer notebooks:

jupyter notebook

To run from Terminal:

cd src
python bir_vae.py

New Models

One of the primary purposes of this repository is to make implementing deep generative model (i.e., GAN/VAE) variants as easy as possible. This is possible because, typically but not always (e.g. BIRVAE), the proposed modifications only apply to the way loss is computed for backpropagation. Thus, the core training class is structured in such a way that most new implementations should only require edits to the train_D and train_G functions of GAN Trainer classes, and the compute_batch function of VAE Trainer classes.

Suppose we have a non-saturating GAN and we wanted to implement a least-squares GAN. To do this, all we have to do is change two lines:

Original (NSGAN)

def train_D(self, images):
  ...
  D_loss = -torch.mean(torch.log(DX_score + 1e-8) + torch.log(1 - DG_score + 1e-8))

  return D_loss
def train_G(self, images):
  ...
  G_loss = -torch.mean(torch.log(DG_score + 1e-8))

  return G_loss

New (LSGAN)

def train_D(self, images):
  ...
  D_loss = (0.50 * torch.mean((DX_score - 1.)**2)) + (0.50 * torch.mean((DG_score - 0.)**2))

  return D_loss
def train_G(self, images):
  ...
  G_loss = 0.50 * torch.mean((DG_score - 1.)**2)

  return G_loss

Model Architecture

The architecture chosen in these implementations for both the generator (G) and discriminator (D) consists of a simple, two-layer feedforward network. While this will give sensible output for MNIST, in practice it is recommended to use deep convolutional architectures (i.e. DCGANs) to get nicer outputs. This can be done by editing the Generator and Discriminator classes for GANs, or the Encoder and Decoder classes for VAEs.

Visualization

All models were trained for 25 epochs with hidden dimension 400, latent dimension 20. Other implementation specifics are as close to the respective original paper (linked) as possible.

Model Epoch 1 Epoch 25 Progress Loss
MMGAN
NSGAN
WGAN
WGPGAN
DRAGAN
BEGAN
LSGAN
RaNSGAN
FisherGAN
InfoGAN
f-TVGAN
f-PearsonGAN
f-JSGAN
f-ForwGAN
f-RevGAN
f-HellingerGAN
VAE
BIRVAE

To Do

Models: CVAE, denoising VAE, adversarial autoencoder | Bayesian GAN, Self-attention GAN, Primal-Dual Wasserstein GAN
Architectures: Add DCGAN option
Datasets: Beyond MNIST

Owner
Shayne O'Brien
NLP / Machine Learning / Network Science. Moved from MIT to Apple 06/2019
Shayne O'Brien
PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

NVIDIA Corporation 1.8k Dec 30, 2022
Benchmarks for Object Detection in Aerial Images

Benchmarks for Object Detection in Aerial Images

Jian Ding 691 Dec 30, 2022
Codebase for the solution that won first place and was awarded the most human-like agent in the 2021 NeurIPS Competition MineRL BASALT Challenge.

KAIROS MineRL BASALT Codebase for the solution that won first place and was awarded the most human-like agent in the 2021 NeurIPS Competition MineRL B

Vinicius G. Goecks 37 Oct 30, 2022
BOOKSUM: A Collection of Datasets for Long-form Narrative Summarization

BOOKSUM: A Collection of Datasets for Long-form Narrative Summarization Authors: Wojciech Kryściński, Nazneen Rajani, Divyansh Agarwal, Caiming Xiong,

Salesforce 125 Dec 31, 2022
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

IDSIA 1.3k Nov 21, 2022
python library for invisible image watermark (blind image watermark)

invisible-watermark invisible-watermark is a python library and command line tool for creating invisible watermark over image.(aka. blink image waterm

Shield Mountain 572 Jan 07, 2023
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
TianyuQi 10 Dec 11, 2022
This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

haifeng xia 32 Oct 26, 2022
Syed Waqas Zamir 906 Dec 30, 2022
《Train in Germany, Test in The USA: Making 3D Object Detectors Generalize》(CVPR 2020)

Train in Germany, Test in The USA: Making 3D Object Detectors Generalize This paper has been accpeted by Conference on Computer Vision and Pattern Rec

Xiangyu Chen 101 Jan 02, 2023
🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗

🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗 This year's first semester Club Info challenge will put you at the head of a car racing

ClubINFO INGI (UCLouvain) 6 Dec 10, 2021
Cognition-aware Cognate Detection

Cognition-aware Cognate Detection The repository which contains our code for our EACL 2021 paper titled, "Cognition-aware Cognate Detection". This wor

Prashant K. Sharma 1 Feb 01, 2022
Official implementation of Long-Short Transformer in PyTorch.

Long-Short Transformer (Transformer-LS) This repository hosts the code and models for the paper: Long-Short Transformer: Efficient Transformers for La

NVIDIA Corporation 198 Dec 29, 2022
IPATool-py: download ipa easily

IPATool-py Python version of IPATool! Installation pip3 install -r requirements.txt Usage Quickstart: download app with specific bundleId into DIR: p

159 Dec 30, 2022
A note taker for NVDA. Allows the user to create, edit, view, manage and export notes to different formats.

Quick Notetaker add-on for NVDA The Quick Notetaker add-on is a wonderful tool which allows writing notes quickly and easily anytime and from any app

5 Dec 06, 2022
Zero-shot Learning by Generating Task-specific Adapters

Code for "Zero-shot Learning by Generating Task-specific Adapters" This is the repository containing code for "Zero-shot Learning by Generating Task-s

INK Lab @ USC 11 Dec 17, 2021
RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching

RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching This repository contains the source code for our paper: RAFT-Stereo: Multilevel

Princeton Vision & Learning Lab 328 Jan 09, 2023