How to Train a GAN? Tips and tricks to make GANs work

Related tags

Deep Learningganhacks
Overview

(this list is no longer maintained, and I am not sure how relevant it is in 2020)

How to Train a GAN? Tips and tricks to make GANs work

While research in Generative Adversarial Networks (GANs) continues to improve the fundamental stability of these models, we use a bunch of tricks to train them and make them stable day to day.

Here are a summary of some of the tricks.

Here's a link to the authors of this document

If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.

1. Normalize the inputs

  • normalize the images between -1 and 1
  • Tanh as the last layer of the generator output

2: A modified loss function

In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

  • because the first formulation has vanishing gradients early on
  • Goodfellow et. al (2014)

In practice, works well:

  • Flip labels when training generator: real = fake, fake = real

3: Use a spherical Z

  • Dont sample from a Uniform distribution

cube

  • Sample from a gaussian distribution

sphere

4: BatchNorm

  • Construct different mini-batches for real and fake, i.e. each mini-batch needs to contain only all real images or all generated images.
  • when batchnorm is not an option use instance normalization (for each sample, subtract mean and divide by standard deviation).

batchmix

5: Avoid Sparse Gradients: ReLU, MaxPool

  • the stability of the GAN game suffers if you have sparse gradients
  • LeakyReLU = good (in both G and D)
  • For Downsampling, use: Average Pooling, Conv2d + stride
  • For Upsampling, use: PixelShuffle, ConvTranspose2d + stride

6: Use Soft and Noisy Labels

  • Label Smoothing, i.e. if you have two target labels: Real=1 and Fake=0, then for each incoming sample, if it is real, then replace the label with a random number between 0.7 and 1.2, and if it is a fake sample, replace it with 0.0 and 0.3 (for example).
    • Salimans et. al. 2016
  • make the labels the noisy for the discriminator: occasionally flip the labels when training the discriminator

7: DCGAN / Hybrid Models

  • Use DCGAN when you can. It works!
  • if you cant use DCGANs and no model is stable, use a hybrid model : KL + GAN or VAE + GAN

8: Use stability tricks from RL

  • Experience Replay
    • Keep a replay buffer of past generations and occassionally show them
    • Keep checkpoints from the past of G and D and occassionaly swap them out for a few iterations
  • All stability tricks that work for deep deterministic policy gradients
  • See Pfau & Vinyals (2016)

9: Use the ADAM Optimizer

  • optim.Adam rules!
    • See Radford et. al. 2015
  • Use SGD for discriminator and ADAM for generator

10: Track failures early

  • D loss goes to 0: failure mode
  • check norms of gradients: if they are over 100 things are screwing up
  • when things are working, D loss has low variance and goes down over time vs having huge variance and spiking
  • if loss of generator steadily decreases, then it's fooling D with garbage (says martin)

11: Dont balance loss via statistics (unless you have a good reason to)

  • Dont try to find a (number of G / number of D) schedule to uncollapse training
  • It's hard and we've all tried it.
  • If you do try it, have a principled approach to it, rather than intuition

For example

while lossD > A:
  train D
while lossG > B:
  train G

12: If you have labels, use them

  • if you have labels available, training the discriminator to also classify the samples: auxillary GANs

13: Add noise to inputs, decay over time

14: [notsure] Train discriminator more (sometimes)

  • especially when you have noise
  • hard to find a schedule of number of D iterations vs G iterations

15: [notsure] Batch Discrimination

  • Mixed results

16: Discrete variables in Conditional GANs

  • Use an Embedding layer
  • Add as additional channels to images
  • Keep embedding dimensionality low and upsample to match image channel size

17: Use Dropouts in G in both train and test phase

Authors

  • Soumith Chintala
  • Emily Denton
  • Martin Arjovsky
  • Michael Mathieu
Owner
Soumith Chintala
/\︿╱\ _________________________________ \0_ 0 /╱\╱____________________________ \▁︹_/
Soumith Chintala
SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals

SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals Abstract Sleep apnea (SA) is a common slee

9 Dec 21, 2022
Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper

LEXA Benchmark Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper (Discovering and Achieving Goals via World Models

Oleg Rybkin 36 Dec 22, 2022
System Design course at HSE (2021)

System Design course at HSE (2021) Wiki-страница курса Структура репозитория: slides - директория с презентациями с занятий tasks - материалы для выпо

22 Dec 25, 2022
Autonomous racing with the Anki Overdrive

Anki Autonomous Racing Autonomous racing with the Anki Overdrive. Using the Overdrive-Python API (https://github.com/xerodotc/overdrive-python) develo

3 Dec 11, 2022
[NeurIPS 2021] "Drawing Robust Scratch Tickets: Subnetworks with Inborn Robustness Are Found within Randomly Initialized Networks" by Yonggan Fu, Qixuan Yu, Yang Zhang, Shang Wu, Xu Ouyang, David Cox, Yingyan Lin

Drawing Robust Scratch Tickets: Subnetworks with Inborn Robustness Are Found within Randomly Initialized Networks Yonggan Fu, Qixuan Yu, Yang Zhang, S

12 Dec 11, 2022
Restricted Boltzmann Machines in Python.

How to Use First, initialize an RBM with the desired number of visible and hidden units. rbm = RBM(num_visible = 6, num_hidden = 2) Next, train the m

Edwin Chen 928 Dec 30, 2022
The King is Naked: on the Notion of Robustness for Natural Language Processing

the-king-is-naked: on the notion of robustness for natural language processing AAAI2022 DISCLAIMER:This repo will be updated soon with instructions on

Iperboreo_ 1 Nov 24, 2022
Implementation of Pix2Seq in PyTorch

pix2seq-pytorch Implementation of Pix2Seq paper Different from the paper image input size 1280 bin size 1280 LambdaLR scheduler used instead of Linear

Tony Shin 9 Dec 15, 2022
🛠 All-in-one web-based IDE specialized for machine learning and data science.

All-in-one web-based development environment for machine learning Getting Started • Features & Screenshots • Support • Report a Bug • FAQ • Known Issu

Machine Learning Tooling 2.9k Jan 09, 2023
A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or simply to separate onnx files to any size you want.

sne4onnx A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or

Katsuya Hyodo 10 Aug 30, 2022
A Human-in-the-Loop workflow for creating HD images from text

A Human-in-the-Loop? workflow for creating HD images from text DALL·E Flow is an interactive workflow for generating high-definition images from text

Jina AI 2.5k Jan 02, 2023
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022
A general, feasible, and extensible framework for classification tasks.

Pytorch Classification A general, feasible and extensible framework for 2D image classification. Features Easy to configure (model, hyperparameters) T

Eugene 26 Nov 22, 2022
An open-source project for applying deep learning to medical scenarios

Auto Vaidya An open source solution for creating end-end web app for employing the power of deep learning in various clinical scenarios like implant d

Smaranjit Ghose 18 May 29, 2022
Zero-shot Learning by Generating Task-specific Adapters

Code for "Zero-shot Learning by Generating Task-specific Adapters" This is the repository containing code for "Zero-shot Learning by Generating Task-s

INK Lab @ USC 11 Dec 17, 2021
This is implementation of AlexNet(2012) with 3D Convolution on TensorFlow (AlexNet 3D).

AlexNet_3dConv TensorFlow implementation of AlexNet(2012) by Alex Krizhevsky, with 3D convolutiional layers. 3D AlexNet Network with a standart AlexNe

Denis Timonin 41 Jan 16, 2022
68 keypoint annotations for COFW test data

68 keypoint annotations for COFW test data This repository contains manually annotated 68 keypoints for COFW test data (original annotation of CFOW da

31 Dec 06, 2022
[ACL 2022] LinkBERT: A Knowledgeable Language Model 😎 Pretrained with Document Links

LinkBERT: A Knowledgeable Language Model Pretrained with Document Links This repo provides the model, code & data of our paper: LinkBERT: Pretraining

Michihiro Yasunaga 264 Jan 01, 2023
Fast convergence of detr with spatially modulated co-attention

Fast convergence of detr with spatially modulated co-attention Usage There are no extra compiled components in SMCA DETR and package dependencies are

peng gao 135 Dec 07, 2022
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 04, 2022