A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Overview

Minimal implementation of diffusion models

A minimal implementation of diffusion models with the goal to democratize the use of synthetic data from these models.

Check out the experimental results section for quantitative numbers on quality of synthetic data and FAQs for a broader discussion. We experiments with nine commonly used datasets, and released all assets, including models and synthetic data for each of them.

Requirements: pip install scipy opencv-python. We assume torch and torchvision are already installed.

Structure

main.py  - Train or sample from a diffusion model.
unets.py - UNet based network architecture for diffusion model.
data.py  - Common datasets and their metadata.
──  scripts
     └── train.sh  - Training scripts for all datasets.
     └── sample.sh - Sampling scripts for all datasets.

Training

Use the following command to train the diffusion model on four gpus.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --epochs 500

We provide the exact script used for training in ./scripts/train.sh.

Sampling

We reuse main.py for sampling but with the --sampling-only only flag. Use the following command to sample 50K images from a pretrained diffusion model.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 \
  --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model

We provide the exact script used for sampling in ./scripts/sample.sh.

How useful is synthetic data from diffusion models? 🤔

Takeaway: Across all datasets, training only on synthetic data suffice to achieve a competitive classification score on real data.

Goal: Our goal is to not only measure photo-realism of synthetic images but also measure how well synthetic images cover the data distribution, i.e., how diverse is synthetic data. Note that a generative model, commonly GANs, can generate high-quality images, but still fail to generate diverse images.

Choice of datasets: We use nine commonly used datasets in image recognition. The goal was to multiple datasets was to capture enough diversity in terms of the number of samples, the number of classes, and coarse vs fine-grained classification. In addition, by using a common setup across datasets, we can test the success of diffusion models without any assumptions about the dataset.

Diffusion model: For each dataset, we train a class-conditional diffusion model. We choose a modest size network and train it for a limited number of hours on a 4xA4000 cluster, as highlighted by the training time in the table below. Next, we sample 50,000 synthetic images from the diffusion model.

Metric to measure synthetic data quality: We train a ResNet50 classifier on only real images and another one on only synthetic images and measure their accuracy on the validation set of real images. This metric is also referred to as classification accuracy score and it provides us a way to measure both quality and diversity of synthetic data in a unified manner across datasets.

Released assets for each dataset: Pre-trained Diffusion models, 50,000 synthetic images for each dataset, and downstream clasifiers trained with real-only or synthetic-only dataset.

Table 1: Training images and classes refer to the number of training images and the number of classes in the dataset. Training time refers to the time taken to train the diffusion model. Real only is the test set accuracy of ResNet-50 model trained on only real training images. Synthetic accuracy is the test accuracy of the ResNet-50 model trained on only 50K synthetic images.

Dataset Training images Classes Training time (hours) Real only Synthetic only
MNIST 60,000 10 2.1 99.6 99.0
MNIST-M 60,000 10 5.3 99.3 97.3
CIFAR-10 50,000 10 10.7 93.8 87.3
Skin Cancer* 33126 2 19.1 69.7 64.1
AFHQ 14630 3 8.6 97.9 98.7
CelebA 109036 4 12.8 90.1 88.9
Standford Cars 8144 196 7.4 33.7 76.6
Oxford Flowers 2040 102 6.0 29.3 76.3
Traffic signs 39252 43 8.3 96.6 96.1

* Due to heavy class imbalance, we use AUROC to measure classification performance.

Note: Except for CIFAR10, MNIST, MNIST-M, and GTSRB, we use 64x64 image resolution for all datasets. The key reason to use a lower resolution was to reduce the computational resources needed to train the diffusion model.

Discussion: Across most datasets training only on synthetic data achieves competitive performance with training on real data. It shows that the synthetic data 1) has high-quality images, otherwise the model wouldn't have learned much from it 2) high coverage of distribution, otherwise, the model trained on synthetic data won't do well on the whole test set. Even more, the synthetic dataset has a unique advantage: we can easily generate a very large amount of it. This difference is clearly visible for the low-data regime (flowers and cars dataset), where training on synthetic data (50K images) achieves much better performance than training on real data, which has less than 10K images. A more principled investigation of sample complexity, i.e., performance vs number-of-synthetic-images is available in one of my previous papers (fig. 9).

FAQs

Q. Why use diffusion models?
A. This question is super broad and has multiple answers. 1) They are super easy to train. Unlike GANs, there are no training instabilities in the optimization process. 2) The mode coverage of the diffusion models is excellent where at the same time the generated images are quite photorealistic. 3) The training pipeline is also consistent across datasets, i.e., no assumption about the data. For all datasets above, the only parameter we changed was the amount of training time.

Q. Is synthetic data from diffusion models much different from other generative models, in particular GANs?
A. As mentioned in the previous answer, synthetic data from diffusion models have much higher coverage than GANs, while having a similar image quality. Check out the this previous paper by Prafulla Dhariwal and Alex Nichol where they provide extensive results supporting this claim. In the regime of robust training, you can find a more quantitive comparison of diffusion models with multiple GANs in one of my previous papers.

Q. Why classification accuracy on some datasets is so low (e.g., flowers), even when training with real data?
A. Due to many reasons, current classification numbers aren't meant to be competitive with state-of-the-art. 1) We don't tune any hyperparameters across datasets. For each dataset, we train a ResNet50 model with 0.1 learning rate, 1e-4 weight decay, 0.9 momentum, and cosine learning rate decay. 2) Instead of full resolution (commonly 224x224), we use low-resolution images (64x64), which makes classification harder.

Q. Using only synthetic data, how to further improve the test accuracy on real data?
A. Diffusion models benefit tremendously from scaling of the training setup. One can do so by increasing the network width (base_width) and training the network for more epochs (2-4x).

References

This implementation was originally motivated by the original implmentation of diffusion models by Jonathan Ho. I followed the recent PyTorch implementation by OpenAI for common design choices in diffusion models.

The experiments to test out the potential of synthetic data from diffusion models are inspired by one of my previous work. We found that using synthetic data from the diffusion model alone surpasses benefits from multiple algorithmic innovations in robust training, which is one of the simple yet extremely hard problems to solve for neural networks. The next step is to repeat the Table-1 experiments, but this time with robust training.

Visualizing real and synthetic images

For each data, we plot real images on the left and synthetic images on the right. Each row corresponds to a unique class while classes for real and synthetic data are identical.

Light     Dark

MNIST

Light     Dark

MNIST-M

Light     Dark

CIFAR-10

Light     Dark

GTSRB

Light     Dark

Celeb-A

Light     Dark

AFHQ

Light     Dark

Cars

Light     Dark

Flowers

Light     Dark

Melanoma (Skin cancer)

Light     Dark

Note: Real images for each dataset follow the same license as their respective dataset.

Owner
Vikash Sehwag
PhD candidate at Princeton University. Interested in problems at the intersection of Security, Privacy, and Machine leanring.
Vikash Sehwag
Code for the paper: Fighting Fake News: Image Splice Detection via Learned Self-Consistency

Fighting Fake News: Image Splice Detection via Learned Self-Consistency [paper] [website] Minyoung Huh *12, Andrew Liu *1, Andrew Owens1, Alexei A. Ef

minyoung huh (jacob) 174 Dec 09, 2022
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Dec 28, 2022
3D position tracking for soccer players with multi-camera videos

This repo contains a full pipeline to support 3D position tracking of soccer players, with multi-view calibrated moving/fixed video sequences as inputs.

Yuchang Jiang 72 Dec 27, 2022
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
The Illinois repository for Climatehack (https://climatehack.ai/). We won 1st place!

Climatehack This is the repository for Illinois's Climatehack Team. We earned first place on the leaderboard with a final score of 0.87992. An overvie

Jatin Mathur 20 Jun 09, 2022
A model that attempts to learn and benefit from data collected on card counting.

A model that attempts to learn and benefit from data collected on card counting. A decision tree like model is built to win more often than loose and increase the bet of the player appropriately to c

1 Dec 17, 2021
A repo for Causal Imitation Learning under Temporally Correlated Noise

CausIL A repo for Causal Imitation Learning under Temporally Correlated Noise. Running Experiments To re-train an expert, run: python experts/train_ex

Gokul Swamy 5 Nov 01, 2022
a reimplementation of UnFlow in PyTorch that matches the official TensorFlow version

pytorch-unflow This is a personal reimplementation of UnFlow [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 134 Nov 20, 2022
Deep Multi-Magnification Network for multi-class tissue segmentation of whole slide images

Deep Multi-Magnification Network This repository provides training and inference codes for Deep Multi-Magnification Network published here. Deep Multi

Computational Pathology 12 Aug 06, 2022
This repo contains the pytorch implementation for Dynamic Concept Learner (accepted by ICLR 2021).

DCL-PyTorch Pytorch implementation for the Dynamic Concept Learner (DCL). More details can be found at the project page. Framework Grounding Physical

Zhenfang Chen 31 Jan 06, 2023
Text Generation by Learning from Demonstrations

Text Generation by Learning from Demonstrations The README was last updated on March 7, 2021. The repo is based on fairseq (v0.9.?). Paper arXiv Prere

38 Oct 21, 2022
mmdetection version of TinyBenchmark.

introduction This project is an mmdetection version of TinyBenchmark. TODO list: add TinyPerson dataset and evaluation add crop and merge for image du

34 Aug 27, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
Official PyTorch code for CVPR 2020 paper "Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision"

Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision https://arxiv.org/abs/2003.00393 Abstract Active learning (AL) aims to min

Denis 29 Nov 21, 2022
Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis.

ID-Unet: Iterative-view-synthesis(CVPR2021 Oral) Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis. Overvie

17 Aug 23, 2022
Offline Reinforcement Learning with Implicit Q-Learning

Offline Reinforcement Learning with Implicit Q-Learning This repository contains the official implementation of Offline Reinforcement Learning with Im

Ilya Kostrikov 126 Jan 06, 2023
Code for CPM-2 Pre-Train

CPM-2 Pre-Train Pre-train CPM-2 此分支为110亿非 MoE 模型的预训练代码,MoE 模型的预训练代码请切换到 moe 分支 CPM-2技术报告请参考link。 0 模型下载 请在智源资源下载页面进行申请,文件介绍如下: 文件名 描述 参数大小 100000.tar

Tsinghua AI 136 Dec 28, 2022
Geometry-Aware Learning of Maps for Camera Localization (CVPR2018)

Geometry-Aware Learning of Maps for Camera Localization This is the PyTorch implementation of our CVPR 2018 paper "Geometry-Aware Learning of Maps for

NVIDIA Research Projects 321 Nov 26, 2022
Customizable RecSys Simulator for OpenAI Gym

gym-recsys: Customizable RecSys Simulator for OpenAI Gym Installation | How to use | Examples | Citation This package describes an OpenAI Gym interfac

Xingdong Zuo 14 Dec 08, 2022
Convert Python 3 code to CUDA code.

Py2CUDA Convert python code to CUDA. Usage To convert a python file say named py_file.py to CUDA, run python generate_cuda.py --file py_file.py --arch

Yuval Rosen 3 Jul 14, 2021