Code for the paper "Adversarial Generator-Encoder Networks"

Related tags

Deep Learninggan
Overview

This repository contains code for the paper

"Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky.

Pretrained models

This is how you can access the models used to generate figures in the paper.

  1. First install dev version of pytorch 0.2 and make sure you have jupyter notebook ready.

  2. Then download the models with the script:

bash download_pretrained.sh
  1. Run jupyter notebook and go through evaluate.ipynb.

Here is an example of samples and reconstructions for imagenet, celeba and cifar10 datasets generated with evaluate.ipynb.

Celeba

Samples Reconstructions

Cifar10

Samples Reconstructions

Tiny ImageNet

Samples Reconstructions

Training

Use age.py script to train a model. Here are the most important parameters:

  • --dataset: one of [celeba, cifar10, imagenet, svhn, mnist]
  • --dataroot: for datasets included in torchvision it is a directory where everything will be downloaded to; for imagenet, celeba datasets it is a path to a directory with folders train and val inside.
  • --image_size:
  • --save_dir: path to a folder, where checkpoints will be stored
  • --nz: dimensionality of latent space
  • -- batch_size: Batch size. Default 64.
  • --netG: .py file with generator definition. Searched in models directory
  • --netE: .py file with generator definition. Searched in models directory
  • --netG_chp: path to a generator checkpoint to load from
  • --netE_chp: path to an encoder checkpoint to load from
  • --nepoch: number of epoch to run
  • --start_epoch: epoch number to start from. Useful for finetuning.
  • --e_updates: Update plan for encoder. <num steps>;KL_fake:<weight>,KL_real:<weight>,match_z:<weight>,match_x:<weight>.
  • --g_updates: Update plan for generator. <num steps>;KL_fake:<weight>,match_z:<weight>,match_x:<weight>.

And misc arguments:

  • --workers: number of dataloader workers.
  • --ngf: controlles number of channels in generator
  • --ndf: controlles number of channels in encoder
  • --beta1: parameter for ADAM optimizer
  • --cpu: do not use GPU
  • --criterion: Parametric param or non-parametric nonparam way to compute KL. Parametric fits Gaussian into data, non-parametric is based on nearest neighbors. Default: param.
  • --KL: What KL to compute: qp or pq. Default is qp.
  • --noise: sphere for uniform on sphere or gaussian. Default sphere.
  • --match_z: loss to use as reconstruction loss in latent space. L1|L2|cos. Default cos.
  • --match_x: loss to use as reconstruction loss in data space. L1|L2|cos. Default L1.
  • --drop_lr: each drop_lr epochs a learning rate is dropped.
  • --save_every: controls how often intermediate results are stored. Default 50.
  • --manual_seed: random seed. Default 123.

Here is cmd you can start with:

Celeba

Let data_root to be a directory with two folders train, val, each with the images for corresponding split.

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --lr 0.0002 --nz 64 --batch_size 64 --netG dcgan64px --netE dcgan64px --nepoch 5 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

It is beneficial to finetune the model with larger batch_size and stronger matching weight then:

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --start_epoch 5 --lr 0.0002 --nz 64 --batch_size 256 --netG dcgan64px --netE dcgan64px --nepoch 6 --drop_lr 5   --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' --g_updates '3;KL_fake:1,match_z:1000,match_x:0' --netE_chp  <save_dir>/netE_epoch_5.pth --netG_chp <save_dir>/netG_epoch_5.pth

Imagenet

python age.py --dataset imagenet --dataroot /path/to/imagenet_dir/ --save_dir <save_dir> --image_size 32 --save_dir ${pdir} --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 6 --drop_lr 3  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:2000,match_x:0' --workers 12

It can be beneficial to switch to 256 batch size after several epochs.

Cifar10

python age.py --dataset cifar10 --image_size 32 --save_dir <save_dir> --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 150 --drop_lr 40  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:1000,match_x:0'

Tested with python 2.7.

Implementation is based on pyTorch DCGAN code.

Citation

If you found this code useful please cite our paper

@inproceedings{DBLP:conf/aaai/UlyanovVL18,
  author    = {Dmitry Ulyanov and
               Andrea Vedaldi and
               Victor S. Lempitsky},
  title     = {It Takes (Only) Two: Adversarial Generator-Encoder Networks},
  booktitle = {{AAAI}},
  publisher = {{AAAI} Press},
  year      = {2018}
}
Owner
Dmitry Ulyanov
Co-Founder at in3D, Phd @ Skoltech
Dmitry Ulyanov
Adaptive Graph Convolution for Point Cloud Analysis

Adaptive Graph Convolution for Point Cloud Analysis This repository contains the implementation of AdaptConv for point cloud analysis. Adaptive Graph

64 Dec 21, 2022
Transport Mode detection - can detect the mode of transport with the help of features such as acceeration,jerk etc

title emoji colorFrom colorTo sdk app_file pinned Transport_Mode_Detector 🚀 purple yellow gradio app.py false Configuration title: string Display tit

Nishant Rajadhyaksha 3 Jan 16, 2022
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
Automatic deep learning for image classification.

AutoDL AutoDL automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few line

wenqi 2 Oct 12, 2022
Code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizability of Cross-Task Neural Architecture Search.

TransNAS-Bench-101 This repository contains the publishable code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizabili

Yawen Duan 17 Nov 20, 2022
Flexible Networks for Learning Physical Dynamics of Deformable Objects (2021)

Flexible Networks for Learning Physical Dynamics of Deformable Objects (2021) By Jinhyung Park, Dohae Lee, In-Kwon Lee from Yonsei University (Seoul,

Jinhyung Park 0 Jan 09, 2022
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Phil Wang 40 Dec 22, 2022
Black box hyperparameter optimization made easy.

BBopt BBopt aims to provide the easiest hyperparameter optimization you'll ever do. Think of BBopt like Keras (back when Theano was still a thing) for

Evan Hubinger 70 Nov 03, 2022
Geometric Deep Learning Extension Library for PyTorch

Documentation | Paper | Colab Notebooks | External Resources | OGB Examples PyTorch Geometric (PyG) is a geometric deep learning extension library for

Matthias Fey 16.5k Jan 08, 2023
Code for the paper "M2m: Imbalanced Classification via Major-to-minor Translation" (CVPR 2020)

M2m: Imbalanced Classification via Major-to-minor Translation This repository contains code for the paper "M2m: Imbalanced Classification via Major-to

79 Oct 13, 2022
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
A simple and useful implementation of LPIPS.

lpips-pytorch Description Developing perceptual distance metrics is a major topic in recent image processing problems. LPIPS[1] is a state-of-the-art

So Uchida 121 Dec 24, 2022
This project demonstrates the use of neural networks and computer vision to create a classifier that interprets the Brazilian Sign Language.

LIBRAS-Image-Classifier This project demonstrates the use of neural networks and computer vision to create a classifier that interprets the Brazilian

Aryclenio Xavier Barros 26 Oct 14, 2022
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models Codes for this paper The Lottery Tickets Hypo

VITA 59 Dec 28, 2022
Source code for 2021 ICCV paper "In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces"

In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces This is the PyTorch implementation for 2021 ICCV paper "In-the-Wild Single C

27 Dec 06, 2022
PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot

Progressive Growing of GANs inference in PyTorch with CelebA training snapshot Description This is an inference sample written in PyTorch of the origi

320 Nov 21, 2022
Deep Inertial Prediction (DIPr)

Deep Inertial Prediction For more information and context related to this repo, please refer to our website. Getting Started (non Docker) Note: you wi

Arcturus Industries 12 Nov 11, 2022
Dynamica causal Bayesian optimisation

Dynamic Causal Bayesian Optimization This is a Python implementation of Dynamic Causal Bayesian Optimization as presented at NeurIPS 2021. Abstract Th

nd308 18 Nov 22, 2022
KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control

KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control Tomas Jakab, Richard Tucker, Ameesh Makadia, Jiajun Wu, Noah Snavely, Angjoo Ka

Tomas Jakab 87 Nov 30, 2022