PyTorch reimplementation of Diffusion Models

Overview

PyTorch pretrained Diffusion Models

A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author's TensorFlow implementation.

Quickstart

Running

pip install -e git+https://github.com/pesser/pytorch_diffusion.git#egg=pytorch_diffusion
pytorch_diffusion_demo

will start a Streamlit demo. It is recommended to run the demo with a GPU available.

demo

Usage

Diffusion models with pretrained weights for cifar10, lsun-bedroom, lsun_cat or lsun_church can be loaded as follows:

from pytorch_diffusion import Diffusion

diffusion = Diffusion.from_pretrained("lsun_church")
samples = diffusion.denoise(4)
diffusion.save(samples, "lsun_church_sample_{:02}.png")

Prefix the name with ema_ to load the averaged weights that produce better results. The U-Net model used for denoising is available via diffusion.model and can also be instantiated on its own:

from pytorch_diffusion import Model

model = Model(resolution=32,
              in_channels=3,
              out_ch=3,
              ch=128,
              ch_mult=(1,2,2,2),
              num_res_blocks=2,
              attn_resolutions=(16,),
              dropout=0.1)

This configuration example corresponds to the model used on CIFAR-10.

Producing samples

If you installed directly from github, you can find the cloned repository in <venv path>/src/pytorch_diffusion for virtual environments, and <cwd>/src/pytorch_diffusion for global installs. There, you can run

python pytorch_diffusion/diffusion.py <name> <bs> <nb>

where <name> is one of cifar10, lsun-bedroom, lsun_cat, lsun_church, or one of these names prefixed with ema_, <bs> is the batch size and <nb> the number of batches. This will produce samples from the PyTorch models and save them to results/<name>/.

Results

Evaluating 50k samples with torch-fidelity gives

Dataset EMA Framework Model FID
CIFAR10 Train no PyTorch cifar10 12.13775
TensorFlow tf_cifar10 12.30003
yes PyTorch ema_cifar10 3.21213
TensorFlow tf_ema_cifar10 3.245872
CIFAR10 Validation no PyTorch cifar10 14.30163
TensorFlow tf_cifar10 14.44705
yes PyTorch ema_cifar10 5.274105
TensorFlow tf_ema_cifar10 5.325035

To reproduce, generate 50k samples from the converted PyTorch models provided in this repo with

`python pytorch_diffusion/diffusion.py <Model> 500 100`

and with

python -c "import convert as m; m.sample_tf(500, 100, which=['cifar10', 'ema_cifar10'])"

for the original TensorFlow models.

Running conversions

The converted pytorch checkpoints are provided for download. If you want to convert them on your own, you can follow the steps described here.

Setup

This section assumes your working directory is the root of this repository. Download the pretrained TensorFlow checkpoints. It should follow the original structure,

diffusion_models_release/
  diffusion_cifar10_model/
    model.ckpt-790000.data-00000-of-00001
    model.ckpt-790000.index
    model.ckpt-790000.meta
  diffusion_lsun_bedroom_model/
    ...
  ...

Set the environment variable TFROOT to the directory where you want to store the author's repository, e.g.

export TFROOT=".."

Clone the diffusion repository,

git clone https://github.com/hojonathanho/diffusion.git ${TFROOT}/diffusion

and install their required dependencies (pip install ${TFROOT}/requirements.txt). Then add the following to your PYTHONPATH:

export PYTHONPATH=".:./scripts:${TFROOT}/diffusion:${TFROOT}/diffusion/scripts:${PYTHONPATH}"

Testing operations

To test the pytorch implementations of the required operations against their TensorFlow counterparts under random initialization and random inputs, run

python -c "import convert as m; m.test_ops()"

Converting checkpoints

To load the pretrained TensorFlow models, copy the weights into the pytorch models, check for equality on random inputs and finally save the corresponding pytorch checkpoints, run

python -c "import convert as m; m.transplant_cifar10()"
python -c "import convert as m; m.transplant_cifar10(ema=True)"
python -c "import convert as m; m.transplant_lsun_bedroom()"
python -c "import convert as m; m.transplant_lsun_bedroom(ema=True)"
python -c "import convert as m; m.transplant_lsun_cat()"
python -c "import convert as m; m.transplant_lsun_cat(ema=True)"
python -c "import convert as m; m.transplant_lsun_church()"
python -c "import convert as m; m.transplant_lsun_church(ema=True)"

Pytorch checkpoints will be saved in

diffusion_models_converted/
  diffusion_cifar10_model/
    model-790000.ckpt
  ema_diffusion_cifar10_model/
    model-790000.ckpt
  diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  ema_diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  diffusion_lsun_cat_model/
    model-1761000.ckpt
  ema_diffusion_lsun_cat_model/
    model-1761000.ckpt
  diffusion_lsun_church_model/
    model-4432000.ckpt
  ema_diffusion_lsun_church_model/
    model-4432000.ckpt

Sample TensorFlow models

To produce N samples from each of the pretrained TensorFlow models, run

python -c "import convert as m; m.sample_tf(N)"

Pass a list of model names as keyword argument which to specify which models to sample from. Samples will be saved in results/.

Owner
Patrick Esser
Patrick Esser
DCGAN LSGAN WGAN-GP DRAGAN PyTorch

Recommendation Our GAN based work for facial attribute editing - AttGAN. News 8 April 2019: We re-implement these GANs by Tensorflow 2! The old versio

Zhenliang He 408 Nov 30, 2022
Official implementation of "Dynamic Anchor Learning for Arbitrary-Oriented Object Detection" (AAAI2021).

DAL This project hosts the official implementation for our AAAI 2021 paper: Dynamic Anchor Learning for Arbitrary-Oriented Object Detection [arxiv] [c

ming71 215 Nov 28, 2022
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
Keras Image Embeddings using Contrastive Loss

Image to Embedding projection in vector space. Implementation in keras and tensorflow of batch all triplet loss for one-shot/few-shot learning.

Shravan Anand K 5 Mar 21, 2022
Computer Vision Paper Reviews with Key Summary of paper, End to End Code Practice and Jupyter Notebook converted papers

Computer-Vision-Paper-Reviews Computer Vision Paper Reviews with Key Summary along Papers & Codes. Jonathan Choi 2021 The repository provides 100+ Pap

Jonathan Choi 2 Mar 17, 2022
Semantic Edge Detection with Diverse Deep Supervision

Semantic Edge Detection with Diverse Deep Supervision This repository contains the code for our IJCV paper: "Semantic Edge Detection with Diverse Deep

Yun Liu 12 Dec 31, 2022
Code for the paper: Fighting Fake News: Image Splice Detection via Learned Self-Consistency

Fighting Fake News: Image Splice Detection via Learned Self-Consistency [paper] [website] Minyoung Huh *12, Andrew Liu *1, Andrew Owens1, Alexei A. Ef

minyoung huh (jacob) 174 Dec 09, 2022
This project contains an implemented version of Face Detection using OpenCV and Mediapipe. This is a code snippet and can be used in projects.

Live-Face-Detection Project Description: In this project, we will be using the live video feed from the camera to detect Faces. It will also detect so

Hassan Shahzad 3 Oct 02, 2021
Original code for "Zero-Shot Domain Adaptation with a Physics Prior"

Zero-Shot Domain Adaptation with a Physics Prior [arXiv] [sup. material] - ICCV 2021 Oral paper, by Attila Lengyel, Sourav Garg, Michael Milford and J

Attila Lengyel 40 Dec 21, 2022
YKKDetector For Python

YKKDetector OpenCVを利用した機械学習データをもとに、VRChatのスクリーンショットなどからYKKさん(もとい「幽狐族のお姉様」)を検出できるソフトウェアです。 マニュアル こちらから実行環境のセットアップから解説する詳細なマニュアルをご覧いただけます。 ライセンス 本ソフトウェア

あんふぃとらいと 5 Dec 07, 2021
Data Consistency for Magnetic Resonance Imaging

Data Consistency for Magnetic Resonance Imaging Data Consistency (DC) is crucial for generalization in multi-modal MRI data and robustness in detectin

Dimitris Karkalousos 19 Dec 12, 2022
atmaCup #11 の Public 4th / Pricvate 5th Solution のリポジトリです。

#11 atmaCup 2021-07-09 ~ 2020-07-21 に行われた #11 [初心者歓迎! / 画像編] atmaCup のリポジトリです。結果は Public 4th / Private 5th でした。 フレームワークは PyTorch で、実装は pytorch-image-m

Tawara 12 Apr 07, 2022
We are More than Our JOints: Predicting How 3D Bodies Move

We are More than Our JOints: Predicting How 3D Bodies Move Citation This repo contains the official implementation of our paper MOJO: @inproceedings{Z

72 Oct 20, 2022
Dynamic Graph Event Detection

DyGED Dynamic Graph Event Detection Get Started pip install -r requirements.txt TODO Paper link to arxiv, and how to cite. Twitter Weather dataset tra

Mert Koşan 3 May 09, 2022
Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction

Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction Requirements The code has been tested running under Python 3.7.4, with the foll

zshicode 84 Jan 01, 2023
Boostcamp AI Tech 3rd / Basic Paper reading w.r.t Embedding

Boostcamp AI Tech 3rd : Basic Paper Reading w.r.t Embedding TL;DR 1992년부터 2018년도까지 이루어진 word/sentence embedding의 중요한 줄기를 이루는 기초 논문 스터디를 진행하고자 합니다. 논

Soyeon Kim 14 Nov 14, 2022
Multi-modal Content Creation Model Training Infrastructure including the FACT model (AI Choreographer) implementation.

AI Choreographer: Music Conditioned 3D Dance Generation with AIST++ [ICCV-2021]. Overview This package contains the model implementation and training

Google Research 365 Dec 30, 2022
TargetAllDomainObjects - A python wrapper to run a command on against all users/computers/DCs of a Windows Domain

TargetAllDomainObjects A python wrapper to run a command on against all users/co

Podalirius 19 Dec 13, 2022
(NeurIPS 2021) Realistic Evaluation of Transductive Few-Shot Learning

Realistic evaluation of transductive few-shot learning Introduction This repo contains the code for our NeurIPS 2021 submitted paper "Realistic evalua

Olivier Veilleux 14 Dec 13, 2022