Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Overview

Note: this repo has been discontinued, please check code for newer version of the paper here

Weight Normalized GAN

Code for the paper "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks".

About the code

Here two versions are provided, one for torch and one for PyTorch.

The code used for the experiments in the paper was in torch and was a bit messy, with hand written backward pass of weight normalized layers and other staff used to test various ideas about GANs that are unrelated to the paper. So we decided to clean up the code and port it to PyTorch (read: autograd). However, we are not able to exactly reproduce the results in the paper with the PyTorch code. So we had to port it back to torch to see the difference.

We did find and fix a mathematical bug in gradient computation (Ouch!) in our implementation of weight normalization, which means that the code used for the paper was incorrect and you might not be able to exactly reproduce the results in the paper with the current code. We need to redo some experiments to make sure everything still works. It seems that now a learning rate of 0.00002 gives very good samples but the speed is not very impressive in the beginning; 0.0001 speeds up training even more than in the paper but give worse samples; 0.00005 balances between the two and also give lower reconstruction loss than in the paper. The example below uses 0.00002.

That being said, we can still find some differences in the samples generated by the two versions of code. We think that the torch version is better, so you are adviced to use that version for training. But you should definitely read the PyTorch version to get a better idea of how our method works. We checked this time that in the torch code, the computed gradients wrt the weight vectors are indeed orthogonal to the weight vectors, so hopefully the difference is not caused by another mathematical bug. It could be a numerical issue since the gradient are not computed in exactly the same way. Or I might have made stupid mistakes as I have been doing machine learning for only half a year. We are still investigating.

Usage

The two versions accept the exact same set of arguments except that there is an additional option to set ID of gpu to use in the torch version.

Before training, you need to prepare the data. For torch you need lmdb.torch for LSUN and cifar.torch for CIFAR-10. Split the dataset into training data and test data with split_data.lua/py. Use --running and --final to set number of test samples for running test and final test respectively.

The LSUN loader creates a cache if there isn't one. It takes some time. The loader for custom dataset from a image folder requires images of each class to be in one subfolder, so if you use say CelebA where there is no classes you need to manually create a dummy class.

To train, run main.lua/py. The only ones you must specify are the --dataset, --dataroot, --save_path and --image_size. By default it trains a vanilla model. Use --norm batch or --norm weight to try different normalizations.

The width and the height of the images are not required to be equal. Nor do they have to be powers of two. They only have to both be even numbers. Image size settings work as follows: if --crop_size is specified or if both --crop_width and --crop_height are specified, the training samples are first cropped to the center. Then, if --width and --height are both specified, the training samples are resized to that size. Otherwise, they are resized so that the aspect ratio is kept and the length of the shorter edge equals --image_size, and then cropped to a square.

If --nlayer is set, that many down/up concolution layers are used. Otherwise such layers are added until the size of the feature map is smaller than 8x8. --nfeature specifies the number of features of the first convolution layer.

Set --load_path to continue a saved training.

To test a trained model, use --final_test. Make sure to also use a larger --test_steps since the default value is for the running test during training. By default it finds the best model in load_path, to use another network, set --net

Read the code to see how other arguments work.

Use plot.lua/py to plot the loss curves. The PyTorch version uses PyGnuplot (it sux).

Example

th main.lua --dataset folder --dataroot /path/to/img_align_celeba --crop_size 160 --image_size 160 --code_size 256 --norm weight --lr 0.00002 --save_path /path/to/save/folder

This should give you something like this in 200,000 iterations: celeba example

Additional notes

The WN model might fail in the first handful of iterations. This happens especially often if the network is deeper (on LSUN). Just restart training. If it get past iteration 5 it should continue to train without trouble. This effect could be reduced by using a smaller learning rate for the first couple of iterations.

Extra stuff

At request, added --ls flag to use least square loss.

Owner
Sitao Xiang
Computer Graphics PhD student at University of Southern California. Twitter: StormRaiser123
Sitao Xiang
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
Final project for Intro to CS class.

Financial Analysis Web App https://share.streamlit.io/mayurk1/fin-web-app-final-project/webApp.py 1. Project Description This project is a technical a

Mayur Khanna 1 Dec 10, 2021
abess: Fast Best-Subset Selection in Python and R

abess: Fast Best-Subset Selection in Python and R Overview abess (Adaptive BEst Subset Selection) library aims to solve general best subset selection,

297 Dec 21, 2022
[NeurIPS 2021] Source code for the paper "Qu-ANTI-zation: Exploiting Neural Network Quantization for Achieving Adversarial Outcomes"

Qu-ANTI-zation This repository contains the code for reproducing the results of our paper: Qu-ANTI-zation: Exploiting Quantization Artifacts for Achie

Secure AI Systems Lab 8 Mar 26, 2022
Locally cache assets that are normally streamed in POPULATION: ONE

Population One Localizer This is no longer needed as of the build shipped on 03/03/22, thank you bigbox :) Locally cache assets that are normally stre

Ahman Woods 2 Mar 04, 2022
AnimationKit: AI Upscaling & Interpolation using Real-ESRGAN+RIFE

ALPHA 2.5: Frostbite Revival (Released 12/23/21) Changelog: [ UI ] Chained design. All steps link to one another! Use the master override toggles to s

87 Nov 16, 2022
PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids

PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids The electric grid is a key enabling infrastructure for the a

Texas A&M Engineering Research 19 Jan 07, 2023
🗣️ Microsoft Edge TTS for Home Assistant, no need for app_key

Microsoft Edge TTS for Home Assistant This component is based on the TTS service of Microsoft Edge browser, no need to apply for app_key. Install Down

152 Dec 31, 2022
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
A library for differentiable nonlinear optimization.

Theseus A library for differentiable nonlinear optimization built on PyTorch to support constructing various problems in robotics and vision as end-to

Meta Research 1.1k Dec 30, 2022
Marine debris detection with commercial satellite imagery and deep learning.

Marine debris detection with commercial satellite imagery and deep learning. Floating marine debris is a global pollution problem which threatens mari

Inter Agency Implementation and Advanced Concepts 56 Dec 16, 2022
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Jan 01, 2023
The final project of "Applying AI to EHR Data" of "AI for Healthcare" nanodegree - Udacity.

Patient Selection for Diabetes Drug Testing Project Overview EHR data is becoming a key source of real-world evidence (RWE) for the pharmaceutical ind

Omar Laham 1 Jan 14, 2022
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

Quan Nguyen 7 May 11, 2022
[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Social-NCE + CrowdNav Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN This is an official implementation for Social NCE: Contr

VITA lab at EPFL 125 Dec 23, 2022
My personal Home Assistant configuration.

About This is my personal Home Assistant configuration. My guiding princile is to have full local control of all my devices. I intend everything to ru

Chris Turra 13 Jun 07, 2022
WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution

WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution This code belongs to the paper [1] available at https://arx

Fabian Altekrueger 5 Jun 02, 2022
[IEEE TPAMI21] MobileSal: Extremely Efficient RGB-D Salient Object Detection [PyTorch & Jittor]

MobileSal IEEE TPAMI 2021: MobileSal: Extremely Efficient RGB-D Salient Object Detection This repository contains full training & testing code, and pr

Yu-Huan Wu 52 Jan 06, 2023