Net2net - Network-to-Network Translation with Conditional Invertible Neural Networks

Overview

Net2Net

Code accompanying the NeurIPS 2020 oral paper

Network-to-Network Translation with Conditional Invertible Neural Networks
Robin Rombach*, Patrick Esser*, Björn Ommer
* equal contribution

tl;dr Our approach distills the residual information of one model with respect to another's and thereby enables translation between fixed off-the-shelf expert models such as BERT and BigGAN without having to modify or finetune them.

teaser arXiv | BibTeX | Project Page

News Dec 19th, 2020: added SBERT-to-BigGAN, SBERT-to-BigBiGAN and SBERT-to-AE (COCO)

Requirements

A suitable conda environment named net2net can be created and activated with:

conda env create -f environment.yaml
conda activate net2net

Datasets

  • CelebA: Create a symlink 'data/CelebA' pointing to a folder which contains the following files:
    .
    ├── identity_CelebA.txt
    ├── img_align_celeba
    ├── list_attr_celeba.txt
    └── list_eval_partition.txt
    
    These files can be obtained here.
  • CelebA-HQ: Create a symlink data/celebahq pointing to a folder containing the .npy files of CelebA-HQ (instructions to obtain them can be found in the PGGAN repository).
  • FFHQ: Create a symlink data/ffhq pointing to the images1024x1024 folder obtained from the FFHQ repository.
  • Anime Faces: First download the face images from the Anime Crop dataset and then apply the preprocessing of FFHQ to those images. We only keep images where the underlying dlib face recognition model recognizes a face. Finally, create a symlink data/anime which contains the processed anime face images.
  • Oil Portraits: Download here. Unpack the content and place the files in data/portraits. It consists of 18k oil portraits, which were obtained by running dlib on a subset of the WikiArt dataset dataset, kindly provided by A Style-Aware Content Loss for Real-time HD Style Transfer.
  • COCO: Create a symlink data/coco containing the images from the 2017 split in train2017 and val2017, and their annotations in annotations. Files can be obtained from the COCO webpage.

ML4Creativity Demo

We include a streamlit demo, which utilizes our approach to demonstrate biases of datasets and their creative applications. More information can be found in our paper A Note on Data Biases in Generative Models from the Machine Learning for Creativity and Design at NeurIPS 2020. Download the models from

and place them into logs. Run the demo with

streamlit run ml4cad.py

Training

Our code uses Pytorch-Lightning and thus natively supports things like 16-bit precision, multi-GPU training and gradient accumulation. Training details for any model need to be specified in a dedicated .yaml file. In general, such a config file is structured as follows:

model:
  base_learning_rate: 4.5e-6
  target: 
   
    
  params:
    ...
data:
  target: translation.DataModuleFromConfig
  params:
    batch_size: ...
    num_workers: ...
    train:
      target: 
    
     
      params:
        ...
    validation:
      target: 
     
      
      params:
        ...

     
    
   

Any Pytorch-Lightning model specified under model.target is then trained on the specified data by running the command:

python translation.py --base 
   
     -t --gpus 0,

   

All available Pytorch-Lightning trainer arguments can be added via the command line, e.g. run

python translation.py --base 
   
     -t --gpus 0,1,2,3 --precision 16 --accumulate_grad_batches 2

   

to train a model on 4 GPUs using 16-bit precision and a 2-step gradient accumulation. More details are provided in the examples below.

Training a cINN

Training a cINN for network-to-network translation usually utilizes the Lighnting Module net2net.models.flows.flow.Net2NetFlow and makes a few further assumptions on the configuration file and model interface:

model:
  base_learning_rate: 4.5e-6
  target: net2net.models.flows.flow.Net2NetFlow
  params:
    flow_config:
      target: 
   
    
      params:
        ...

    cond_stage_config:
      target: 
    
     
      params:
        ...

    first_stage_config:
      target: 
     
      
      params:
        ...

     
    
   

Here, the entries under flow_config specifies the architecture and parameters of the conditional INN; cond_stage_config specifies the first network whose representation is to be translated into another network specified by first_stage_config. Our model net2net.models.flows.flow.Net2NetFlow expects that the first
network has a .encode() method which produces the representation of interest, while the second network should have an encode() and a decode() method, such that both of them applied sequentially produce the networks output. This allows for a modular combination of arbitrary models of interest. For more details, see the examples below.

Training a cINN - Superresolution

superres Training details for a cINN to concatenate two autoencoders from different image scales for stochastic superresolution are specified in configs/translation/faces32-to-256.yaml.

To train a model for translating from 32 x 32 images to 256 x 256 images on GPU 0, run

python translation.py --base configs/translation/faces32-to-faces256.yaml -t --gpus 0, 

and specify any additional training commands as described above. Note that this setup requires two pretrained autoencoder models, one on 32 x 32 images and the other on 256 x 256. If you want to train them yourself on a combination of FFHQ and CelebA-HQ, run

python translation.py --base configs/autoencoder/faces32.yaml -t --gpus 
   
    , 

   

for the 32 x 32 images; and

python translation.py --base configs/autoencoder/faces256.yaml -t --gpus 
   
    , 

   

for the model on 256 x 256 images. After training, adopt the corresponding model paths in configs/translation/faces32-to-faces256.yaml. Additionally, we provide weights of pretrained autoencoders for both settings: Weights 32x32; Weights256x256. To run the training as described above, put them into logs/2020-10-16T17-11-42_FacesFQ32x32/checkpoints/last.ckptand logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt, respectively.

Training a cINN - Unpaired Translation

superres All training scenarios for unpaired translation are specified in the configs in configs/creativity. We provide code and pretrained autoencoder models for three different translation tasks:

  • AnimePhotography; see configs/creativity/anime_photography_256.yaml. Download autoencoder checkpoint (Download Anime+Photography) and place into logs/2020-09-30T21-40-22_AnimeAndFHQ/checkpoints/epoch=000007.ckpt.
  • Oil-PortraitPhotography; see configs/creativity/portraits_photography_256.yaml Download autoencoder checkpoint (Download Portrait+Photography) and place into logs/2020-09-29T23-47-10_PortraitsAndFFHQ/checkpoints/epoch=000004.ckpt.
  • FFHQCelebA-HQCelebA; see configs/creativity/celeba_celebahq_ffhq_256.yaml Download autoencoder checkpoint (Download FFHQ+CelebAHQ+CelebA) and place into logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt. Note that this is the same autoencoder checkpoint as for the stochastic superresolution experiment.

To train a cINN on one of these unpaired transfer tasks using the first GPU, simply run

python translation.py --base configs/creativity/
   
    .yaml -t --gpus 0,

   

where .yaml is one of portraits_photography_256.yaml, celeba_celebahq_ffhq_256.yaml or anime_photography_256.yaml. Providing additional arguments to the pytorch-lightning trainer object is also possible as described above.

In our framework, unpaired translation between domains is formulated as a translation between expert 1, a model which can infer the domain a given image belongs to, and expert 2, a model which can synthesize images of each domain. In the examples provided, we assume that the domain label comes with the dataset and provide the net2net.modules.labels.model.Labelator module, which simply returns a one hot encoding of this label. However, one could also use a classification model which infers the domain label from the image itself. For expert 2, our examples use an autoencoder trained jointly on all domains, which is easily achieved by concatenating datasets together. The provided net2net.data.base.ConcatDatasetWithIndex concatenates datasets and returns the corresponding dataset label for each example, which can then be used by the Labelator class for the translation. The training configurations for the autoencoders used in the creativity experiments are included in configs/autoencoder/anime_photography_256.yaml, configs/autoencoder/celeba_celebahq_ffhq_256.yaml and configs/autoencoder/portraits_photography_256.yaml.

Unpaired Translation on Custom Datasets

Create pytorch datasets for each of your domains, create a concatenated dataset with ConcatDatasetWithIndex (follow the example in net2net.data.faces.CCFQTrain), train an autoencoder on the concatenated dataset (adjust the data section in configs/autoencoder/celeba_celebahq_ffhq_256.yaml) and finally train a net2net translation model between a Labelator and your autoencoder (adjust the sections data and first_stage_config in configs/creativity/celeba_celebahq_ffhq_256.yaml). You can then also add your new model to the available modes in the ml4cad.py demo to visualize the results.

Training a cINN - Text-to-Image

texttoimage We provide code to obtain a text-to-image model by translating between a text model (SBERT) and an image decoder. To show the flexibility of our approach, we include code for three different decoders: BigGAN, as described in the paper, BigBiGAN, which is only available as a tensorflow model and thus nicely shows how our approach can work with black-box experts, and an autoencoder.

SBERT-to-BigGAN

Train with

python translation.py --base configs/translation/sbert-to-biggan256.yaml -t --gpus 0,

When running it for the first time, the required models will be downloaded automatically.

SBERT-to-BigBiGAN

Since BigBiGAN is only available on tensorflow-hub, this example has an additional dependency on tensorflow. A suitable environment is provided in env_bigbigan.yaml, and you will need COCO for training. You can then start training with

python translation.py --base configs/translation/sbert-to-bigbigan.yaml -t --gpus 0,

Note that the BigBiGAN class is just a naive wrapper, which converts pytorch tensors to numpy arrays, feeds them to the tensorflow graph and again converts the result to pytorch tensors. It does not require gradients of the expert model and serves as a good example on how to use black-box experts.

SBERT-to-AE

Similarly to the other examples, you can also train your own autoencoder on COCO with

python translation.py --base configs/autoencoder/coco256.yaml -t --gpus 0,

or download a pre-trained one, and translate to it by running

python translation.py --base configs/translation/sbert-to-ae-coco256.yaml -t --gpus 0,

Shout-outs

Thanks to everyone who makes their code and models available.

BibTeX

@misc{rombach2020networktonetwork,
      title={Network-to-Network Translation with Conditional Invertible Neural Networks},
      author={Robin Rombach and Patrick Esser and Björn Ommer},
      year={2020},
      eprint={2005.13580},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{esser2020note,
      title={A Note on Data Biases in Generative Models}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.02516},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
CompVis Heidelberg
Computer Vision research group at the Ruprecht-Karls-University Heidelberg
CompVis Heidelberg
Viperdb - A tiny log-structured key-value database written in pure Python

ViperDB 🐍 ViperDB is a lightweight embedded key-value store written in pure Pyt

17 Oct 17, 2022
Multilingual Image Captioning

Multilingual Image Captioning Authors: Bhavitvya Malik, Gunjan Chhablani Demo Link: https://huggingface.co/spaces/flax-community/multilingual-image-ca

Gunjan Chhablani 32 Nov 25, 2022
Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far Can We Go?" submitted to TOSEM

tosem2021-personality-rep-package Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far

Collaborative Development Group 1 Dec 13, 2021
[ICCV2021] Official code for "Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition"

CTR-GCN This repo is the official implementation for Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition. The pap

Yuxin Chen 148 Dec 16, 2022
Julia and Matlab codes to simulated all problems in El-Hachem, McCue and Simpson (2021)

Substrate_Mediated_Invasion Julia and Matlab codes to simulated all problems in El-Hachem, McCue and Simpson (2021) 2DSolver.jl reproduces the simulat

Matthew Simpson 0 Nov 09, 2021
Official Implementation (PyTorch) of "Point Cloud Augmentation with Weighted Local Transformations", ICCV 2021

PointWOLF: Point Cloud Augmentation with Weighted Local Transformations This repository is the implementation of PointWOLF(To appear). Sihyeon Kim1*,

MLV Lab (Machine Learning and Vision Lab at Korea University) 16 Nov 03, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

🦩 Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 630 Dec 28, 2022
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

76 Jan 03, 2023
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
Measuring if attention is explanation with ROAR

NLP ROAR Interpretability Official code for: Evaluating the Faithfulness of Importance Measures in NLP by Recursively Masking Allegedly Important Toke

Andreas Madsen 19 Nov 13, 2022
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
Deep learning algorithms for muon momentum estimation in the CMS Trigger System

Deep learning algorithms for muon momentum estimation in the CMS Trigger System The Compact Muon Solenoid (CMS) is a general-purpose detector at the L

anuragB 2 Oct 06, 2021
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
This repository contains demos I made with the Transformers library by HuggingFace.

Transformers-Tutorials Hi there! This repository contains demos I made with the Transformers library by 🤗 HuggingFace. Currently, all of them are imp

3.5k Jan 01, 2023
A Pytorch implementation of MoveNet from Google. Include training code and pre-train model.

Movenet.Pytorch Intro MoveNet is an ultra fast and accurate model that detects 17 keypoints of a body. This is A Pytorch implementation of MoveNet fro

Mr.Fire 241 Dec 26, 2022
Code release for Universal Domain Adaptation(CVPR 2019)

Universal Domain Adaptation Code release for Universal Domain Adaptation(CVPR 2019) Requirements python 3.6+ PyTorch 1.0 pip install -r requirements.t

THUML @ Tsinghua University 229 Dec 23, 2022
code for Grapadora research paper experimentation

Road feature embedding selection method Code for research paper experimentation Abstract Traffic forecasting models rely on data that needs to be sens

Eric López Manibardo 0 May 26, 2022
Implementation of the SUMO (Slim U-Net trained on MODA) model

SUMO - Slim U-Net trained on MODA Implementation of the SUMO (Slim U-Net trained on MODA) model as described in: TODO: add reference to paper once ava

6 Nov 19, 2022
Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis [Paper] [Online Demo] The following results are obtained by our SCUNet with purely syn

Kai Zhang 312 Jan 07, 2023
Official repository of "BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment"

BasicVSR_PlusPlus (CVPR 2022) [Paper] [Project Page] [Code] This is the official repository for BasicVSR++. Please feel free to raise issue related to

Kelvin C.K. Chan 227 Jan 01, 2023