RP-GAN: Stable GAN Training with Random Projections

Overview

RP-GAN: Stable GAN Training with Random Projections

Interpolated images from our GAN

This repository contains a reference implementation of the algorithm described in the paper:

Behnam Neyshabur, Srinadh Bhojanapalli, and Ayan Chakrabarti, "Stabilizing GAN Training with Multiple Random Projections," arXiv:1705.07831 [cs.LG], 2017.

Pre-trained generator models are not included in the repository due to their size, but are available as binary downloads as part of the release. This code and data is being released for research use. If you use the code in research that results in a publication, we request that you kindly cite the above paper. Please direct any questions to [email protected].

Requirements

The code uses the tensorflow library, and has been tested with versions 0.9 and 0.11 with both Python2 and Python3. You will need a modern GPU for training in a reasonable amount of time, but the sampling code should work on a CPU.

Sampling with Trained Models

We first describe usage of scripts for sampling from trained models. You can use these scripts for models you train yourself, or use the provided pre-trained models.

Pre-trained Models

We provide a number of pre-trained models in the release, corresponding to the experiments in the paper. The parameters of each model (both for training and sampling) are described in .py files the exp/ directory. face1.py describes a face image model trained in the traditional setting with a single discriminator, while faceNN.py are models trained with multiple discriminators each acting on one of NN random low-dimensional projections. face48.py describes the main face model used in our experiments, while dog12.py is the model trained with 12 discriminators on the Imagenet-Canines set. After downloading the trained model archive files, unzip them in the repository root directory. This should create files in sub-directories of models/.

Generating Samples

Use sample.py to generate samples using any of trained models as:

$ ./sample.py expName[,seed] out.png [iteration]

where expName is the name of the experiment file (without the .py extension), and out.png is the file to save the generated samples to. The script accepts optional parameters: seed (default 0) specifies the random seed used to generate the noise vectors provided to the generator, and iteration (default: max iteration available as saved file) specifies which model file to use in case multiple snapshots are available. E.g.,

$ ./sample.py face48 out.png      # Sample from the face48 experiment, using 
                                  # seed 0, and the latest model file.
$ ./sample.py face48,100 out.png  # Sample from the face48 experiment, using
                                  # seed 100, and the latest model file.
$ ./sample.py face1 out.png       # Sample from the single discriminator face
                                  # experiment, and the latest model file.
$ ./sample.py face1 out.png 40000 # Sample from the single discriminator face
                                  # experiment, and the 40k iterations model.
Interpolating in Latent Space

We also provide a script to produce interpolated images like the ones at the top of this page. However, before you can use this script, you need to create a version of the model file that contains the population mean-variance statistics of the activations to be used in batch-norm la(sample.py above uses batch norm statistics which is fine since it is working with a large batch of noise vectors. However, for interpolation, you will typically be working with smaller, more correlated, batches, and therefore should use batch statistics).

To create this version of the model file, use the provided script fixbn.py as:

$ CUDA_VISIBLE_DEVICES= ./fixbn.py expName [iteration]

This will create a second version of the model weights file (with extension .bgmodel.npz instead of .gmodel.npz) that also stores the batch statistics. Like for sample.py, you can provide a second optional argument to specify a specific model snapshot corresponding to an iteration number.

Note that we call the script with CUDA_VISIBLE_DEVICES= to force tensorflow to use the CPU instead of the GPU. This is because we compute these stats over a relatively large batch which typically doesn't fit in GPU memory (and since it's only one forward pass, running time isn't really an issue).

You only need to call fixbn.py once, and after that, you can use the script interp.py to create interpolated samples. The script will generate multiple rows of images, each producing samples from noise vectors interpolated between a pair from left-to-right. The script lets you specify these pairs of noise vectors as IDs:

$ ./interp.py expName[,seed[,iteration]] out.png lid,rid lid,rid ....

The first parameter now has two optional comma-separated arguments beyond the model name for seed and iteration. After this and the output file name, it agrees an arbitrary number of pairs of left-right image IDs, for each row of desired images in the output. These IDs correspond to the number of the image, in reading order, in the output generated by sample.py (with the same seed). For example, to create the images at the top of the page, use:

$ ./interp.py face48 out.png 137,65 146,150 15,138 54,72 38,123 36,93

Training

To train your own model, you will need to create a new model file (say myown.py) in the exp/ directory. See the existing model files for reference. Here is an explanation of some of the key parameters:

  • wts_dir: Directory in which to store model weights. This directory must already exist.
  • imsz: Resolution / Size of the images (will be square color images of size imsz x imsz).
  • lfile: Path to a list file for the images you want to train on, where each line of the file contains a path to an image.
  • crop: Boolean (True or False). Indicates whether the images are already the correct resolution, or need to be cropped. If True, these images will first be resized so that the smaller side matches imsz, and then a random crop along the other dimension will be used for training.

Before you begin training, you will need to create a file called filts.npz which defines the convolutional filters for the random projections. See the filts/ directory for the filters used for the pre-trained models, as well as instructions on a script for creating your own. On

Once you have created the model file and prepared the directory, you can begin training by using the train.py script as:

$ ./train.py myown

where the first parameter is the name of your model file.

We also provide a script for traditional training---baseline_train.py---with a single discriminator acting on the original image. It is used in the same way, except it doesn't require a filts.npz file in the weights directory.


Acknowledgments

This work was supported by the National Science Foundation under award no. IIS-1820693. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors, and do not necessarily reflect the views of the National Science Foundation.

You might also like...
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

This is the official implementation of the paper
This is the official implementation of the paper "Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation".

[CVPRW 2021] - Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation

TeST: Temporal-Stable Thresholding for Semi-supervised Learning
TeST: Temporal-Stable Thresholding for Semi-supervised Learning

TeST: Temporal-Stable Thresholding for Semi-supervised Learning TeST Illustration Semi-supervised learning (SSL) offers an effective method for large-

Simple converter for deploying Stable-Baselines3 model to TFLite and/or Coral

Running SB3 developed agents on TFLite or Coral Introduction I've been using Stable-Baselines3 to train agents against some custom Gyms, some of which

RL agent to play μRTS with Stable-Baselines3
RL agent to play μRTS with Stable-Baselines3

Gym-μRTS with Stable-Baselines3/PyTorch This repo contains an attempt to reproduce Gridnet PPO with invalid action masking algorithm to play μRTS usin

Additional code for Stable-baselines3 to load and upload models from the Hub.

Hugging Face x Stable-baselines3 A library to load and upload Stable-baselines3 models from the Hub. Installation With pip Examples [Todo: add colab t

Self-driving car env with PPO algorithm from stable baseline3
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

(SIGIR2020) “Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback’’

Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback About This repository accompanies the real-world experiments conducted i

Releases(v1.0)
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
Implementation of: "Exploring Randomly Wired Neural Networks for Image Recognition"

RandWireNN Unofficial PyTorch Implementation of: Exploring Randomly Wired Neural Networks for Image Recognition. Results Validation result on Imagenet

Seung-won Park 684 Nov 02, 2022
Robustness between the worst and average case

Robustness between the worst and average case A repository that implements intermediate robustness training and evaluation from the NeurIPS 2021 paper

CMU Locus Lab 16 Dec 02, 2022
Keras-1D-NN-Classifier

Keras-1D-NN-Classifier This code is based on the reference codes linked below. reference 1, reference 2 This code is for 1-D array data classification

Jae-Hoon Shim 6 May 18, 2021
nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation "

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation ". Please

jsguo 610 Dec 28, 2022
The project covers common metrics for super-resolution performance evaluation.

Super-Resolution Performance Evaluation Code The project covers common metrics for super-resolution performance evaluation. Metrics support The script

xmy 10 Aug 03, 2022
🍅🍅🍅YOLOv5-Lite: lighter, faster and easier to deploy. Evolved from yolov5 and the size of model is only 1.7M (int8) and 3.3M (fp16). It can reach 10+ FPS on the Raspberry Pi 4B when the input size is 320×320~

YOLOv5-Lite:lighter, faster and easier to deploy Perform a series of ablation experiments on yolov5 to make it lighter (smaller Flops, lower memory, a

pogg 1.5k Jan 05, 2023
Research - dataset and code for 2016 paper Learning a Driving Simulator

the people's comma the paper Learning a Driving Simulator the comma.ai driving dataset 7 and a quarter hours of largely highway driving. Enough to tra

comma.ai 4.1k Jan 02, 2023
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
Code for the published paper : Learning to recognize rare traffic sign

Improving traffic sign recognition by active search This repo contains code for the paper : "Learning to recognise rare traffic signs" How to use this

samsja 4 Jan 05, 2023
Constrained Language Models Yield Few-Shot Semantic Parsers

Constrained Language Models Yield Few-Shot Semantic Parsers This repository contains tools and instructions for reproducing the experiments in the pap

Microsoft 43 Nov 23, 2022
A repo that contains all the mesh keys needed for mesh backend, along with a code example of how to use them in python

Mesh-Keys A repo that contains all the mesh keys needed for mesh backend, along with a code example of how to use them in python Have been seeing alot

Joseph 53 Dec 13, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
Honours project, on creating a depth estimation map from two stereo images of featureless regions

image-processing This module generates depth maps for shape-blocked-out images Install If working with anaconda, then from the root directory: conda e

2 Oct 17, 2022
FAST-RIR: FAST NEURAL DIFFUSE ROOM IMPULSE RESPONSE GENERATOR

This is the official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR) for generating room impulse responses (RIRs) for a given acoustic environment.

Anton Jeran Ratnarajah 89 Dec 22, 2022
Efficient Training of Visual Transformers with Small Datasets

Official codes for "Efficient Training of Visual Transformers with Small Datasets", NerIPS 2021.

Yahui Liu 112 Dec 25, 2022
SOTA model in CIFAR10

A PyTorch Implementation of CIFAR Tricks 调研了CIFAR10数据集上各种trick,数据增强,正则化方法,并进行了实现。目前项目告一段落,如果有更好的想法,或者希望一起维护这个项目可以提issue或者在我的主页找到我的联系方式。 0. Requirement

PJDong 58 Dec 21, 2022
ncnn is a high-performance neural network inference framework optimized for the mobile platform

ncnn ncnn is a high-performance neural network inference computing framework optimized for mobile platforms. ncnn is deeply considerate about deployme

Tencent 16.2k Jan 05, 2023
A fast MoE impl for PyTorch

An easy-to-use and efficient system to support the Mixture of Experts (MoE) model for PyTorch.

Rick Ho 873 Jan 09, 2023