Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Overview

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Paper

Description

Recent research has shown that numerous human-interpretable directions exist in the latent space of GANs. In this paper, we develop an automatic procedure for finding directions that lead to foreground-background image separation, and we use these directions to train an image segmentation model without human supervision. Our method is generator-agnostic, producing strong segmentation results with a wide range of different GAN architectures. Furthermore, by leveraging GANs pretrained on large datasets such as ImageNet, we are able to segment images from a range of domains without further training or finetuning. Evaluating our method on image segmentation benchmarks, we compare favorably to prior work while using neither human supervision nor access to the training data. Broadly, our results demonstrate that automatically extracting foreground-background structure from pretrained deep generative models can serve as a remarkably effective substitute for human supervision.

How to run

Dependencies

This code depends on pytorch-pretrained-gans, a repository I developed that exposes a standard interface for a variety of pretrained GANs. Install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

The pretrained weights for most GANs are downloaded automatically. For those that are not, I have provided scripts in that repository.

There are also some standard dependencies:

Install them with:

pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia

General Approach

Our unsupervised segmentation approach has two steps: (1) finding a good direction in latent space, and (2) training a segmentation model from data and masks that are generated using this direction.

In detail, this means:

  1. We use optimization/main.py finds a salient direction (or two salient directions) in the latent space of a given pretrained GAN that leads to foreground-background image separation.
  2. We use segmentation/main.py to train a standard segmentation network (a UNet) on generated data. The data can be generated in two ways: (1) you can generate the images on-the-fly during training, or (2) you can generate the images before training the segmentation model using segmentation/generate_and_save.py and then train the segmentation network afterward. The second approach is faster, but requires more disk space (~10GB for 1 million images). We will also provide a pre-generated dataset (coming soon).

Configuration and Logging

We use Hydra for configuration and Weights and Biases for logging. With Hydra, you can specify a config file (found in configs/) with --config-name=myconfig.yaml. You can also override the config from the command line by specifying the overriding arguments (without --). For example, you can enable Weights and Biases with wandb=True and you can name the run with name=myname.

The structure of the configs is as follows:

config
├── data_gen
│   ├── generated.yaml  # <- for generating data with 1 latent direction
│   ├── generated-dual.yaml   # <- for generating data with 2 latent directions
│   ├── generator  # <- different types of GANs for generating data
│   │   ├── bigbigan.yaml
│   │   ├── pretrainedbiggan.yaml
│   │   ├── selfconditionedgan.yaml
│   │   ├── studiogan.yaml
│   │   └── stylegan2.yaml 
│   └── saved.yaml  # <- for using pre-generated data
├── optimize.yaml  # <- for optimization
└── segment.yaml   # <- for segmentation

Code Structure

The code is structured as follows:

src
├── models  # <- segmentation model
│   ├── __init__.py
│   ├── latent_shift_model.py  # <- shifts direction in latent space
│   ├── unet_model.py  # <- segmentation model
│   └── unet_parts.py
├── config  # <- configuration, explained above
│   ├── ... 
├── datasets  # <- classes for loading datasets during segmentation/generation
│   ├── __init__.py
│   ├── gan_dataset.py  # <- for generating dataset
│   ├── saved_gan_dataset.py  # <- for pre-generated dataset
│   └── real_dataset.py  # <- for evaluation datasets (i.e. real images)
├── optimization
│   ├── main.py  # <- main script
│   └── utils.py  # <- helper functions
└── segmentation
    ├── generate_and_save.py  # <- for generating a dataset and saving it to disk
    ├── main.py  # <- main script, uses PyTorch Lightning 
    ├── metrics.py  # <- for mIoU/F-score calculations
    └── utils.py  # <- helper functions

Datasets

The datasets should have the following structure. You can easily add you own datasets or use only a subset of these datasets by modifying config/segment.yaml. You should specify your directory by modifying root in that file on line 19, or by passing data_seg.root=MY_DIR using the command line whenever you call python segmentation/main.py.

├── DUT_OMRON
│   ├── DUT-OMRON-image
│   │   └── ...
│   └── pixelwiseGT-new-PNG
│       └── ...
├── DUTS
│   ├── DUTS-TE
│   │   ├── DUTS-TE-Image
│   │   │   └── ...
│   │   └── DUTS-TE-Mask
│   │       └── ...
│   └── DUTS-TR
│       ├── DUTS-TR-Image
│       │   └── ...
│       └── DUTS-TR-Mask
│           └── ...
├── ECSSD
│   ├── ground_truth_mask
│   │   └── ...
│   └── images
│       └── ...
├── CUB_200_2011
│   ├── train_images
│   │   └── ...
│   ├── train_segmentations
│   │   └── ...
│   ├── test_images
│   │   └── ...
│   └── test_segmentations
│       └── ...
└── Flowers
    ├── train_images
    │   └── ...
    ├── train_segmentations
    │   └── ...
    ├── test_images
    │   └── ...
    └── test_segmentations
        └── ...

The datasets can be downloaded from:

Training

Before training, make sure you understand the general approach (explained above).

Note: All commands are called from within the src directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

Optimization

PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME

This should take less than 5 minutes to run. The output will be saved in outputs/optimization/fixed-BigBiGAN-NAME/DATE/, with the final checkpoint in latest.pth.

Segmentation with precomputed generations

The recommended way of training is to generate the data first and train afterward. An example generation script would be:

PYTHONPATH=. python segmentation/generate_and_save.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.save_dir="YOUR_OUTPUT_DIR" \
data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128

This will generate 1 million image-label pairs and save them to YOUR_OUTPUT_DIR/images. Note that YOUR_OUTPUT_DIR should be an absolute path, not a relative one, because Hydra changes the working directory. You may also want to tune the generation_batch_size to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"

It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

Segmentation with on-the-fly generations

Alternatively, you can generate data while training the segmentation model. An example script would be:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.kwargs.generation_batch_size=128

Evaluation

To evaluate, set the train argument to False. For example:

python train.py \
name="eval" \
train=False \
eval_checkpoint=${checkpoint} \
data_seg.root=${DATASETS_DIR} 

Pretrained models

  • ... are coming soon!

Available GANs

It should be possible to use any GAN from pytorch-pretrained-gans, including:

Citation

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. The pytorch implementation of  DG-Font: Deformable Generative Networks for Unsupervised Font Generation
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

Minimal PyTorch implementation of Generative Latent Optimization from the paper
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

DeepCAD: A Deep Generative Network for Computer-Aided Design Models
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

Comments
  • pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    Hi, is the repo in the pytorch-pretrained-gans step public or is that the right URL for it? I got prompted for username and password when I tried the pip install git+ and don't see the repo at that URL: https://github.com/lukemelas/pytorch-pretrained-gans (Get 404)

    Thanks.

    opened by ModMorph 2
  • Help producing results with the StyleGAN models

    Help producing results with the StyleGAN models

    Hi there!

    I'm having trouble producing meaningful results on StyleGAN2 on AFHQ. I've been using the default setup and hyperparameters. After 50 iterations (with the default batch size of 32) I get visualisations that look initially promising: (https://i.imgur.com/eR79Wyd.png). But as training progresses, and indeed when it reaches 300 iterations, these are the visualisation results: https://i.imgur.com/36zhBzT.png.

    I've tried playing with the learning rate, and the number of iterations with no success yet. Did you have tips here or ideas as to what might be going wrong here?

    Thanks! James.

    opened by james-oldfield 1
  • bug

    bug

    Firstly, I ran PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME. And then, I ran PYTHONPATH=. python segmentation/generate_and_save.py \ name=NAME \ data_gen=generated \ data_gen/generator=bigbigan \ data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \ data_gen.save_dir="YOUR_OUTPUT_DIR" \ data_gen.save_size=1000000 \ data_gen.kwargs.batch_size=1 \ data_gen.kwargs.generation_batch_size=128 When I ran PYTHONPATH=. python segmentation/main.py \ name=NAME \ data_gen=saved \ data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE" An error occurred. The error is: Traceback (most recent call last): File "segmentation/main.py", line 98, in main kwargs = dict(images_dir=_cfg.images_dir, labels_dir=_cfg.labels_dir, omegaconf.errors.InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable '/raid/name/gaochengli/segmentation/src/images' not found" full_key: data_seg.data[0].images_dir object_type=dict According to what you wrote, I modified the root (config/segment.yaml on line 19). Just like this "/raid/name/gaochengli/segmentation/src/images". And the folder contains all data sets,whose name is images. I wonder why such a mistake happened.

    opened by Lee-Gao 1
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
Diverse graph algorithms implemented using JGraphT library.

# 1. Installing Maven & Pandas First, please install Java (JDK11) and Python 3 if they are not already. Next, make sure that Maven (for importing J

See Woo Lee 3 Dec 17, 2022
Official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right"

Surface Form Competition This is the official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right" We p

Peter West 46 Dec 23, 2022
Sequence to Sequence Models with PyTorch

Sequence to Sequence models with PyTorch This repository contains implementations of Sequence to Sequence (Seq2Seq) models in PyTorch At present it ha

Sandeep Subramanian 708 Dec 19, 2022
Get started learning C# with C# notebooks powered by .NET Interactive and VS Code.

.NET Interactive Notebooks for C# Welcome to the home of .NET interactive notebooks for C#! How to Install Download the .NET Coding Pack for VS Code f

.NET Platform 425 Dec 25, 2022
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022
ViSER: Video-Specific Surface Embeddings for Articulated 3D Shape Reconstruction

ViSER: Video-Specific Surface Embeddings for Articulated 3D Shape Reconstruction. NeurIPS 2021.

Gengshan Yang 59 Nov 25, 2022
Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression

Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression YOLOv5 with alpha-IoU losses implemented in PyTorch. Example r

Jacobi(Jiabo He) 147 Dec 05, 2022
Dilated Convolution for Semantic Image Segmentation

Multi-Scale Context Aggregation by Dilated Convolutions Introduction Properties of dilated convolution are discussed in our ICLR 2016 conference paper

Fisher Yu 764 Dec 26, 2022
Super Resolution for images using deep learning.

Neural Enhance Example #1 — Old Station: view comparison in 24-bit HD, original photo CC-BY-SA @siv-athens. As seen on TV! What if you could increase

Alex J. Champandard 11.7k Dec 29, 2022
Using LSTM to detect spoofing attacks in an Air-Ground network

Using LSTM to detect spoofing attacks in an Air-Ground network Specifications IDE: Spider Packages: Tensorflow 2.1.0 Keras NumPy Scikit-learn Matplotl

Tiep M. H. 1 Nov 20, 2021
Python Interview Questions

Python Interview Questions Clone the code to your computer. You need to understand the code in main.py and modify the content in if __name__ =='__main

ClassmateLin 575 Dec 28, 2022
ULMFiT for Genomic Sequence Data

Genomic ULMFiT This is an implementation of ULMFiT for genomics classification using Pytorch and Fastai. The model architecture used is based on the A

Karl 276 Dec 12, 2022
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Daochen Zha 48 Nov 21, 2022
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 09, 2023
Generating Images with Recurrent Adversarial Networks

Generating Images with Recurrent Adversarial Networks Python (Theano) implementation of Generating Images with Recurrent Adversarial Networks code pro

Daniel Jiwoong Im 121 Sep 08, 2022
A semismooth Newton method for elliptic PDE-constrained optimization

sNewton4PDEOpt The Python module implements a semismooth Newton method for solving finite-element discretizations of the strongly convex, linear ellip

2 Dec 08, 2022
Pytorch code for semantic segmentation using ERFNet

ERFNet (PyTorch version) This code is a toolbox that uses PyTorch for training and evaluating the ERFNet architecture for semantic segmentation. For t

Edu 394 Jan 01, 2023
GND-Nets (Graph Neural Diffusion Networks) in TensorFlow.

GNDC For submission to IEEE TKDE. Overview Here we provide the implementation of GND-Nets (Graph Neural Diffusion Networks) in TensorFlow. The reposit

Wei Ye 3 Aug 08, 2022
[Official] Exploring Temporal Coherence for More General Video Face Forgery Detection(ICCV 2021)

Exploring Temporal Coherence for More General Video Face Forgery Detection(FTCN) Yinglin Zheng, Jianmin Bao, Dong Chen, Ming Zeng, Fang Wen Accepted b

57 Dec 28, 2022