A minimalist implementation of score-based diffusion model

Overview

sdeflow-light

This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper

"A Variational Perspective on Diffusion-Based Generative Models and Score Matching" by Chin-Wei Huang, Jae Hyun Lim and Aaron Courville [arXiv]

Also see the concurrent work by Yang Song & Conor Durkan where they used the same idea to obtain state-of-the-art likelihood estimates.

Experiments on Swissroll

Here's a Colab notebook which contains an example for training a model on the Swissroll dataset.

Open In Colab

In this notebook, you'll see how to train the model using score matching loss, how to evaluate the ELBO of the plug-in reverse SDE, and how to sample from it. It also includes a snippet to sample from a family of plug-in reverse SDEs (parameterized by λ) mentioned in Appendix C of the paper.

Below are the trajectories of λ=0 (the reverse SDE used in Song et al.) and λ=1 (equivalent ODE) when we plug in the learned score / drift function. This corresponds to Figure 5 of the paper. drawing drawing

Experiments on MNIST and CIFAR-10

This repository contains one main training loop (train_img.py). The model is trained to minimize the denoising score matching loss by calling the .dsm(x) loss function, and evaluated using the following ELBO, by calling .elbo_random_t_slice(x)

score-elbo

where the divergence (sum of the diagonal entries of the Jacobian) is estimated using the Hutchinson trace estimator.

It's a minimalist codebase in the sense that we do not use fancy optimizer (we only use Adam with the default setup) or learning rate scheduling. We use the modified U-net architecture from Denoising Diffusion Probabilistic Models by Jonathan Ho.

A key difference from Song et al. is that instead of parameterizing the score function s, here we parameterize the drift term a (where they are related by a=gs and g is the diffusion coefficient). That is, a is the U-net.

Parameterization: Our original generative & inference SDEs are

  • dX = mu dt + sigma dBt
  • dY = (-mu + sigma*a) ds + sigma dBs

We reparameterize it as

  • dX = (ga - f) dt + g dBt
  • dY = f ds + g dBs

by letting mu = ga - f, and sigma = g. (since f and g are fixed, we only have one degree of freedom, which is a). Alternatively, one can parameterize s (e.g. using the U-net), and just let a=gs.

How it works

Here's an example command line for running an experiment

python train_img.py --dataroot=[DATAROOT] --saveroot=[SAVEROOT] --expname=[EXPNAME] \
    --dataset=cifar --print_every=2000 --sample_every=2000 --checkpoint_every=2000 --num_steps=1000 \
    --batch_size=128 --lr=0.0001 --num_iterations=100000 --real=True --debias=False

Setting --debias to be False uses uniform sampling for the time variable, whereas setting it to be True uses a non-uniform sampling strategy to debias the gradient estimate described in the paper. Below are the bits-per-dim and the corresponding standard error of the test set recorded during training (orange for --debias=True and blue for --debias=False).

drawing drawing

Here are some samples (debiased on the right)

drawing drawing

It takes about 14 hrs to finish 100k iterations on a V100 GPU.

Owner
Chin-Wei Huang
Chin-Wei Huang
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv SAINT: Improved Neural Networks for Tabular Data via Row Atte

Gowthami Somepalli 284 Dec 21, 2022
Experiments on continual learning from a stream of pretrained models.

Ex-model CL Ex-model continual learning is a setting where a stream of experts (i.e. model's parameters) is available and a CL model learns from them

Antonio Carta 6 Dec 04, 2022
PyExplainer: A Local Rule-Based Model-Agnostic Technique (Explainable AI)

PyExplainer PyExplainer is a local rule-based model-agnostic technique for generating explanations (i.e., why a commit is predicted as defective) of J

AI Wizards for Software Management (AWSM) Research Group 14 Nov 13, 2022
TSIT: A Simple and Versatile Framework for Image-to-Image Translation

TSIT: A Simple and Versatile Framework for Image-to-Image Translation This repository provides the official PyTorch implementation for the following p

Liming Jiang 255 Nov 23, 2022
Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting

Autoformer (NeurIPS 2021) Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting Time series forecasting is a c

THUML @ Tsinghua University 847 Jan 08, 2023
Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs

Perceiver IO Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs Usage import torch from src.perceiver.

Timur Ganiev 111 Nov 15, 2022
《LXMERT: Learning Cross-Modality Encoder Representations from Transformers》(EMNLP 2020)

The Most Important Thing. Our code is developed based on: LXMERT: Learning Cross-Modality Encoder Representations from Transformers

53 Dec 16, 2022
A collection of SOTA Image Classification Models in PyTorch

A collection of SOTA Image Classification Models in PyTorch

sithu3 85 Dec 30, 2022
😊 Python module for face feature changing

PyWarping Python module for face feature changing Installation pip install pywarping If you get an error: No such file or directory: 'cmake': 'cmake',

Dopevog 10 Sep 10, 2021
A facial recognition doorbell system using a Raspberry Pi

Facial Recognition Doorbell This project expands on the person-detecting doorbell system to allow it to identify faces, and announce names accordingly

rydercalmdown 22 Apr 15, 2022
"Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback"

This is code repo for our EMNLP 2017 paper "Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback", which implements the A2C algorithm on top of a neural encoder-

Khanh Nguyen 131 Oct 21, 2022
网络协议2天集训

网络协议2天集训 抓包工具安装 Wireshark wireshark下载地址 Tcpdump CentOS yum install tcpdump -y Ubuntu apt-get install tcpdump -y k8s抓包测试环境 查看虚拟网卡veth pair 查看

120 Dec 12, 2022
Code for the paper "Implicit Representations of Meaning in Neural Language Models"

Implicit Representations of Meaning in Neural Language Models Preliminaries Create and set up a conda environment as follows: conda create -n state-pr

Belinda Li 39 Nov 03, 2022
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

22 Sep 22, 2022
Code for the upcoming CVPR 2021 paper

The Temporal Opportunist: Self-Supervised Multi-Frame Monocular Depth Jamie Watson, Oisin Mac Aodha, Victor Prisacariu, Gabriel J. Brostow and Michael

Niantic Labs 496 Dec 30, 2022
simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

Ramón Casero 1 Jan 07, 2022
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
A curated list of resources for Image and Video Deblurring

A curated list of resources for Image and Video Deblurring

Subeesh Vasu 1.7k Jan 01, 2023
Official PyTorch implementation for paper "Efficient Two-Stage Detection of Human–Object Interactions with a Novel Unary–Pairwise Transformer"

UPT: Unary–Pairwise Transformers This repository contains the official PyTorch implementation for the paper Frederic Z. Zhang, Dylan Campbell and Step

Frederic Zhang 109 Dec 20, 2022
Camera-caps - Examine the camera capabilities for V4l2 cameras

camera-caps This is a graphical user interface over the v4l2-ctl command line to

Jetsonhacks 25 Dec 26, 2022