Code for the paper "Adversarial Generator-Encoder Networks"

Related tags

Deep Learninggan
Overview

This repository contains code for the paper

"Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky.

Pretrained models

This is how you can access the models used to generate figures in the paper.

  1. First install dev version of pytorch 0.2 and make sure you have jupyter notebook ready.

  2. Then download the models with the script:

bash download_pretrained.sh
  1. Run jupyter notebook and go through evaluate.ipynb.

Here is an example of samples and reconstructions for imagenet, celeba and cifar10 datasets generated with evaluate.ipynb.

Celeba

Samples Reconstructions

Cifar10

Samples Reconstructions

Tiny ImageNet

Samples Reconstructions

Training

Use age.py script to train a model. Here are the most important parameters:

  • --dataset: one of [celeba, cifar10, imagenet, svhn, mnist]
  • --dataroot: for datasets included in torchvision it is a directory where everything will be downloaded to; for imagenet, celeba datasets it is a path to a directory with folders train and val inside.
  • --image_size:
  • --save_dir: path to a folder, where checkpoints will be stored
  • --nz: dimensionality of latent space
  • -- batch_size: Batch size. Default 64.
  • --netG: .py file with generator definition. Searched in models directory
  • --netE: .py file with generator definition. Searched in models directory
  • --netG_chp: path to a generator checkpoint to load from
  • --netE_chp: path to an encoder checkpoint to load from
  • --nepoch: number of epoch to run
  • --start_epoch: epoch number to start from. Useful for finetuning.
  • --e_updates: Update plan for encoder. <num steps>;KL_fake:<weight>,KL_real:<weight>,match_z:<weight>,match_x:<weight>.
  • --g_updates: Update plan for generator. <num steps>;KL_fake:<weight>,match_z:<weight>,match_x:<weight>.

And misc arguments:

  • --workers: number of dataloader workers.
  • --ngf: controlles number of channels in generator
  • --ndf: controlles number of channels in encoder
  • --beta1: parameter for ADAM optimizer
  • --cpu: do not use GPU
  • --criterion: Parametric param or non-parametric nonparam way to compute KL. Parametric fits Gaussian into data, non-parametric is based on nearest neighbors. Default: param.
  • --KL: What KL to compute: qp or pq. Default is qp.
  • --noise: sphere for uniform on sphere or gaussian. Default sphere.
  • --match_z: loss to use as reconstruction loss in latent space. L1|L2|cos. Default cos.
  • --match_x: loss to use as reconstruction loss in data space. L1|L2|cos. Default L1.
  • --drop_lr: each drop_lr epochs a learning rate is dropped.
  • --save_every: controls how often intermediate results are stored. Default 50.
  • --manual_seed: random seed. Default 123.

Here is cmd you can start with:

Celeba

Let data_root to be a directory with two folders train, val, each with the images for corresponding split.

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --lr 0.0002 --nz 64 --batch_size 64 --netG dcgan64px --netE dcgan64px --nepoch 5 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

It is beneficial to finetune the model with larger batch_size and stronger matching weight then:

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --start_epoch 5 --lr 0.0002 --nz 64 --batch_size 256 --netG dcgan64px --netE dcgan64px --nepoch 6 --drop_lr 5   --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' --g_updates '3;KL_fake:1,match_z:1000,match_x:0' --netE_chp  <save_dir>/netE_epoch_5.pth --netG_chp <save_dir>/netG_epoch_5.pth

Imagenet

python age.py --dataset imagenet --dataroot /path/to/imagenet_dir/ --save_dir <save_dir> --image_size 32 --save_dir ${pdir} --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 6 --drop_lr 3  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:2000,match_x:0' --workers 12

It can be beneficial to switch to 256 batch size after several epochs.

Cifar10

python age.py --dataset cifar10 --image_size 32 --save_dir <save_dir> --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 150 --drop_lr 40  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:1000,match_x:0'

Tested with python 2.7.

Implementation is based on pyTorch DCGAN code.

Citation

If you found this code useful please cite our paper

@inproceedings{DBLP:conf/aaai/UlyanovVL18,
  author    = {Dmitry Ulyanov and
               Andrea Vedaldi and
               Victor S. Lempitsky},
  title     = {It Takes (Only) Two: Adversarial Generator-Encoder Networks},
  booktitle = {{AAAI}},
  publisher = {{AAAI} Press},
  year      = {2018}
}
Owner
Dmitry Ulyanov
Co-Founder at in3D, Phd @ Skoltech
Dmitry Ulyanov
Interpretable and Generalizable Person Re-Identification with Query-Adaptive Convolution and Temporal Lifting

QAConv Interpretable and Generalizable Person Re-Identification with Query-Adaptive Convolution and Temporal Lifting This PyTorch code is proposed in

Shengcai Liao 166 Dec 28, 2022
Source code of the paper "Deep Learning of Latent Variable Models for Industrial Process Monitoring".

Source code of the paper "Deep Learning of Latent Variable Models for Industrial Process Monitoring".

Xiangyin Kong 7 Nov 08, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 03, 2022
An example project demonstrating how the Autonomous Learning Library can be used to build new reinforcement learning agents.

About This repository shows how Autonomous Learning Library can be used to build new reinforcement learning agents. In particular, it contains a model

Chris Nota 5 Aug 30, 2022
GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily

GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily Abstract Graph Neural Networks (GNNs) are widely used on a

10 Dec 20, 2022
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Xiaohao Xu 70 Dec 04, 2022
dyld_shared_cache processing / Single-Image loading for BinaryNinja

Dyld Shared Cache Parser Author: cynder (kat) Dyld Shared Cache Support for BinaryNinja Without any of the fuss of requiring manually loading several

cynder 76 Dec 28, 2022
Experiments on Flood Segmentation on Sentinel-1 SAR Imagery with Cyclical Pseudo Labeling and Noisy Student Training

Flood Detection Challenge This repository contains code for our submission to the ETCI 2021 Competition on Flood Detection (Winning Solution #2). Acco

Siddha Ganju 108 Dec 28, 2022
A simple pytorch pipeline for semantic segmentation.

SegmentationPipeline -- Pytorch A simple pytorch pipeline for semantic segmentation. Requirements : torch=1.9.0 tqdm albumentations=1.0.3 opencv-pyt

petite7 4 Feb 22, 2022
DGCNN - Dynamic Graph CNN for Learning on Point Clouds

DGCNN is the author's re-implementation of Dynamic Graph CNN, which achieves state-of-the-art performance on point-cloud-related high-level tasks including category classification, semantic segmentat

Wang, Yue 1.3k Dec 26, 2022
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Framework web SnakeServer.

SnakeServer - Framework Web 🐍 Documentação oficial do framework SnakeServer. Conteúdo Sobre Como contribuir Enviar relatórios de segurança Pull reque

Jaedson Silva 0 Jul 21, 2022
Model Quantization Benchmark

Introduction MQBench is an open-source model quantization toolkit based on PyTorch fx. The envision of MQBench is to provide: SOTA Algorithms. With MQ

500 Jan 06, 2023
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Using Tensorflow Object Detection API to detect Waymo open dataset

Waymo-2D-Object-Detection Using Tensorflow Object Detection API to detect Waymo open dataset Result CenterNet Training Loss SSD ResNet Training Loss C

76 Dec 12, 2022
[ICLR 2021] Is Attention Better Than Matrix Decomposition?

Enjoy-Hamburger 🍔 Official implementation of Hamburger, Is Attention Better Than Matrix Decomposition? (ICLR 2021) Under construction. Introduction T

Gsunshine 271 Dec 29, 2022
Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Medical-Transformer Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" About this repo: This repo

Jeya Maria Jose 615 Dec 25, 2022
Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems

AequeVox Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems README under development. Python Packages Required

Sai Sathiesh 2 Aug 28, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022