Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

Overview

HiT-GAN Official TensorFlow Implementation

HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "Improved Transformer for High-Resolution GANs" for more details.

This implementation is based on TensorFlow 2.x. We use tf.keras layers for building the model and use tf.data for our input pipeline. The model is trained using a custom training loop with tf.distribute on multiple TPUs/GPUs.

Environment setup

It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.

pip install -r requirements.txt

ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from the official guide.

Train on ImageNet

To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

To train the model on ImageNet with multiple GPUs, try the following command:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=False

Please set train_batch_size according to the number of GPUs for training. Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (--use_ema_model=False), so training with GPUs will lead to slight performance drop.

Evaluate on ImageNet

Run the following command to evaluate the model on GPUs:

python run.py --mode=eval --dataset=imagenet2012 \
  --eval_batch_size=128 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --latent_dim=256 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

This command runs models with 8 P100 GPUs. Please set eval_batch_size according to the number of GPUs for evaluation. Please also note that train_steps and use_ema_model should be set according to the values used for training.

CelebA-HQ

At the first time, download CelebA-HQ following tensorflow_datasets instruction from the official guide.

Train on CelebA-HQ

The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=celeb_a_hq/256 \
  --train_batch_size=256 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on CelebA-HQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=celeb_a_hq/256 \
  --eval_batch_size=128 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

FFHQ

At the first time, download the tfrecords of FFHQ from the official site and put them into $DATA_DIR.

Train on FFHQ

The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=ffhq/256 \
  --train_batch_size=256 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on FFHQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=ffhq/256 \
  --eval_batch_size=128 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

Cite

@inproceedings{zhao2021improved,
  title = {Improved Transformer for High-Resolution {GANs}},
  author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas abd Han Zhang},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}

Disclaimer

This is not an officially supported Google product.

LaneAF: Robust Multi-Lane Detection with Affinity Fields

LaneAF: Robust Multi-Lane Detection with Affinity Fields This repository contains Pytorch code for training and testing LaneAF lane detection models i

155 Dec 17, 2022
[ArXiv 2021] Data-Efficient Instance Generation from Instance Discrimination

InsGen - Data-Efficient Instance Generation from Instance Discrimination Data-Efficient Instance Generation from Instance Discrimination Ceyuan Yang,

GenForce: May Generative Force Be with You 93 Dec 25, 2022
implementation for paper "ShelfNet for fast semantic segmentation"

ShelfNet-lightweight for paper (ShelfNet for fast semantic segmentation) This repo contains implementation of ShelfNet-lightweight models for real-tim

Juntang Zhuang 252 Sep 16, 2022
A New Approach to Overgenerating and Scoring Abstractive Summaries

We provide the source code for the paper "A New Approach to Overgenerating and Scoring Abstractive Summaries" accepted at NAACL'21. If you find the code useful, please cite the following paper.

Kaiqiang Song 4 Apr 03, 2022
Colar: Effective and Efficient Online Action Detection by Consulting Exemplars, CVPR 2022.

Colar: Effective and Efficient Online Action Detection by Consulting Exemplars This repository is the official implementation of Colar. In this work,

LeYang 246 Dec 13, 2022
A few stylization coreML models that I've trained with CreateML

CoreML-StyleTransfer A few stylization coreML models that I've trained with CreateML You can open and use the .mlmodel files in the "models" folder in

Doron Adler 8 Aug 18, 2022
CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss

CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss This is official implement of "

程星 87 Dec 24, 2022
FaceVerse: a Fine-grained and Detail-controllable 3D Face Morphable Model from a Hybrid Dataset (CVPR2022)

FaceVerse FaceVerse: a Fine-grained and Detail-controllable 3D Face Morphable Model from a Hybrid Dataset Lizhen Wang, Zhiyuan Chen, Tao Yu, Chenguang

Lizhen Wang 219 Dec 28, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

Terbe Dániel 138 Dec 17, 2022
Official Pytorch implementation of RePOSE (ICCV2021)

RePOSE: Iterative Rendering and Refinement for 6D Object Detection (ICCV2021) [Link] Abstract We present RePOSE, a fast iterative refinement method fo

Shun Iwase 68 Nov 15, 2022
Official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo'

IterMVS official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo' Introduction IterMVS is a novel lear

Fangjinhua Wang 127 Jan 04, 2023
End-to-End Speech Processing Toolkit

ESPnet: end-to-end speech processing toolkit system/pytorch ver. 1.3.1 1.4.0 1.5.1 1.6.0 1.7.1 1.8.1 1.9.0 ubuntu20/python3.9/pip ubuntu20/python3.8/p

ESPnet 5.9k Jan 04, 2023
Numba-accelerated Pythonic implementation of MPDATA with examples in Python, Julia and Matlab

PyMPDATA PyMPDATA is a high-performance Numba-accelerated Pythonic implementation of the MPDATA algorithm of Smolarkiewicz et al. used in geophysical

Atmospheric Cloud Simulation Group @ Jagiellonian University 15 Nov 23, 2022
GBIM(Gesture-Based Interaction map)

手势交互地图 GBIM(Gesture-Based Interaction map),基于视觉深度神经网络的交互地图,通过电脑摄像头观察使用者的手势变化,进而控制地图进行简单的交互。网络使用PaddleX提供的轻量级模型PPYOLO Tiny以及MobileNet V3 small,使得整个模型大小约10MB左右,即使在CPU下也能快速定位和识别手势。

8 Feb 10, 2022
Rename Images with Auto Generated Neural Image Captions

Recaption Images with Generated Neural Image Caption Example Usage: Commandline: Recaption all images from folder /home/feng/Downloads/images to folde

feng wang 3 May 01, 2022
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 08, 2022
Learned model to estimate number of distinct values (NDV) of a population using a small sample.

Learned NDV estimator Learned model to estimate number of distinct values (NDV) of a population using a small sample. The model approximates the maxim

2 Nov 21, 2022
MADT: Offline Pre-trained Multi-Agent Decision Transformer

MADT: Offline Pre-trained Multi-Agent Decision Transformer A link to our paper can be found on Arxiv. Overview Official codebase for Offline Pre-train

Linghui Meng 51 Dec 21, 2022
A Weakly Supervised Amodal Segmenter with Boundary Uncertainty Estimation

Paper Khoi Nguyen, Sinisa Todorovic "A Weakly Supervised Amodal Segmenter with Boundary Uncertainty Estimation", accepted to ICCV 2021 Our code is mai

Khoi Nguyen 5 Aug 14, 2022