How to Train a GAN? Tips and tricks to make GANs work

Related tags

Deep Learningganhacks
Overview

(this list is no longer maintained, and I am not sure how relevant it is in 2020)

How to Train a GAN? Tips and tricks to make GANs work

While research in Generative Adversarial Networks (GANs) continues to improve the fundamental stability of these models, we use a bunch of tricks to train them and make them stable day to day.

Here are a summary of some of the tricks.

Here's a link to the authors of this document

If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.

1. Normalize the inputs

  • normalize the images between -1 and 1
  • Tanh as the last layer of the generator output

2: A modified loss function

In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

  • because the first formulation has vanishing gradients early on
  • Goodfellow et. al (2014)

In practice, works well:

  • Flip labels when training generator: real = fake, fake = real

3: Use a spherical Z

  • Dont sample from a Uniform distribution

cube

  • Sample from a gaussian distribution

sphere

4: BatchNorm

  • Construct different mini-batches for real and fake, i.e. each mini-batch needs to contain only all real images or all generated images.
  • when batchnorm is not an option use instance normalization (for each sample, subtract mean and divide by standard deviation).

batchmix

5: Avoid Sparse Gradients: ReLU, MaxPool

  • the stability of the GAN game suffers if you have sparse gradients
  • LeakyReLU = good (in both G and D)
  • For Downsampling, use: Average Pooling, Conv2d + stride
  • For Upsampling, use: PixelShuffle, ConvTranspose2d + stride

6: Use Soft and Noisy Labels

  • Label Smoothing, i.e. if you have two target labels: Real=1 and Fake=0, then for each incoming sample, if it is real, then replace the label with a random number between 0.7 and 1.2, and if it is a fake sample, replace it with 0.0 and 0.3 (for example).
    • Salimans et. al. 2016
  • make the labels the noisy for the discriminator: occasionally flip the labels when training the discriminator

7: DCGAN / Hybrid Models

  • Use DCGAN when you can. It works!
  • if you cant use DCGANs and no model is stable, use a hybrid model : KL + GAN or VAE + GAN

8: Use stability tricks from RL

  • Experience Replay
    • Keep a replay buffer of past generations and occassionally show them
    • Keep checkpoints from the past of G and D and occassionaly swap them out for a few iterations
  • All stability tricks that work for deep deterministic policy gradients
  • See Pfau & Vinyals (2016)

9: Use the ADAM Optimizer

  • optim.Adam rules!
    • See Radford et. al. 2015
  • Use SGD for discriminator and ADAM for generator

10: Track failures early

  • D loss goes to 0: failure mode
  • check norms of gradients: if they are over 100 things are screwing up
  • when things are working, D loss has low variance and goes down over time vs having huge variance and spiking
  • if loss of generator steadily decreases, then it's fooling D with garbage (says martin)

11: Dont balance loss via statistics (unless you have a good reason to)

  • Dont try to find a (number of G / number of D) schedule to uncollapse training
  • It's hard and we've all tried it.
  • If you do try it, have a principled approach to it, rather than intuition

For example

while lossD > A:
  train D
while lossG > B:
  train G

12: If you have labels, use them

  • if you have labels available, training the discriminator to also classify the samples: auxillary GANs

13: Add noise to inputs, decay over time

14: [notsure] Train discriminator more (sometimes)

  • especially when you have noise
  • hard to find a schedule of number of D iterations vs G iterations

15: [notsure] Batch Discrimination

  • Mixed results

16: Discrete variables in Conditional GANs

  • Use an Embedding layer
  • Add as additional channels to images
  • Keep embedding dimensionality low and upsample to match image channel size

17: Use Dropouts in G in both train and test phase

Authors

  • Soumith Chintala
  • Emily Denton
  • Martin Arjovsky
  • Michael Mathieu
Owner
Soumith Chintala
/\︿╱\ _________________________________ \0_ 0 /╱\╱____________________________ \▁︹_/
Soumith Chintala
A simple, high level, easy-to-use open source Computer Vision library for Python.

ZoomVision : Slicing Aid Detection A simple, high level, easy-to-use open source Computer Vision library for Python. Installation Installing dependenc

Nurettin Sinanoğlu 2 Mar 04, 2022
Liver segmentation using MONAI and pytorch

Machine Learning use case in the field of Healthcare. In this project MONAI and pytorch frameworks are used for 3D Liver segmentation.

Abhishek Gajbhiye 2 May 30, 2022
LinkNet - This repository contains our Torch7 implementation of the network developed by us at e-Lab.

LinkNet This repository contains our Torch7 implementation of the network developed by us at e-Lab. You can go to our blogpost or read the article Lin

e-Lab 158 Nov 11, 2022
Cereal box identification in store shelves using computer vision and a single train image per model.

Product Recognition on Store Shelves Description You can read the task description here. Report You can read and download our report here. Step A - Mu

Nicholas Baraghini 1 Jan 21, 2022
Evaluating AlexNet features at various depths

Linear Separability Evaluation This repo provides the scripts to test a learned AlexNet's feature representation performance at the five different con

Yuki M. Asano 32 Dec 30, 2022
PiRapGenerator - Make anyone rap the digits of pi

PiRapGenerator Make anyone rap the digits of pi (sample files are of Ted Nivison

7 Oct 02, 2022
Build upon neural radiance fields to create a scene-specific implicit 3D semantic representation, Semantic-NeRF

Semantic-NeRF: Semantic Neural Radiance Fields Project Page | Video | Paper | Data In-Place Scene Labelling and Understanding with Implicit Scene Repr

Shuaifeng Zhi 243 Jan 07, 2023
This repository contains the code for TACL2021 paper: SummaC: Re-Visiting NLI-based Models for Inconsistency Detection in Summarization

SummaC: Summary Consistency Detection This repository contains the code for TACL2021 paper: SummaC: Re-Visiting NLI-based Models for Inconsistency Det

Philippe Laban 24 Jan 03, 2023
Pyeventbus: a publish/subscribe event bus

pyeventbus pyeventbus is a publish/subscribe event bus for Python 2.7. simplifies the communication between python classes decouples event senders and

15 Apr 21, 2022
🔥 TensorFlow Code for technical report: "YOLOv3: An Incremental Improvement"

🆕 Are you looking for a new YOLOv3 implemented by TF2.0 ? If you hate the fucking tensorflow1.x very much, no worries! I have implemented a new YOLOv

3.6k Dec 26, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Ibai Gorordo 14 Dec 09, 2022
tree-math: mathematical operations for JAX pytrees

tree-math: mathematical operations for JAX pytrees tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterati

Google 137 Dec 28, 2022
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
DumpSMBShare - A script to dump files and folders remotely from a Windows SMB share

DumpSMBShare A script to dump files and folders remotely from a Windows SMB shar

Podalirius 178 Jan 06, 2023
dyld_shared_cache processing / Single-Image loading for BinaryNinja

Dyld Shared Cache Parser Author: cynder (kat) Dyld Shared Cache Support for BinaryNinja Without any of the fuss of requiring manually loading several

cynder 76 Dec 28, 2022
Accelerated Multi-Modal MR Imaging with Transformers

Accelerated Multi-Modal MR Imaging with Transformers Dependencies numpy==1.18.5 scikit_image==0.16.2 torchvision==0.8.1 torch==1.7.0 runstats==1.8.0 p

54 Dec 16, 2022
Experimental solutions to selected exercises from the book [Advances in Financial Machine Learning by Marcos Lopez De Prado]

Advances in Financial Machine Learning Exercises Experimental solutions to selected exercises from the book Advances in Financial Machine Learning by

Brian 1.4k Jan 04, 2023
Replication Code for "Self-Supervised Bug Detection and Repair" NeurIPS 2021

Self-Supervised Bug Detection and Repair This is the reference code to replicate the research in Self-Supervised Bug Detection and Repair in NeurIPS 2

Microsoft 85 Dec 24, 2022