PyTorch reimplementation of Diffusion Models

Overview

PyTorch pretrained Diffusion Models

A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author's TensorFlow implementation.

Quickstart

Running

pip install -e git+https://github.com/pesser/pytorch_diffusion.git#egg=pytorch_diffusion
pytorch_diffusion_demo

will start a Streamlit demo. It is recommended to run the demo with a GPU available.

demo

Usage

Diffusion models with pretrained weights for cifar10, lsun-bedroom, lsun_cat or lsun_church can be loaded as follows:

from pytorch_diffusion import Diffusion

diffusion = Diffusion.from_pretrained("lsun_church")
samples = diffusion.denoise(4)
diffusion.save(samples, "lsun_church_sample_{:02}.png")

Prefix the name with ema_ to load the averaged weights that produce better results. The U-Net model used for denoising is available via diffusion.model and can also be instantiated on its own:

from pytorch_diffusion import Model

model = Model(resolution=32,
              in_channels=3,
              out_ch=3,
              ch=128,
              ch_mult=(1,2,2,2),
              num_res_blocks=2,
              attn_resolutions=(16,),
              dropout=0.1)

This configuration example corresponds to the model used on CIFAR-10.

Producing samples

If you installed directly from github, you can find the cloned repository in <venv path>/src/pytorch_diffusion for virtual environments, and <cwd>/src/pytorch_diffusion for global installs. There, you can run

python pytorch_diffusion/diffusion.py <name> <bs> <nb>

where <name> is one of cifar10, lsun-bedroom, lsun_cat, lsun_church, or one of these names prefixed with ema_, <bs> is the batch size and <nb> the number of batches. This will produce samples from the PyTorch models and save them to results/<name>/.

Results

Evaluating 50k samples with torch-fidelity gives

Dataset EMA Framework Model FID
CIFAR10 Train no PyTorch cifar10 12.13775
TensorFlow tf_cifar10 12.30003
yes PyTorch ema_cifar10 3.21213
TensorFlow tf_ema_cifar10 3.245872
CIFAR10 Validation no PyTorch cifar10 14.30163
TensorFlow tf_cifar10 14.44705
yes PyTorch ema_cifar10 5.274105
TensorFlow tf_ema_cifar10 5.325035

To reproduce, generate 50k samples from the converted PyTorch models provided in this repo with

`python pytorch_diffusion/diffusion.py <Model> 500 100`

and with

python -c "import convert as m; m.sample_tf(500, 100, which=['cifar10', 'ema_cifar10'])"

for the original TensorFlow models.

Running conversions

The converted pytorch checkpoints are provided for download. If you want to convert them on your own, you can follow the steps described here.

Setup

This section assumes your working directory is the root of this repository. Download the pretrained TensorFlow checkpoints. It should follow the original structure,

diffusion_models_release/
  diffusion_cifar10_model/
    model.ckpt-790000.data-00000-of-00001
    model.ckpt-790000.index
    model.ckpt-790000.meta
  diffusion_lsun_bedroom_model/
    ...
  ...

Set the environment variable TFROOT to the directory where you want to store the author's repository, e.g.

export TFROOT=".."

Clone the diffusion repository,

git clone https://github.com/hojonathanho/diffusion.git ${TFROOT}/diffusion

and install their required dependencies (pip install ${TFROOT}/requirements.txt). Then add the following to your PYTHONPATH:

export PYTHONPATH=".:./scripts:${TFROOT}/diffusion:${TFROOT}/diffusion/scripts:${PYTHONPATH}"

Testing operations

To test the pytorch implementations of the required operations against their TensorFlow counterparts under random initialization and random inputs, run

python -c "import convert as m; m.test_ops()"

Converting checkpoints

To load the pretrained TensorFlow models, copy the weights into the pytorch models, check for equality on random inputs and finally save the corresponding pytorch checkpoints, run

python -c "import convert as m; m.transplant_cifar10()"
python -c "import convert as m; m.transplant_cifar10(ema=True)"
python -c "import convert as m; m.transplant_lsun_bedroom()"
python -c "import convert as m; m.transplant_lsun_bedroom(ema=True)"
python -c "import convert as m; m.transplant_lsun_cat()"
python -c "import convert as m; m.transplant_lsun_cat(ema=True)"
python -c "import convert as m; m.transplant_lsun_church()"
python -c "import convert as m; m.transplant_lsun_church(ema=True)"

Pytorch checkpoints will be saved in

diffusion_models_converted/
  diffusion_cifar10_model/
    model-790000.ckpt
  ema_diffusion_cifar10_model/
    model-790000.ckpt
  diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  ema_diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  diffusion_lsun_cat_model/
    model-1761000.ckpt
  ema_diffusion_lsun_cat_model/
    model-1761000.ckpt
  diffusion_lsun_church_model/
    model-4432000.ckpt
  ema_diffusion_lsun_church_model/
    model-4432000.ckpt

Sample TensorFlow models

To produce N samples from each of the pretrained TensorFlow models, run

python -c "import convert as m; m.sample_tf(N)"

Pass a list of model names as keyword argument which to specify which models to sample from. Samples will be saved in results/.

Owner
Patrick Esser
Patrick Esser
Explore extreme compression for pre-trained language models

Code for paper "Exploring extreme parameter compression for pre-trained language models ICLR2022"

twinkle 16 Nov 14, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

3 Nov 23, 2022
Using CNN to mimic the driver based on training data from Torcs

Behavioural-Cloning-in-autonomous-driving Using CNN to mimic the driver based on training data from Torcs. Approach First, the data was collected from

Sudharshan 2 Jan 05, 2022
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
A Fast Monotone Rotating Shallow Water model

pyRSW A Fast Monotone Rotating Shallow Water model How fast? As fast as a sustained 2 Gflop/s per core on a 2.5 GHz cpu (or 2048 Gflop/s with 1024 cor

Guillaume Roullet 13 Sep 28, 2022
Level Based Customer Segmentation

level_based_customer_segmentation Level Based Customer Segmentation Persona Veri Seti kullanılarak müşteri segmentasyonu yapılmıştır. KOLONLAR : PRICE

Buse Yıldırım 6 Dec 21, 2021
SurfEmb (CVPR 2022) - SurfEmb: Dense and Continuous Correspondence Distributions

SurfEmb SurfEmb: Dense and Continuous Correspondence Distributions for Object Pose Estimation with Learnt Surface Embeddings Rasmus Laurvig Haugard, A

Rasmus Haugaard 56 Nov 19, 2022
Anonymize BLM Protest Images

Anonymize BLM Protest Images This repository automates @BLMPrivacyBot, a Twitter bot that shows the anonymized images to help keep protesters safe. Us

Stanford Machine Learning Group 40 Oct 13, 2022
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021

This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'. Installation To install, use conda with conda env c

14 Sep 21, 2022
Learning Facial Representations from the Cycle-consistency of Face (ICCV 2021)

Learning Facial Representations from the Cycle-consistency of Face (ICCV 2021) This repository contains the code for our ICCV2021 paper by Jia-Ren Cha

Jia-Ren Chang 40 Dec 27, 2022
This program can detect your face and add an Christams hat on the top of your head

Auto_Christmas This program can detect your face and add a Christmas hat to the top of your head. just run the Auto_Christmas.py, then you can see the

3 Dec 22, 2021
AnimationKit: AI Upscaling & Interpolation using Real-ESRGAN+RIFE

ALPHA 2.5: Frostbite Revival (Released 12/23/21) Changelog: [ UI ] Chained design. All steps link to one another! Use the master override toggles to s

87 Nov 16, 2022
Implementation of Nyström Self-attention, from the paper Nyströmformer

Nyström Attention Implementation of Nyström Self-attention, from the paper Nyströmformer. Yannic Kilcher video Install $ pip install nystrom-attention

Phil Wang 95 Jan 02, 2023
Pretrained Pytorch face detection (MTCNN) and recognition (InceptionResnet) models

Face Recognition Using Pytorch Python 3.7 3.6 3.5 Status This is a repository for Inception Resnet (V1) models in pytorch, pretrained on VGGFace2 and

Tim Esler 3.3k Jan 04, 2023
Revealing and Protecting Labels in Distributed Training

Revealing and Protecting Labels in Distributed Training

Google Interns 0 Nov 09, 2022
Single Red Blood Cell Hydrodynamic Traps Via the Generative Design

Rbc-traps-generative-design - The generative design for single red clood cell hydrodynamic traps using GEFEST framework

Natural Systems Simulation Lab 4 Jun 16, 2022
Code for the paper "Training GANs with Stronger Augmentations via Contrastive Discriminator" (ICLR 2021)

Training GANs with Stronger Augmentations via Contrastive Discriminator (ICLR 2021) This repository contains the code for reproducing the paper: Train

Jongheon Jeong 174 Dec 29, 2022
Repository for the electrical and ICT benchmark model developed in the ERIGrid 2.0 project.

Benchmark Model Electrical and ICT System This repository contains the documentation, code, and models for the electrical and ICT benchmark model deve

ERIGrid 2.0 1 Nov 29, 2021
Contextual Attention Localization for Offline Handwritten Text Recognition

CALText This repository contains the source code for CALText model introduced in "CALText: Contextual Attention Localization for Offline Handwritten T

0 Feb 17, 2022