On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

Overview

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing Valentin Khrulkov, Leyla Mirvakhabova, Ivan Oseledets, Artem Babenko

Overview

We replace linear shifts commonly used for image editing with a flow of a trainable Neural ODE in the latent space.

w' = NN(w; \theta)

The RHS of this Neural ODE is trained end-to-end using pre-trained attribute regressors by enforcing

  • change of the desired attribute;
  • invariance of remaining attributes.

Installation and usage

Data

Data required to use the code is available at this dropbox link (2.5Gb).

Path Description
data data hosted on Dropbox
  ├  models pretrained GAN models and attribute regressors
  ├  log pretrained nonlinear edits (Neural ODEs of depth 1) for a variety of attributes on CUB, FFHQ, Places2
  ├  data_to_rectify 100,000 precomputed pairs (w, R[G[w]]); i.e., style vectors and corresponding semantic attributes
  ├  configs parameters of StyleGAN 2 generators for each dataset (n_mlp, channel_width, etc)
    └  inverses precomputed inverses (elements of W-plus) for sample FFHQ images

To download and unpack the data run get_data.sh.

Training

We used torch 1.7 for training; however, the code should work for lower versions as well. An example training script to rectify all the attributes:

CUDA_VISIBLE_DEVICES=0 python train_ode.py --dataset ffhq \
--nb-iter 5000 \
--alpha 8 \
--depth 1

For selected attributes:

CUDA_VISIBLE_DEVICES=0 python train_ode.py --dataset ffhq \
--nb-iter 5000 \
--alpha 8 \
--dir 4 8 15 16 23 32 \
--depth 1

Custom dataset

For training on a custom dataset, you have to provide

  • Generator and attribute regressor weights
  • a dictionary {dataset}_all.pt (stored in data_to_rectify). It has the form {"ws": ws, "labels" : labels} with ws being a torch.Tensor of size N x 512 and labels is a torch.Tensor of size N x D, with D being the number of semantic factors. labels should be constructed by evaluating the corresponding attribute regressor on synthetic images generator(ws[i]). It is used to sample batches for training.

Visualization

Please see explore.ipynb for example visualizations. lib.utils.py contains a utility wrapper useful for building and loading the Neural ODE models (FlowFactory).

Restoring from checkpoint

= 1 corresponds to an MLP with depth layers odeblock.load_state_dict(...) # some style vector (generator.style(z)) w0 = ... # You can directly call odeint with torch.no_grad(): odeint(odeblock.odefunc, w0, torch.FloatTensor([0, 1]).to(device)) # Or utilize the wrapper flow = LatentFlow(odefunc=odeblock.odefunc, device=device, name="Bald") flow.flow(w=w0, t=1) # To flow real images: w = torch.load("inverses/actors.pt").to(device) flow.flow(w, t=6, truncate_real=6) # truncate_real specifies which portion of a W-plus vector to modify # (e.g., first 6 our of 14 vectors) ">
import torch
from lib.utils import FlowFactory, LatentFlow
from torchdiffeq import odeint_adjoint as odeint
device = torch.device("cuda")
flow_factory = FlowFactory(dataset="ffhq", device=device)
odeblock = flow_factory._build_odeblock(depth=1)
# depth = -1 corresponds to a constant right hand side (w' = c)
# depth >= 1 corresponds to an MLP with depth layers
odeblock.load_state_dict(...)

# some style vector (generator.style(z))
w0 = ...

# You can directly call odeint
with torch.no_grad():
    odeint(odeblock.odefunc, w0, torch.FloatTensor([0, 1]).to(device))

# Or utilize the wrapper 
flow = LatentFlow(odefunc=odeblock.odefunc, device=device, name="Bald")
flow.flow(w=w0, t=1)

# To flow real images:
w = torch.load("inverses/actors.pt").to(device)
flow.flow(w, t=6, truncate_real=6)
# truncate_real specifies which portion of a W-plus vector to modify
# (e.g., first 6 our of 14 vectors)

A sample script to generate a movie is

CUDA_VISIBLE_DEVICES=0 python make_movie.py --attribute Bald --dataset ffhq

Examples

FFHQ

Bald Goatee Wavy_Hair Arched_Eyebrows
Bangs Young Blond_Hair Chubby

Places2

lush rugged fog

Citation

Coming soon.

Credits

Owner
Valentin Khrulkov
PhD student
Valentin Khrulkov
Reinforcement Learning Theory Book (rus)

Reinforcement Learning Theory Book (rus)

qbrick 206 Nov 27, 2022
Video Autoencoder: self-supervised disentanglement of 3D structure and motion

Video Autoencoder: self-supervised disentanglement of 3D structure and motion This repository contains the code (in PyTorch) for the model introduced

157 Dec 22, 2022
Clustering with variational Bayes and population Monte Carlo

pypmc pypmc is a python package focusing on adaptive importance sampling. It can be used for integration and sampling from a user-defined target densi

45 Feb 06, 2022
Official Matlab Implementation for "Tiny Obstacle Discovery by Occlusion-aware Multilayer Regression", TIP 2020

Tiny Obstacle Discovery by Occlusion-aware Multilayer Regression Official Matlab Implementation for "Tiny Obstacle Discovery by Occlusion-aware Multil

Xuefeng 5 Jan 15, 2022
A PyTorch version of You Only Look at One-level Feature object detector

PyTorch_YOLOF A PyTorch version of You Only Look at One-level Feature object detector. The input image must be resized to have their shorter side bein

Jianhua Yang 25 Dec 30, 2022
The Environment I built to study Reinforcement Learning + Pokemon Showdown

pokemon-showdown-rl-environment The Environment I built to study Reinforcement Learning + Pokemon Showdown Been a while since I ran this. Think it is

3 Jan 16, 2022
Attention-guided gan for synthesizing IR images

SI-AGAN Attention-guided gan for synthesizing IR images This repository contains the Tensorflow code for "Pedestrian Gender Recognition by Style Trans

1 Oct 25, 2021
Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.

Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles This project is for the paper: Detecting Errors and Estimating

Jiefeng Chen 13 Nov 21, 2022
Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks.

Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks. Generally, we intergrete different kind of functional

28 Jan 08, 2023
TensorFlow implementation of ENet, trained on the Cityscapes dataset.

segmentation TensorFlow implementation of ENet (https://arxiv.org/pdf/1606.02147.pdf) based on the official Torch implementation (https://github.com/e

Fredrik Gustafsson 248 Dec 16, 2022
A tensorflow implementation of GCN-LPA

GCN-LPA This repository is the implementation of GCN-LPA (arXiv): Unifying Graph Convolutional Neural Networks and Label Propagation Hongwei Wang, Jur

Hongwei Wang 83 Nov 28, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition

AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition [ArXiv] [Project Page] This repository is the official implementation of AdaMML:

International Business Machines 43 Dec 26, 2022
Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].

PLBART Code pre-release of our work, Unified Pre-training for Program Understanding and Generation accepted at NAACL 2021. Note. A detailed documentat

Wasi Ahmad 138 Dec 30, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods

pyLiDAR-SLAM This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods, which can easily be evaluated

Kitware, Inc. 208 Dec 16, 2022
A library for graph deep learning research

Documentation | Paper [JMLR] | Tutorials | Benchmarks | Examples DIG: Dive into Graphs is a turnkey library for graph deep learning research. Why DIG?

DIVE Lab, Texas A&M University 1.3k Jan 01, 2023
Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal Action Localization' (ICCV-21 Oral)

Learning-Action-Completeness-from-Points Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal A

Pilhyeon Lee 67 Jan 03, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022