Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Overview

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

[paper (NeurIPS 2021)] [paper (arXiv)] [code]

Authors: Zinan Lin, Vyas Sekar, Giulia Fanti

Abstract: Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs). However, there is currently limited understanding of why SN is effective. In this work, we show that SN controls two important failure modes of GAN training: exploding and vanishing gradients. Our proofs illustrate a (perhaps unintentional) connection with the successful LeCun initialization. This connection helps to explain why the most popular implementation of SN for GANs requires no hyper-parameter tuning, whereas stricter implementations of SN have poor empirical performance out-of-the-box. Unlike LeCun initialization which only controls gradient vanishing at the beginning of training, SN preserves this property throughout training. Building on this theoretical understanding, we propose a new spectral normalization technique: Bidirectional Scaled Spectral Normalization (BSSN), which incorporates insights from later improvements to LeCun initialization: Xavier initialization and Kaiming initialization. Theoretically, we show that BSSN gives better gradient control than SN. Empirically, we demonstrate that it outperforms SN in sample quality and training stability on several benchmark datasets.


This repo contains the codes for reproducing the experiments of our BSN and different SN variants in the paper. The codes were tested under Python 2.7.5, TensorFlow 1.14.0.

Preparing datasets

CIFAR10

Download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz (or from other sources).

STL10

Download stl10_binary.tar.gz from http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz (or from other sources), and put it in dataset_preprocess/STL10 folder. Then run python preprocess.py. This code will resize the images into 48x48x3 format, and save the images in stl10.npy.

CelebA

Download img_align_celeba.zip from https://www.kaggle.com/jessicali9530/celeba-dataset (or from other sources), and put it in dataset_preprocess/CelebA folder. Then run python preprocess.py. This code will crop and resize the images into 64x64x3 format, and save the images in celeba.npy.

ImageNet

Download ILSVRC2012_img_train.tar from http://www.image-net.org/ (or from other sources), and put it in dataset_preprocess/ImageNet folder. Then run python preprocess.py. This code will crop and resize the images into 128x128x3 format, and save the images in ILSVRC2012folder. Each subfolder in ILSVRC2012 folder corresponds to one class. Each npy file in the subfolders corresponds to an image.

Training BSN and SN variants

Prerequisites

The codes are based on GPUTaskScheduler library, which helps you automatically schedule the jobs among GPU nodes. Please install it first. You may need to change GPU configurations according to the devices you have. The configurations are set in config.py in each directory. Please refer to GPUTaskScheduler's GitHub page for the details of how to make proper configurations.

You can also run these codes without GPUTaskScheduler. Just run python gan.py in gan subfolders.

CIFAR10, STL10, CelebA

Preparation

Copy the preprocessed datasets from the previous steps into the following paths:

  • CIFAR10: /data/CIFAR10/cifar-10-python.tar.gz.
  • STL10: /data/STL10/cifar-10-stl10.npy.
  • CelebA: /data/CelebA/celeba.npy.

Here means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-CNN.
  • SN with the same gammas: same_gamma-CNN.
  • SN with different gammas: diff_gamma-CNN.

Alternatively, you can directly modify the dataset paths in /gan_task.py to the path of the preprocessed dataset folders.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

ImageNet

Preparation

Copy the preprocessed folder ILSVRC2012 from the previous steps to /data/imagenet/ILSVRC2012, where means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-ResNet.

Alternatively, you can directly modify the dataset path in /gan_task.py to the path of the preprocessed folder ILSVRC2012.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

The code supports multi-GPU training for speed-up, by separating each data batch equally among multiple GPUs. To do that, you only need to make minor modifications in config.py. For example, if you have two GPUs with IDs 0 and 1, then all you need to do is to (1) change "gpu": ["0"] to "gpu": [["0", "1"]], and (2) change "num_gpus": [1] to "num_gpus": [2]. Note that the number of GPUs might influence the results because in this implementation the batch normalization layers on different GPUs are independent. In our experiments, we were using only one GPU.

Results

The code generates the following result files/folders:

  • /results/ /worker.log : Standard output and error from the code.
  • /results/ /metrics.csv : Inception Score and FID during training.
  • /results/ /sample/*.png : Generated images during training.
  • /results/ /checkpoint/* : TensorFlow checkpoints.
  • /results/ /time.txt : Training iteration timestamps.
Owner
Zinan Lin
Ph.D. student at Electrical and Computer Engineering, Carnegie Mellon University
Zinan Lin
Abstractive opinion summarization system (SelSum) and the largest dataset of Amazon product summaries (AmaSum). EMNLP 2021 conference paper.

Learning Opinion Summarizers by Selecting Informative Reviews This repository contains the codebase and the dataset for the corresponding EMNLP 2021

Arthur Bražinskas 39 Jan 01, 2023
Improved Fitness Optimization Landscapes for Sequence Design

ReLSO Improved Fitness Optimization Landscapes for Sequence Design Description Citation How to run Training models Original data source Description In

Krishnaswamy Lab 44 Dec 20, 2022
STRIVE: Scene Text Replacement In Videos

STRIVE: Scene Text Replacement In Videos Dataset Types: RoboText SynthText RealWorld videos RoboText : Videos of texts collected using navigation robo

15 Jul 11, 2022
PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Pre

Facebook Research 437 Dec 23, 2022
Code for the paper "Attention Approximates Sparse Distributed Memory"

Attention Approximates Sparse Distributed Memory - Codebase This is all of the code used to run analyses in the paper "Attention Approximates Sparse D

Trenton Bricken 14 Dec 05, 2022
ML models implementation practice

Let's implement various ML algorithms with numpy/tf Vanilla Neural Network https://towardsdatascience.com/lets-code-a-neural-network-in-plain-numpy-ae

Jinsoo Heo 4 Jul 04, 2021
3D Human Pose Machines with Self-supervised Learning

3D Human Pose Machines with Self-supervised Learning Keze Wang, Liang Lin, Chenhan Jiang, Chen Qian, and Pengxu Wei, “3D Human Pose Machines with Self

Chenhan Jiang 398 Dec 20, 2022
Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al.

nam-pytorch Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al. [abs, pdf] Installation You can access nam-pytorch vi

Rishabh Anand 11 Mar 14, 2022
UnsupervisedR&R: Unsupervised Pointcloud Registration via Differentiable Rendering

UnsupervisedR&R: Unsupervised Pointcloud Registration via Differentiable Rendering This repository holds all the code and data for our recent work on

Mohamed El Banani 118 Dec 06, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

34 Dec 28, 2022
GLaRA: Graph-based Labeling Rule Augmentation for Weakly Supervised Named Entity Recognition

GLaRA: Graph-based Labeling Rule Augmentation for Weakly Supervised Named Entity Recognition

Xinyan Zhao 29 Dec 26, 2022
Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning"

Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning" Getting started Prerequisites CUD

70 Dec 02, 2022
This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

215355 1 Dec 16, 2021
Contra is a lightweight, production ready Tensorflow alternative for solving time series prediction challenges with AI

Contra AI Engine A lightweight, production ready Tensorflow alternative developed by Styvio styvio.com » How to Use · Report Bug · Request Feature Tab

styvio 14 May 25, 2022
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
Project page of the paper 'Analyzing Perception-Distortion Tradeoff using Enhanced Perceptual Super-resolution Network' (ECCVW 2018)

EPSR (Enhanced Perceptual Super-resolution Network) paper This repo provides the test code, pretrained models, and results on benchmark datasets of ou

Subeesh Vasu 78 Nov 19, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
In this project we predict the forest cover type using the cartographic variables in the training/test datasets.

Kaggle Competition: Forest Cover Type Prediction In this project we predict the forest cover type (the predominant kind of tree cover) using the carto

Marianne Joy Leano 1 Mar 15, 2022
Pytorch Implementation of Various Point Transformers

Pytorch Implementation of Various Point Transformers Recently, various methods applied transformers to point clouds: PCT: Point Cloud Transformer (Men

Neil You 434 Dec 30, 2022
[ECCVW2020] Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DiMP)

Feel free to visit my homepage Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DIMP) [ECCVW2020 paper] Presentation

Seokeon Choi 35 Oct 26, 2022