Skip to content

nupurkmr9/vision-aided-gan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision-aided GAN

PWC PWC PWC

[NEW!] Vision-aided GAN training with BigGAN and StyleGAN3

[NEW!] Using vision-aided Discriminator in your own GAN training with pip install vision-aided-loss


Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
In CVPR 2022

Quantitative Comparison


Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only 0.7% of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.

Example Results

Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the same randomly sample latent code on 100-shot Bridge-of-sighs and AnimalFace Dog dataset.

Interpolation Videos

Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).

Worst sample visualzation

We randomly sample 5k images and sort them according to Mahalanobis distance using mean and variance of real samples calculated in inception feature space. Below visualization shows the bottom 30 images according to the distance for StyleGAN2-ADA (left) and our model (right).

AFHQ Dog

AFHQ Cat

AFHQ Wild

Pretrained Models

StyleGAN2 models

BigGAN models

All pre-trained models can be downloaded at this link as well.

Vision-aided StyleGAN2 training

Please see stylegan2 README for training StyleGAN2 models with our method. This code will reproduce all StyleGAN2 based results from our paper.

Vision-aided Discriminator in a custom GAN model

install the library via pip install git+https://github.com/nupurkmr9/vision-aided-gan.git or pip install vision-aided-loss

For details on off-the-shelf models please see MODELS.md

import vision_aided_loss

device='cuda'
discr = vision_aided_loss.Discriminator(cv_type='clip', loss_type='multilevel_sigmoid_s', device=device).to(device)
discr.cv_ensemble.requires_grad_(False) # Freeze feature extractor

# Sample images
real = sample_real_image()
fake = G.forward(z)

# Update discriminator discr
lossD = discr(real, for_real=True) + discr(fake, for_real=False)
lossD.backward()

# Update generator G
lossG = discr(fake, for_G=True)
lossG.backward()

# We recommend adding vision-aided adversarial loss after training GAN with standard loss till few warmup_iter.

Arg details:

  • cv_type: name of the off-the-shelf model from [clip, dino, swin, vgg, det_coco, seg_ade, face_seg, face_normals]. Multiple models can be used with '+' separated model names.
  • output_type: output feature type from off-the-shelf models. should be one of [conv, conv_multi_level]. Supports conv_multi_level only for clip and dino. For multiple models output_type should be '+' separated output_type for each model.
  • diffaug: if True performs DiffAugment on vision-aided discriminator with poilcy color,translation,cutout. Recommended to keep this as True.
  • num_classes: for conditional training use num_classes>0. Projection discriminator is used similar to BigGAN.
  • loss_type: should be one of [sigmoid, multilevel_sigmoid, sigmoid_s, multilevl_sigmoid_s, hinge, multilevel_hinge]. Appeding _s enables label smoothing. If loss_type is None output is a list of logits corresponding to each vision-aided discriminator.
  • device: device for off-the-shelf model weights.

Vision-aided StyleGAN3 training

Please see stylegan3 README for training StyleGAN3 models with our method.

Vision-aided BigGAN training

Please see biggan README for training BigGAN models with our method.

To add you own pretrained Model

create the class file to extract pretrained features as vision_module/<custom_model>.py. Add the class path in the class_name_dict in vision_module.cvmodel.CVBackbone class. Update the architecture of trainable classifier head over pretrained features in vision_module.cv_discriminator. Reinstall library manually via pip install .

References

@InProceedings{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month     = {June},
  year      = {2022}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.