Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Overview

Label-Efficient Semantic Segmentation with Diffusion Models

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point in the context of semantic segmentation.

DDPM-based Segmentation

 

Dependencies

  • Python >= 3.7
  • Packages: see requirements.txt

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh
  3. Check paths in experiments/ /ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

How to improve the performance

  1. Set input_activations=true in experiments/ /ddpm.json .
       In this case, the feature dimension is 18432.
  2. Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb):
bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh
  2. Check paths in experiments/ /datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh

  3. Check paths in experiments/ /datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setup:

Dataset epochs batch-size multi-crop num-prototypes
LSUN 200 1792 2x256 + 6x108 1000
FFHQ-256 400 2048 2x224 + 6x96 200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/swav/download_checkpoint.sh
  3. Check paths in experiments/ /swav.json
  4. Run: bash scripts/swav/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh
  2. Change paths in experiments/ /datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

  • Performance in terms of mean IoU:
Method Bedroom-28 FFHQ-34 Cat-15 Horse-21 CelebA-19 ADE-Bedroom-30
ALAE 20.0 ± 1.0 48.1 ± 1.3 -- -- 49.7 ± 0.7 15.0 ± 0.5
VDVAE -- 57.3 ± 1.1 -- -- 54.1 ± 1.0 --
GAN Inversion 13.9 ± 0.6 51.7 ± 0.8 21.4 ± 1.7 17.7 ± 0.4 51.5 ± 2.3 11.1 ± 0.2
GAN Encoder 22.4 ± 1.6 53.9 ± 1.3 32.0 ± 1.8 26.7 ± 0.7 53.9 ± 0.8 15.7 ± 0.3
SwAV 41.0 ± 2.3 54.7 ± 1.4 44.1 ± 2.1 51.7 ± 0.5 53.2 ± 1.0 30.3 ± 1.5
DatasetGAN 31.3 ± 2.7 57.0 ± 1.0 36.5 ± 2.3 45.4 ± 1.4 -- --
DatasetDDPM 46.9 ± 2.8 56.0 ± 0.9 45.4 ± 2.8 60.4 ± 1.2 -- --
DDPM 46.1 ± 1.9 57.0 ± 1.4 52.3 ± 3.0 63.1 ± 0.9 57.0 ± 1.0 32.3 ± 1.5

 

  • Examples of segmentation masks predicted by the DDPM-based method:
DDPM-based Segmentation

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
Yandex Research
Yandex Research
Official repository for MixFaceNets: Extremely Efficient Face Recognition Networks

MixFaceNets This is the official repository of the paper: MixFaceNets: Extremely Efficient Face Recognition Networks. (Accepted in IJCB2021) https://i

Fadi Boutros 51 Dec 13, 2022
CondenseNet V2: Sparse Feature Reactivation for Deep Networks

CondenseNetV2 This repository is the official Pytorch implementation for "CondenseNet V2: Sparse Feature Reactivation for Deep Networks" paper by Le Y

Haojun Jiang 74 Dec 12, 2022
PyTorch Implementations for DeeplabV3 and PSPNet

Pytorch-segmentation-toolbox DOC Pytorch code for semantic segmentation. This is a minimal code to run PSPnet and Deeplabv3 on Cityscape dataset. Shor

Zilong Huang 746 Dec 15, 2022
Code for Ditto: Building Digital Twins of Articulated Objects from Interaction

Ditto: Building Digital Twins of Articulated Objects from Interaction Zhenyu Jiang, Cheng-Chun Hsu, Yuke Zhu CVPR 2022, Oral Project | arxiv News 2022

UT Robot Perception and Learning Lab 78 Dec 22, 2022
This is a five-step framework for the development of intrusion detection systems (IDS) using machine learning (ML) considering model realization, and performance evaluation.

AB-TRAP: building invisibility shields to protect network devices The AB-TRAP framework is applicable to the development of Network Intrusion Detectio

Lab-C2DC - Laboratory of Command and Control and Cyber-security 17 Jan 04, 2023
Implementation for "Manga Filling Style Conversion with Screentone Variational Autoencoder" (SIGGRAPH ASIA 2020 issue)

Manga Filling with ScreenVAE SIGGRAPH ASIA 2020 | Project Website | BibTex This repository is for ScreenVAE introduced in the following paper "Manga F

30 Dec 24, 2022
Official pytorch code for SSAT: A Symmetric Semantic-Aware Transformer Network for Makeup Transfer and Removal

SSAT: A Symmetric Semantic-Aware Transformer Network for Makeup Transfer and Removal This is the official pytorch code for SSAT: A Symmetric Semantic-

ForeverPupil 57 Dec 13, 2022
Instance-wise Occlusion and Depth Orders in Natural Scenes (CVPR 2022)

Instance-wise Occlusion and Depth Orders in Natural Scenes Official source code. Appears at CVPR 2022 This repository provides a new dataset, named In

27 Dec 27, 2022
Model of an AI powered sign language interpreter.

TEXT AND SPEECH TO SIGN LANGUAGE. A web application which takes in text or live audio speech recording as input, converts and displays the relevant Si

Mark Gatere 4 Mar 30, 2022
Qt-GUI implementation of the YOLOv5 algorithm (ver.6 and ver.5)

YOLOv5-GUI 🎉 YOLOv5算法(ver.6及ver.5)的Qt-GUI实现 🎉 Qt-GUI implementation of the YOLOv5 algorithm (ver.6 and ver.5). 基于YOLOv5的v5版本和v6版本及Javacr大佬的UI逻辑进行编写

EricFang 12 Dec 28, 2022
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding 📋 This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

55 Dec 21, 2022
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 03, 2023
Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

1 Jun 02, 2022
Unsupervised Image Generation with Infinite Generative Adversarial Networks

Unsupervised Image Generation with Infinite Generative Adversarial Networks Here is the implementation of MICGANs using DCGAN architecture on MNIST da

16 Dec 24, 2021
Semi-supervised Stance Detection of Tweets Via Distant Network Supervision

SANDS This is an annonymous repository containing code and data necessary to reproduce the results published in "Semi-supervised Stance Detection of T

2 Sep 22, 2022
Dogs classification with Deep Metric Learning using some popular losses

Tsinghua Dogs classification with Deep Metric Learning 1. Introduction Tsinghua Dogs dataset Tsinghua Dogs is a fine-grained classification dataset fo

QuocThangNguyen 45 Nov 09, 2022
Simple-Neural-Network From Scratch in Python

Simple-Neural-Network From Scratch in Python This is a simple Neural Network created without any Machine Learning Libraries. The only dependencies are

Aum Shah 1 Dec 28, 2021
SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP

scdlpicker SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP Objective This is a simple deep learning (DL) repicker module

Joachim Saul 6 May 13, 2022
BookMyShowPC - Movie Ticket Reservation App made with Tkinter

Book My Show PC What is this? Movie Ticket Reservation App made with Tkinter. Tk

The Nithin Balaji 3 Dec 09, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Jan 07, 2023