A minimalist implementation of score-based diffusion model

Overview

sdeflow-light

This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper

"A Variational Perspective on Diffusion-Based Generative Models and Score Matching" by Chin-Wei Huang, Jae Hyun Lim and Aaron Courville [arXiv]

Also see the concurrent work by Yang Song & Conor Durkan where they used the same idea to obtain state-of-the-art likelihood estimates.

Experiments on Swissroll

Here's a Colab notebook which contains an example for training a model on the Swissroll dataset.

Open In Colab

In this notebook, you'll see how to train the model using score matching loss, how to evaluate the ELBO of the plug-in reverse SDE, and how to sample from it. It also includes a snippet to sample from a family of plug-in reverse SDEs (parameterized by λ) mentioned in Appendix C of the paper.

Below are the trajectories of λ=0 (the reverse SDE used in Song et al.) and λ=1 (equivalent ODE) when we plug in the learned score / drift function. This corresponds to Figure 5 of the paper. drawing drawing

Experiments on MNIST and CIFAR-10

This repository contains one main training loop (train_img.py). The model is trained to minimize the denoising score matching loss by calling the .dsm(x) loss function, and evaluated using the following ELBO, by calling .elbo_random_t_slice(x)

score-elbo

where the divergence (sum of the diagonal entries of the Jacobian) is estimated using the Hutchinson trace estimator.

It's a minimalist codebase in the sense that we do not use fancy optimizer (we only use Adam with the default setup) or learning rate scheduling. We use the modified U-net architecture from Denoising Diffusion Probabilistic Models by Jonathan Ho.

A key difference from Song et al. is that instead of parameterizing the score function s, here we parameterize the drift term a (where they are related by a=gs and g is the diffusion coefficient). That is, a is the U-net.

Parameterization: Our original generative & inference SDEs are

  • dX = mu dt + sigma dBt
  • dY = (-mu + sigma*a) ds + sigma dBs

We reparameterize it as

  • dX = (ga - f) dt + g dBt
  • dY = f ds + g dBs

by letting mu = ga - f, and sigma = g. (since f and g are fixed, we only have one degree of freedom, which is a). Alternatively, one can parameterize s (e.g. using the U-net), and just let a=gs.

How it works

Here's an example command line for running an experiment

python train_img.py --dataroot=[DATAROOT] --saveroot=[SAVEROOT] --expname=[EXPNAME] \
    --dataset=cifar --print_every=2000 --sample_every=2000 --checkpoint_every=2000 --num_steps=1000 \
    --batch_size=128 --lr=0.0001 --num_iterations=100000 --real=True --debias=False

Setting --debias to be False uses uniform sampling for the time variable, whereas setting it to be True uses a non-uniform sampling strategy to debias the gradient estimate described in the paper. Below are the bits-per-dim and the corresponding standard error of the test set recorded during training (orange for --debias=True and blue for --debias=False).

drawing drawing

Here are some samples (debiased on the right)

drawing drawing

It takes about 14 hrs to finish 100k iterations on a V100 GPU.

Owner
Chin-Wei Huang
Chin-Wei Huang
Config files for my GitHub profile.

Canalyst Candas Data Science Library Name Canalyst Candas Description Built by a former PM / analyst to give anyone with a little bit of Python knowle

Canalyst Candas 13 Jun 24, 2022
These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations"

Few-shot-NLEs These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations". You can find the smal

Yordan Yordanov 0 Oct 21, 2022
Deep Implicit Moving Least-Squares Functions for 3D Reconstruction

DeepMLS: Deep Implicit Moving Least-Squares Functions for 3D Reconstruction This repository contains the implementation of the paper: Deep Implicit Mo

103 Dec 22, 2022
[NeurIPS2021] Code Release of Learning Transferable Perturbations

Learning Transferable Adversarial Perturbations This is an official release of the paper Learning Transferable Adversarial Perturbations. The code is

Krishna Kanth 17 Nov 11, 2022
A Protein-RNA Interface Predictor Based on Semantics of Sequences

PRIP PRIP:A Protein-RNA Interface Predictor Based on Semantics of Sequences installation gensim==3.8.3 matplotlib==3.1.3 xgboost==1.3.3 prettytable==2

李优 0 Mar 25, 2022
yolox_backbone is a deep-learning library and is a collection of YOLOX Backbone models.

YOLOX-Backbone yolox-backbone is a deep-learning library and is a collection of YOLOX backbone models. Install pip install yolox-backbone Load a Pret

Yonghye Kwon 21 Dec 28, 2022
Fastshap: A fast, approximate shap kernel

fastshap: A fast, approximate shap kernel fastshap was designed to be: Fast Calculating shap values can take an extremely long time. fastshap utilizes

Samuel Wilson 22 Sep 24, 2022
✅ How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.

How Robust are Fact Checking Systems on Colloquial Claims? Official PyTorch implementation of our NAACL paper: Byeongchang Kim*, Hyunwoo Kim*, Seokhee

Byeongchang Kim 19 Mar 15, 2022
Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning

Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning. Circuit Training is an open-s

Google Research 479 Dec 25, 2022
tensorrt int8 量化yolov5 4.0 onnx模型

onnx模型转换为 int8 tensorrt引擎

123 Dec 28, 2022
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
A PyTorch implementation of "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" (ICLR 2019).

APPNP ⠀ A PyTorch implementation of Predict then Propagate: Graph Neural Networks meet Personalized PageRank (ICLR 2019). Abstract Neural message pass

Benedek Rozemberczki 329 Dec 30, 2022
This repository is the official implementation of Using Time-Series Privileged Information for Provably Efficient Learning of Prediction Models

Using Time-Series Privileged Information for Provably Efficient Learning of Prediction Models Link to paper Abstract We study prediction of future out

Rickard Karlsson 2 Aug 19, 2022
This repo is duplication of jwyang/faster-rcnn.pytorch

Faster RCNN Pytorch This repo is duplication of jwyang/faster-rcnn.pytorch C/C++ code are removed and easier to study. Python 3.8.5 Ubuntu 20.04.1 LTS

Kim Jihwan 1 Jan 14, 2022
Implementation of the ivis algorithm as described in the paper Structure-preserving visualisation of high dimensional single-cell datasets.

Implementation of the ivis algorithm as described in the paper Structure-preserving visualisation of high dimensional single-cell datasets.

beringresearch 285 Jan 04, 2023
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
Ppq - A powerful offline neural network quantization tool with custimized IR

PPL Quantization Tool(PPL 量化工具) PPL Quantization Tool (PPQ) is a powerful offlin

605 Jan 03, 2023
Dynamic Environments with Deformable Objects (DEDO)

DEDO - Dynamic Environments with Deformable Objects DEDO is a lightweight and customizable suite of environments with deformable objects. It is aimed

Rika 32 Dec 22, 2022
Deploy optimized transformer based models on Nvidia Triton server

🤗 Hugging Face Transformer submillisecond inference 🤯 and deployment on Nvidia Triton server Yes, you can perfom inference with transformer based mo

Lefebvre Sarrut Services 1.2k Jan 05, 2023
Latex code for making neural networks diagrams

PlotNeuralNet Latex code for drawing neural networks for reports and presentation. Have a look into examples to see how they are made. Additionally, l

Haris Iqbal 18.6k Jan 01, 2023