Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Overview

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Paper

Description

Recent research has shown that numerous human-interpretable directions exist in the latent space of GANs. In this paper, we develop an automatic procedure for finding directions that lead to foreground-background image separation, and we use these directions to train an image segmentation model without human supervision. Our method is generator-agnostic, producing strong segmentation results with a wide range of different GAN architectures. Furthermore, by leveraging GANs pretrained on large datasets such as ImageNet, we are able to segment images from a range of domains without further training or finetuning. Evaluating our method on image segmentation benchmarks, we compare favorably to prior work while using neither human supervision nor access to the training data. Broadly, our results demonstrate that automatically extracting foreground-background structure from pretrained deep generative models can serve as a remarkably effective substitute for human supervision.

How to run

Dependencies

This code depends on pytorch-pretrained-gans, a repository I developed that exposes a standard interface for a variety of pretrained GANs. Install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

The pretrained weights for most GANs are downloaded automatically. For those that are not, I have provided scripts in that repository.

There are also some standard dependencies:

Install them with:

pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia

General Approach

Our unsupervised segmentation approach has two steps: (1) finding a good direction in latent space, and (2) training a segmentation model from data and masks that are generated using this direction.

In detail, this means:

  1. We use optimization/main.py finds a salient direction (or two salient directions) in the latent space of a given pretrained GAN that leads to foreground-background image separation.
  2. We use segmentation/main.py to train a standard segmentation network (a UNet) on generated data. The data can be generated in two ways: (1) you can generate the images on-the-fly during training, or (2) you can generate the images before training the segmentation model using segmentation/generate_and_save.py and then train the segmentation network afterward. The second approach is faster, but requires more disk space (~10GB for 1 million images). We will also provide a pre-generated dataset (coming soon).

Configuration and Logging

We use Hydra for configuration and Weights and Biases for logging. With Hydra, you can specify a config file (found in configs/) with --config-name=myconfig.yaml. You can also override the config from the command line by specifying the overriding arguments (without --). For example, you can enable Weights and Biases with wandb=True and you can name the run with name=myname.

The structure of the configs is as follows:

config
├── data_gen
│   ├── generated.yaml  # <- for generating data with 1 latent direction
│   ├── generated-dual.yaml   # <- for generating data with 2 latent directions
│   ├── generator  # <- different types of GANs for generating data
│   │   ├── bigbigan.yaml
│   │   ├── pretrainedbiggan.yaml
│   │   ├── selfconditionedgan.yaml
│   │   ├── studiogan.yaml
│   │   └── stylegan2.yaml 
│   └── saved.yaml  # <- for using pre-generated data
├── optimize.yaml  # <- for optimization
└── segment.yaml   # <- for segmentation

Code Structure

The code is structured as follows:

src
├── models  # <- segmentation model
│   ├── __init__.py
│   ├── latent_shift_model.py  # <- shifts direction in latent space
│   ├── unet_model.py  # <- segmentation model
│   └── unet_parts.py
├── config  # <- configuration, explained above
│   ├── ... 
├── datasets  # <- classes for loading datasets during segmentation/generation
│   ├── __init__.py
│   ├── gan_dataset.py  # <- for generating dataset
│   ├── saved_gan_dataset.py  # <- for pre-generated dataset
│   └── real_dataset.py  # <- for evaluation datasets (i.e. real images)
├── optimization
│   ├── main.py  # <- main script
│   └── utils.py  # <- helper functions
└── segmentation
    ├── generate_and_save.py  # <- for generating a dataset and saving it to disk
    ├── main.py  # <- main script, uses PyTorch Lightning 
    ├── metrics.py  # <- for mIoU/F-score calculations
    └── utils.py  # <- helper functions

Datasets

The datasets should have the following structure. You can easily add you own datasets or use only a subset of these datasets by modifying config/segment.yaml. You should specify your directory by modifying root in that file on line 19, or by passing data_seg.root=MY_DIR using the command line whenever you call python segmentation/main.py.

├── DUT_OMRON
│   ├── DUT-OMRON-image
│   │   └── ...
│   └── pixelwiseGT-new-PNG
│       └── ...
├── DUTS
│   ├── DUTS-TE
│   │   ├── DUTS-TE-Image
│   │   │   └── ...
│   │   └── DUTS-TE-Mask
│   │       └── ...
│   └── DUTS-TR
│       ├── DUTS-TR-Image
│       │   └── ...
│       └── DUTS-TR-Mask
│           └── ...
├── ECSSD
│   ├── ground_truth_mask
│   │   └── ...
│   └── images
│       └── ...
├── CUB_200_2011
│   ├── train_images
│   │   └── ...
│   ├── train_segmentations
│   │   └── ...
│   ├── test_images
│   │   └── ...
│   └── test_segmentations
│       └── ...
└── Flowers
    ├── train_images
    │   └── ...
    ├── train_segmentations
    │   └── ...
    ├── test_images
    │   └── ...
    └── test_segmentations
        └── ...

The datasets can be downloaded from:

Training

Before training, make sure you understand the general approach (explained above).

Note: All commands are called from within the src directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

Optimization

PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME

This should take less than 5 minutes to run. The output will be saved in outputs/optimization/fixed-BigBiGAN-NAME/DATE/, with the final checkpoint in latest.pth.

Segmentation with precomputed generations

The recommended way of training is to generate the data first and train afterward. An example generation script would be:

PYTHONPATH=. python segmentation/generate_and_save.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.save_dir="YOUR_OUTPUT_DIR" \
data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128

This will generate 1 million image-label pairs and save them to YOUR_OUTPUT_DIR/images. Note that YOUR_OUTPUT_DIR should be an absolute path, not a relative one, because Hydra changes the working directory. You may also want to tune the generation_batch_size to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"

It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

Segmentation with on-the-fly generations

Alternatively, you can generate data while training the segmentation model. An example script would be:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.kwargs.generation_batch_size=128

Evaluation

To evaluate, set the train argument to False. For example:

python train.py \
name="eval" \
train=False \
eval_checkpoint=${checkpoint} \
data_seg.root=${DATASETS_DIR} 

Pretrained models

  • ... are coming soon!

Available GANs

It should be possible to use any GAN from pytorch-pretrained-gans, including:

Citation

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. The pytorch implementation of  DG-Font: Deformable Generative Networks for Unsupervised Font Generation
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

Minimal PyTorch implementation of Generative Latent Optimization from the paper
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

DeepCAD: A Deep Generative Network for Computer-Aided Design Models
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

Comments
  • pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    Hi, is the repo in the pytorch-pretrained-gans step public or is that the right URL for it? I got prompted for username and password when I tried the pip install git+ and don't see the repo at that URL: https://github.com/lukemelas/pytorch-pretrained-gans (Get 404)

    Thanks.

    opened by ModMorph 2
  • Help producing results with the StyleGAN models

    Help producing results with the StyleGAN models

    Hi there!

    I'm having trouble producing meaningful results on StyleGAN2 on AFHQ. I've been using the default setup and hyperparameters. After 50 iterations (with the default batch size of 32) I get visualisations that look initially promising: (https://i.imgur.com/eR79Wyd.png). But as training progresses, and indeed when it reaches 300 iterations, these are the visualisation results: https://i.imgur.com/36zhBzT.png.

    I've tried playing with the learning rate, and the number of iterations with no success yet. Did you have tips here or ideas as to what might be going wrong here?

    Thanks! James.

    opened by james-oldfield 1
  • bug

    bug

    Firstly, I ran PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME. And then, I ran PYTHONPATH=. python segmentation/generate_and_save.py \ name=NAME \ data_gen=generated \ data_gen/generator=bigbigan \ data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \ data_gen.save_dir="YOUR_OUTPUT_DIR" \ data_gen.save_size=1000000 \ data_gen.kwargs.batch_size=1 \ data_gen.kwargs.generation_batch_size=128 When I ran PYTHONPATH=. python segmentation/main.py \ name=NAME \ data_gen=saved \ data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE" An error occurred. The error is: Traceback (most recent call last): File "segmentation/main.py", line 98, in main kwargs = dict(images_dir=_cfg.images_dir, labels_dir=_cfg.labels_dir, omegaconf.errors.InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable '/raid/name/gaochengli/segmentation/src/images' not found" full_key: data_seg.data[0].images_dir object_type=dict According to what you wrote, I modified the root (config/segment.yaml on line 19). Just like this "/raid/name/gaochengli/segmentation/src/images". And the folder contains all data sets,whose name is images. I wonder why such a mistake happened.

    opened by Lee-Gao 1
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

causal-bald | Abstract | Installation | Example | Citation | Reproducing Results DUE An implementation of the methods presented in Causal-BALD: Deep B

OATML 13 Oct 07, 2022
Running AlphaFold2 (from ColabFold) in Azure Machine Learning

Running AlphaFold2 (from ColabFold) in Azure Machine Learning Colby T. Ford, Ph.D. Companion repository for Medium Post: How to predict many protein s

Colby T. Ford 3 Feb 18, 2022
Revisiting Temporal Alignment for Video Restoration

Revisiting Temporal Alignment for Video Restoration [arXiv] Kun Zhou, Wenbo Li, Liying Lu, Xiaoguang Han, Jiangbo Lu We provide our results at Google

52 Dec 25, 2022
Parametric Contrastive Learning (ICCV2021)

Parametric-Contrastive-Learning This repository contains the implementation code for ICCV2021 paper: Parametric Contrastive Learning (https://arxiv.or

DV Lab 156 Dec 21, 2022
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 01, 2023
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021)

Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021) Alexey Nekrasov*, Jonas Schult*, Or Litany, Bastian Leibe, Francis Engelmann Mix3D is

Alexey Nekrasov 189 Dec 26, 2022
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021
PyTorch implementation of SQN based on CloserLook3D's encoder

SQN_pytorch This repo is an implementation of Semantic Query Network (SQN) using CloserLook3D's encoder in Pytorch. For TensorFlow implementation, che

PointCloudYC 1 Oct 21, 2021
Dense Prediction Transformers

Vision Transformers for Dense Prediction This repository contains code and models for our paper: Vision Transformers for Dense Prediction René Ranftl,

Intelligent Systems Lab Org 1.3k Jan 02, 2023
This is the official pytorch implementation of the BoxEL for the description logic EL++

BoxEL: Box EL++ Embedding This is the official pytorch implementation of the BoxEL for the description logic EL++. BoxEL++ is a geometric approach bas

1 Nov 03, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
Implementation of neural class expression synthesizers

NCES Implementation of neural class expression synthesizers (NCES) Installation Clone this repository: https://github.com/ConceptLengthLearner/NCES.gi

NeuralConceptSynthesis 0 Jan 06, 2022
Activating More Pixels in Image Super-Resolution Transformer

HAT [Paper Link] Activating More Pixels in Image Super-Resolution Transformer Xiangyu Chen, Xintao Wang, Jiantao Zhou and Chao Dong BibTeX @article{ch

XyChen 270 Dec 27, 2022
All course materials for the Zero to Mastery Machine Learning and Data Science course.

Zero to Mastery Machine Learning Welcome! This repository contains all of the code, notebooks, images and other materials related to the Zero to Maste

Daniel Bourke 1.6k Jan 08, 2023
Consecutive-Subsequence - Simple software to calculate susequence with highest sum

Simple software to calculate susequence with highest sum This repository contain

Gbadamosi Farouk 1 Jan 31, 2022
Minimalist Error collection Service compatible with Rollbar clients. Sentry or Rollbar alternative.

Minimalist Error collection Service Features Compatible with any Rollbar client(see https://docs.rollbar.com/docs). Just change the endpoint URL to yo

Haukur Rósinkranz 381 Nov 11, 2022
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Alleviating Over-segmentation Errors by Detecting Action Boundaries

Alleviating Over-segmentation Errors by Detecting Action Boundaries Forked from ASRF offical code. This repo is the a implementation of replacing orig

13 Dec 12, 2022
Self-supervised learning optimally robust representations for domain generalization.

OptDom: Learning Optimal Representations for Domain Generalization This repository contains the official implementation for Optimal Representations fo

Yangjun Ruan 18 Aug 25, 2022