A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Overview

Minimal implementation of diffusion models

A minimal implementation of diffusion models with the goal to democratize the use of synthetic data from these models.

Check out the experimental results section for quantitative numbers on quality of synthetic data and FAQs for a broader discussion. We experiments with nine commonly used datasets, and released all assets, including models and synthetic data for each of them.

Requirements: pip install scipy opencv-python. We assume torch and torchvision are already installed.

Structure

main.py  - Train or sample from a diffusion model.
unets.py - UNet based network architecture for diffusion model.
data.py  - Common datasets and their metadata.
──  scripts
     └── train.sh  - Training scripts for all datasets.
     └── sample.sh - Sampling scripts for all datasets.

Training

Use the following command to train the diffusion model on four gpus.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --epochs 500

We provide the exact script used for training in ./scripts/train.sh.

Sampling

We reuse main.py for sampling but with the --sampling-only only flag. Use the following command to sample 50K images from a pretrained diffusion model.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 \
  --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model

We provide the exact script used for sampling in ./scripts/sample.sh.

How useful is synthetic data from diffusion models? 🤔

Takeaway: Across all datasets, training only on synthetic data suffice to achieve a competitive classification score on real data.

Goal: Our goal is to not only measure photo-realism of synthetic images but also measure how well synthetic images cover the data distribution, i.e., how diverse is synthetic data. Note that a generative model, commonly GANs, can generate high-quality images, but still fail to generate diverse images.

Choice of datasets: We use nine commonly used datasets in image recognition. The goal was to multiple datasets was to capture enough diversity in terms of the number of samples, the number of classes, and coarse vs fine-grained classification. In addition, by using a common setup across datasets, we can test the success of diffusion models without any assumptions about the dataset.

Diffusion model: For each dataset, we train a class-conditional diffusion model. We choose a modest size network and train it for a limited number of hours on a 4xA4000 cluster, as highlighted by the training time in the table below. Next, we sample 50,000 synthetic images from the diffusion model.

Metric to measure synthetic data quality: We train a ResNet50 classifier on only real images and another one on only synthetic images and measure their accuracy on the validation set of real images. This metric is also referred to as classification accuracy score and it provides us a way to measure both quality and diversity of synthetic data in a unified manner across datasets.

Released assets for each dataset: Pre-trained Diffusion models, 50,000 synthetic images for each dataset, and downstream clasifiers trained with real-only or synthetic-only dataset.

Table 1: Training images and classes refer to the number of training images and the number of classes in the dataset. Training time refers to the time taken to train the diffusion model. Real only is the test set accuracy of ResNet-50 model trained on only real training images. Synthetic accuracy is the test accuracy of the ResNet-50 model trained on only 50K synthetic images.

Dataset Training images Classes Training time (hours) Real only Synthetic only
MNIST 60,000 10 2.1 99.6 99.0
MNIST-M 60,000 10 5.3 99.3 97.3
CIFAR-10 50,000 10 10.7 93.8 87.3
Skin Cancer* 33126 2 19.1 69.7 64.1
AFHQ 14630 3 8.6 97.9 98.7
CelebA 109036 4 12.8 90.1 88.9
Standford Cars 8144 196 7.4 33.7 76.6
Oxford Flowers 2040 102 6.0 29.3 76.3
Traffic signs 39252 43 8.3 96.6 96.1

* Due to heavy class imbalance, we use AUROC to measure classification performance.

Note: Except for CIFAR10, MNIST, MNIST-M, and GTSRB, we use 64x64 image resolution for all datasets. The key reason to use a lower resolution was to reduce the computational resources needed to train the diffusion model.

Discussion: Across most datasets training only on synthetic data achieves competitive performance with training on real data. It shows that the synthetic data 1) has high-quality images, otherwise the model wouldn't have learned much from it 2) high coverage of distribution, otherwise, the model trained on synthetic data won't do well on the whole test set. Even more, the synthetic dataset has a unique advantage: we can easily generate a very large amount of it. This difference is clearly visible for the low-data regime (flowers and cars dataset), where training on synthetic data (50K images) achieves much better performance than training on real data, which has less than 10K images. A more principled investigation of sample complexity, i.e., performance vs number-of-synthetic-images is available in one of my previous papers (fig. 9).

FAQs

Q. Why use diffusion models?
A. This question is super broad and has multiple answers. 1) They are super easy to train. Unlike GANs, there are no training instabilities in the optimization process. 2) The mode coverage of the diffusion models is excellent where at the same time the generated images are quite photorealistic. 3) The training pipeline is also consistent across datasets, i.e., no assumption about the data. For all datasets above, the only parameter we changed was the amount of training time.

Q. Is synthetic data from diffusion models much different from other generative models, in particular GANs?
A. As mentioned in the previous answer, synthetic data from diffusion models have much higher coverage than GANs, while having a similar image quality. Check out the this previous paper by Prafulla Dhariwal and Alex Nichol where they provide extensive results supporting this claim. In the regime of robust training, you can find a more quantitive comparison of diffusion models with multiple GANs in one of my previous papers.

Q. Why classification accuracy on some datasets is so low (e.g., flowers), even when training with real data?
A. Due to many reasons, current classification numbers aren't meant to be competitive with state-of-the-art. 1) We don't tune any hyperparameters across datasets. For each dataset, we train a ResNet50 model with 0.1 learning rate, 1e-4 weight decay, 0.9 momentum, and cosine learning rate decay. 2) Instead of full resolution (commonly 224x224), we use low-resolution images (64x64), which makes classification harder.

Q. Using only synthetic data, how to further improve the test accuracy on real data?
A. Diffusion models benefit tremendously from scaling of the training setup. One can do so by increasing the network width (base_width) and training the network for more epochs (2-4x).

References

This implementation was originally motivated by the original implmentation of diffusion models by Jonathan Ho. I followed the recent PyTorch implementation by OpenAI for common design choices in diffusion models.

The experiments to test out the potential of synthetic data from diffusion models are inspired by one of my previous work. We found that using synthetic data from the diffusion model alone surpasses benefits from multiple algorithmic innovations in robust training, which is one of the simple yet extremely hard problems to solve for neural networks. The next step is to repeat the Table-1 experiments, but this time with robust training.

Visualizing real and synthetic images

For each data, we plot real images on the left and synthetic images on the right. Each row corresponds to a unique class while classes for real and synthetic data are identical.

Light     Dark

MNIST

Light     Dark

MNIST-M

Light     Dark

CIFAR-10

Light     Dark

GTSRB

Light     Dark

Celeb-A

Light     Dark

AFHQ

Light     Dark

Cars

Light     Dark

Flowers

Light     Dark

Melanoma (Skin cancer)

Light     Dark

Note: Real images for each dataset follow the same license as their respective dataset.

Owner
Vikash Sehwag
PhD candidate at Princeton University. Interested in problems at the intersection of Security, Privacy, and Machine leanring.
Vikash Sehwag
Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition How Fast Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100 Pre-trained Model

190 Dec 29, 2022
Code for "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" paper

UNICORN 🦄 Webpage | Paper | BibTex PyTorch implementation of "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" pap

118 Jan 06, 2023
Repositório da disciplina de APC, no segundo semestre de 2021

NOTAS FINAIS: https://github.com/fabiommendes/apc2018/blob/master/nota-final.pdf Algoritmos e Programação de Computadores Este é o Git da disciplina A

16 Dec 16, 2022
Face and other object detection using OpenCV and ML Yolo

Object-and-Face-Detection-Using-Yolo- Opencv and YOLO object and face detection is implemented. You only look once (YOLO) is a state-of-the-art, real-

Happy N. Monday 3 Feb 15, 2022
Official Implementation of VAT

Semantic correspondence Few-shot segmentation Cost Aggregation Is All You Need for Few-Shot Segmentation For more information, check out project [Proj

Hamacojr 114 Dec 27, 2022
Framework for evaluating ANNS algorithms on billion scale datasets.

Billion-Scale ANN http://big-ann-benchmarks.com/ Install The only prerequisite is Python (tested with 3.6) and Docker. Works with newer versions of Py

Harsha Vardhan Simhadri 132 Dec 24, 2022
Code for the paper "Unsupervised Contrastive Learning of Sound Event Representations", ICASSP 2021.

Unsupervised Contrastive Learning of Sound Event Representations This repository contains the code for the following paper. If you use this code or pa

Eduardo Fonseca 81 Dec 22, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

peng gao 42 Nov 26, 2022
Official repository for Natural Image Matting via Guided Contextual Attention

GCA-Matting: Natural Image Matting via Guided Contextual Attention The source codes and models of Natural Image Matting via Guided Contextual Attentio

Li Yaoyi 349 Dec 26, 2022
Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022

PyCRE Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022 Dependencies This project is developed

<a href=[email protected]"> 7 May 06, 2022
[NeurIPS 2021] Low-Rank Subspaces in GANs

Low-Rank Subspaces in GANs Figure: Image editing results using LowRankGAN on StyleGAN2 (first three columns) and BigGAN (last column). Low-Rank Subspa

112 Dec 28, 2022
Edge Restoration Quality Assessment

ERQA - Edge Restoration Quality Assessment ERQA - a full-reference quality metric designed to analyze how good image and video restoration methods (SR

MSU Video Group 27 Dec 17, 2022
Library for time-series-forecasting-as-a-service.

TIMEX TIMEX (referred in code as timexseries) is a framework for time-series-forecasting-as-a-service. Its main goal is to provide a simple and generi

Alessandro Falcetta 8 Jan 06, 2023
Official implementation for the paper: Permutation Invariant Graph Generation via Score-Based Generative Modeling

Permutation Invariant Graph Generation via Score-Based Generative Modeling This repo contains the official implementation for the paper Permutation In

64 Dec 29, 2022
YOLTv5 rapidly detects objects in arbitrarily large aerial or satellite images that far exceed the ~600×600 pixel size typically ingested by deep learning object detection frameworks

YOLTv5 rapidly detects objects in arbitrarily large aerial or satellite images that far exceed the ~600×600 pixel size typically ingested by deep learning object detection frameworks.

Adam Van Etten 145 Jan 01, 2023
Convert Mission Planner (ArduCopter) Waypoint Missions to Litchi CSV Format to execute on DJI Drones

Mission Planner to Litchi Convert Mission Planner (ArduCopter) Waypoint Surveys to Litchi CSV Format to execute on DJI Drones Litchi doesn't support S

Yaros 24 Dec 09, 2022
BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

172 Dec 23, 2022
You Only Hypothesize Once: Point Cloud Registration with Rotation-equivariant Descriptors

You Only Hypothesize Once: Point Cloud Registration with Rotation-equivariant Descriptors In this paper, we propose a novel local descriptor-based fra

Haiping Wang 80 Dec 15, 2022
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

130 Dec 05, 2022