CVPR2021 Content-Aware GAN Compression

Overview

Content-Aware GAN Compression [ArXiv]

Paper accepted to CVPR2021.

@inproceedings{liu2021content,
  title     = {Content-Aware GAN Compression},
  author    = {Liu, Yuchen and Shu, Zhixin and Li, Yijun and Lin, Zhe and Perazzi, Federico and Kung, S.Y.},
  booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2021},
}

Overview

We propose a novel content-aware approach for GAN compression. With content-awareness, our 11x-accelerated GAN performs comparably with the full-size model on image generation and image editing.

Image Generation

We show an example above on the generative ability of our 11x-accelerated generator vs. the full-size one. In particular, our model generates the interested contents visually comparable to the full-size model.

Image Editing

We show an example typifying the effectiveness of our compressed StyleGAN2 for image style-mixing and morphing above. When we mix middle styles from B, the original full-size model has a significant identity loss, while our approach better preserves the person’s identity. We also observe that our morphed images have a smoother expression transition compared the full-size model in the beard, substantiating our advantage in latent space smoothness.

We provide an additional example above.

Methodology

In our work, we make the first attempt to bring content awareness into channel pruning and knowledge distillation.

Specifically, we leverage a content-parsing network to identify contents of interest (COI), a set of spatial locations with salient semantic concepts, within the generated images. We design a content-aware pruning metric (with a forward and backward path) to remove channels that are least sensitive to COI in the generated images. For knowledge distillation, we focus our distillation region only to COI of the teacher’s outputs which further enhances target contents’ distillation.

Usage

Prerequisite

We have tested our codes under the following environments:

python == 3.6.5
pytorch == 1.6.0
torchvision == 0.7.0
CUDA == 10.2

Pretrained Full-Size Generator Checkpoint

To start with, you can first download a full-size generator checkpoint from:

256px StyleGAN2

1024px StyleGAN2

and place it under the folder ./Model/full_size_model/.

Pruning

Once you get the full-size checkpoint, you can prune the generator by:

python3 prune.py \
	--generated_img_size=256 \
	--ckpt=/path/to/full/size/model/ \
	--remove_ratio=0.7 \
	--info_print

We adopt a uniform channel pruning ratio for every layer. Above procedure will remove 70% of channels from the generator in each layer. The pruned checkpoint will be saved at ./Model/pruned_model/.

Retraining

We then retrain the pruned generator by:

python3 train.py \
	--size=256 \
	--path=/path/to/ffhq/data/folder/ \
	--ckpt=/path/to/pruned/model/ \
	--teacher_ckpt=/path/to/full/size/model/ \
	--iter=450001 \
	--batch_size=16

You may adjust the variables gpu_device_ids and primary_device for the GPU setup in train_hyperparams.py.

Training Log

The time for retraining 11x-compressed models on V100 GPUs:

Model Batch Size Iterations # GPUs Time (Hour)
256px StyleGAN2 16 450k 2 131
1024px StyleGAN2 16 450k 4 251

A typical training curve for the 11x-compressed 256px StyleGAN2:

Evaluation

To evaluate the model quantitatively, we provide get_fid.py and get_ppl.py to get model's FID and PPL sores.

FID Evaluation:

python3 get_fid.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--n_sample=50000 \
	--batch_size=64 \
	--info_print

PPL Evaluation:

python3 get_ppl.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--n_sample=5000 \
	--eps=1e-4 \
	--info_print

We also provide an image projector which return a (real image, projected image) pair in Image_Projection_Visualization.png as well as the PSNR and LPIPS score between this pair:

python3 get_projected_image.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--image_file=/path/to/an/RGB/image/ \
	--num_iters=800 \
	--info_print

An example of Image_Projection_Visualization.png projected by a full-size 256px StyleGAN2:

Helen-Set55

We provide the Helen-Set55 on Google Drive.

11x-Accelerated Generator Checkpoint

We provide the following checkpoints of our content-aware compressed StyleGAN2:

Compressed 256px StyleGAN2

Compressed 1024px StyleGAN2

Acknowledgement

PyTorch StyleGAN2: https://github.com/rosinality/stylegan2-pytorch

Face Parsing BiSeNet: https://github.com/zllrunning/face-parsing.PyTorch

Fréchet Inception Distance: https://github.com/mseitzer/pytorch-fid

Learned Perceptual Image Patch Similarity: https://github.com/richzhang/PerceptualSimilarity

Owner
Yuchen Liu, Ph.D. Candidate at Princeton University
NCVX (NonConVeX): A User-Friendly and Scalable Package for Nonconvex Optimization in Machine Learning.

The source code is temporariy removed, as we are solving potential copyright and license issues with GRANSO (http://www.timmitchell.com/software/GRANS

SUN Group @ UMN 28 Aug 03, 2022
RRxIO - Robust Radar Visual/Thermal Inertial Odometry: Robust and accurate state estimation even in challenging visual conditions.

RRxIO - Robust Radar Visual/Thermal Inertial Odometry RRxIO offers robust and accurate state estimation even in challenging visual conditions. RRxIO c

Christopher Doer 64 Dec 29, 2022
Adversarial Texture Optimization from RGB-D Scans (CVPR 2020).

AdversarialTexture Adversarial Texture Optimization from RGB-D Scans (CVPR 2020). Scanning Data Download Please refer to data directory for details. B

Jingwei Huang 153 Nov 28, 2022
Joint Learning of 3D Shape Retrieval and Deformation, CVPR 2021

Joint Learning of 3D Shape Retrieval and Deformation Joint Learning of 3D Shape Retrieval and Deformation Mikaela Angelina Uy, Vladimir G. Kim, Minhyu

Mikaela Uy 38 Oct 18, 2022
Boosted CVaR Classification (NeurIPS 2021)

Boosted CVaR Classification Runtian Zhai, Chen Dan, Arun Sai Suggala, Zico Kolter, Pradeep Ravikumar NeurIPS 2021 Table of Contents Quick Start Train

Runtian Zhai 4 Feb 15, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
O-CNN: Octree-based Convolutional Neural Networks for 3D Shape Analysis

O-CNN This repository contains the implementation of our papers related with O-CNN. The code is released under the MIT license. O-CNN: Octree-based Co

Microsoft 607 Dec 28, 2022
Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution

Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution Abstract Within the Latin (and ancient Greek) production, it is well

4 Dec 03, 2022
Public implementation of "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression" from CoRL'21

Self-Supervised Reward Regression (SSRR) Codebase for CoRL 2021 paper "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression "

19 Dec 12, 2022
VL-LTR: Learning Class-wise Visual-Linguistic Representation for Long-Tailed Visual Recognition

VL-LTR: Learning Class-wise Visual-Linguistic Representation for Long-Tailed Visual Recognition Usage First, install PyTorch 1.7.1+, torchvision 0.8.2

40 Dec 12, 2022
Breast cancer is been classified into benign tumour and malignant tumour.

Breast cancer is been classified into benign tumour and malignant tumour. Logistic regression is applied in this model.

1 Feb 04, 2022
[ICCV 2021] Deep Hough Voting for Robust Global Registration

Deep Hough Voting for Robust Global Registration, ICCV, 2021 Project Page | Paper | Video Deep Hough Voting for Robust Global Registration Junha Lee1,

57 Nov 28, 2022
Traductor de lengua de señas al español basado en Python con Opencv y MedaiPipe

Traductor de señas Traductor de lengua de señas al español basado en Python con Opencv y MedaiPipe Requerimientos 🔧 Python 3.8 o inferior para evitar

Jahaziel Hernandez Hoyos 3 Nov 12, 2022
Learning to Communicate with Deep Multi-Agent Reinforcement Learning in PyTorch

Learning to Communicate with Deep Multi-Agent Reinforcement Learning This is a PyTorch implementation of the original Lua code release. Overview This

Minqi 297 Dec 12, 2022
Topic Modelling for Humans

gensim – Topic Modelling in Python Gensim is a Python library for topic modelling, document indexing and similarity retrieval with large corpora. Targ

RARE Technologies 13.8k Jan 03, 2023
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmenta

NVIDIA Research Projects 3.2k Dec 30, 2022
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.

aft-pytorch Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc. Installation You can i

Rishabh Anand 184 Dec 12, 2022
deep-prae

Deep Probabilistic Accelerated Evaluation (Deep-PrAE) Our work presents an efficient rare event simulation methodology for black box autonomy using Im

Safe AI Lab 4 Apr 17, 2021
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022