Labels4Free: Unsupervised Segmentation using StyleGAN

Overview

Labels4Free: Unsupervised Segmentation using StyleGAN

ICCV 2021

image Figure: Some segmentation masks predicted by Labels4Free Framework on real and synthetic images

We propose an unsupervised segmentation framework for StyleGAN generated objects. We build on two main observations. First, the features generated by StyleGAN hold valuable information that can be utilized towards training segmentation networks. Second, the foreground and background can often be treated to be largely independent and be swapped across images to produce plausible composited images. For our solution, we propose to augment the Style-GAN2 generator architecture with a segmentation branch and to split the generator into a foreground and background network. This enables us to generate soft segmentation masks for the foreground object in an unsupervised fashion. On multiple object classes, we report comparable results against state-of-the-art supervised segmentation networks, while against the best unsupervised segmentation approach we demonstrate a clear improvement, both in qualitative and quantitative metrics.

Labels4Free: Unsupervised Segmentation Using StyleGAN (ICCV 2021)
Rameen Abdal, Peihao Zhu, Niloy Mitra, Peter Wonka
KAUST, Adobe Research

[Paper] [Project Page] [Video]

Installation

Clone this repo.

git clone https://github.com/RameenAbdal/Labels4Free.git
cd Labels4Free/

This repo is based on the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). Refer to this repo for setting up the environment, preparation of LMDB datasets and downloading pretrained weights of the models.

Download the pretrained weights of Alpha Networks here

Training the models

The models were trained on 4 RTX 2080 (24 GB) GPUs. In order to train the models using the settings in the paper use the following commands for each dataset.

Checkpoints and samples are saved in ./checkpoint and ./sample folders.

FFHQ dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT]--loss_multiplier 1.2 --iter 1200 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Horse dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 500 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Cat dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT]  --loss_multiplier 3 --iter 900 --trunc 0.5 --lr 0.0002 --reproduce_model

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 10 --iter 50 --trunc 0.3 --lr 0.002 --sat_weight 1.0 --model_save_freq 25 --reproduce_model --use_disc

In order to train your own models using different settings e.g on a single GPU, using different samples, iterations etc. use the following commands.

FFHQ dataset

python train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT] --loss_multiplier 1.2 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 3 --bg_coverage_value 0.4

LSUN-Horse dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 6 --bg_coverage_value 0.6

LSUN-Cat dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 0.5 --lr 0.0002 --bg_coverage_wt 4 --bg_coverage_value 0.35

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 20 --iter 750 --trunc 0.3 --lr 0.0008 --sat_weight 0.1 --bg_coverage_wt 40 --bg_coverage_value 0.75 --model_save_freq 50

Sample from the pretrained model

Samples are saved in ./test_sample folder.

python test_sample.py --size [SIZE] --batch 2 --n_sample 100 --ckpt_bg_extractor [ALPHANETWORK_MODEL] --ckpt_generator [GENERATOR_MODEL] --th 0.9

Results on Custom dataset

Folder: Custom dataset, predicted and ground truth masks.

python test_customdata.py --path_gt [GT_Folder] --path_pred [PRED_FOLDER]

Citation

@InProceedings{Abdal_2021_ICCV,
    author    = {Abdal, Rameen and Zhu, Peihao and Mitra, Niloy J. and Wonka, Peter},
    title     = {Labels4Free: Unsupervised Segmentation Using StyleGAN},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {13970-13979}
}

Acknowledgments

This implementation builds upon the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). This work was supported by Adobe Research and KAUST Office of Sponsored Research (OSR).

Owner
PhD @ KAUST
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
[ ICCV 2021 Oral ] Our method can estimate camera poses and neural radiance fields jointly when the cameras are initialized at random poses in complex scenarios (outside-in scenes, even with less texture or intense noise )

GNeRF This repository contains official code for the ICCV 2021 paper: GNeRF: GAN-based Neural Radiance Field without Posed Camera. This implementation

Quan Meng 191 Dec 26, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images

Main repo for ECCV 2020 paper MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images. visual.cs.brown.edu/matryodshka

Brown University Visual Computing Group 75 Dec 13, 2022
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
PyTorch Live is an easy to use library of tools for creating on-device ML demos on Android and iOS.

PyTorch Live is an easy to use library of tools for creating on-device ML demos on Android and iOS. With Live, you can build a working mobile app ML demo in minutes.

559 Jan 01, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Implementation of MA-Trace - a general-purpose multi-agent RL algorithm for cooperative environments.

Off-Policy Correction For Multi-Agent Reinforcement Learning This repository is the official implementation of Off-Policy Correction For Multi-Agent R

4 Aug 18, 2022
This is an official implementation for "SimMIM: A Simple Framework for Masked Image Modeling".

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

Microsoft 674 Dec 26, 2022
Hi Guys, here I am providing examples, which will help you in Lerarning Python

LearningPython Hi guys, here I am trying to include as many practice examples of Python Language, as i Myself learn, and hope these will help you in t

4 Feb 03, 2022
PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement.

DECOR-GAN PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement, Zhiqin Chen, Vladimir G. Kim, Matthew Fish

Zhiqin Chen 72 Dec 31, 2022
Source code for ZePHyR: Zero-shot Pose Hypothesis Rating @ ICRA 2021

ZePHyR: Zero-shot Pose Hypothesis Rating ZePHyR is a zero-shot 6D object pose estimation pipeline. The core is a learned scoring function that compare

R-Pad - Robots Perceiving and Doing 18 Aug 22, 2022
Official repository of my book: "Deep Learning with PyTorch Step-by-Step: A Beginner's Guide"

This is the official repository of my book "Deep Learning with PyTorch Step-by-Step". Here you will find one Jupyter notebook for every chapter in the book.

Daniel Voigt Godoy 340 Jan 01, 2023
This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices.

GBW This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices. Ap

Andi Han 0 Oct 22, 2021
PyTorch implementation of our method for adversarial attacks and defenses in hyperspectral image classification.

Self-Attention Context Network for Hyperspectral Image Classification PyTorch implementation of our method for adversarial attacks and defenses in hyp

22 Dec 02, 2022
Tzer: TVM Implementation of "Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation (OOPSLA'22)“.

Artifact • Reproduce Bugs • Quick Start • Installation • Extend Tzer Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation This is the s

12 Dec 29, 2022
Tool which allow you to detect and translate text.

Text detection and recognition This repository contains tool which allow to detect region with text and translate it one by one. Description Two pretr

Damian Panek 176 Nov 28, 2022
Buffon’s needle: one of the oldest problems in geometric probability

Buffon-s-Needle Buffon’s needle is one of the oldest problems in geometric proba

3 Feb 18, 2022
Set of methods to ensemble boxes from different object detection models, including implementation of "Weighted boxes fusion (WBF)" method.

Set of methods to ensemble boxes from different object detection models, including implementation of "Weighted boxes fusion (WBF)" method.

1.4k Jan 05, 2023