PyTorch Implementation of DSB for Score Based Generative Modeling. Experiments managed using Hydra.

Overview

Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling

This repository contains the implementation for the paper Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling.

If using this code, please cite the paper:

    @article{de2021diffusion,
              title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
              author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
              journal={arXiv preprint arXiv:2106.01357},
              year={2021}
            }

Contributors

  • Valentin De Bortoli
  • James Thornton
  • Jeremy Heng
  • Arnaud Doucet

What is a Schrödinger bridge?

The Schrödinger Bridge (SB) problem is a classical problem appearing in applied mathematics, optimal control and probability; see [1, 2, 3]. In the discrete-time setting, it takes the following (dynamic) form. Consider as reference density p(x0:N) describing the process adding noise to the data. We aim to find p*(x0:N) such that p*(x0) = pdata(x0) and p*(xN) = pprior(xN) and minimize the Kullback-Leibler divergence between p* and p. In this work we introduce Diffusion Schrodinger Bridge (DSB), a new algorithm which uses score-matching approaches [4] to approximate the Iterative Proportional Fitting algorithm, an iterative method to find the solutions of the SB problem. DSB can be seen as a refinement of existing score-based generative modeling methods [5, 6].

Schrodinger bridge

Installation

This project can be installed from its git repository.

  1. Obtain the sources by:

    git clone https://github.com/anon284/schrodinger_bridge.git

or, if git is unavailable, download as a ZIP from GitHub https://github.com/.

  1. Install:

    conda env create -f conda.yaml

    conda activate bridge

  2. Download data examples:

    • CelebA: python data.py --data celeba --data_dir './data/'
    • MNIST: python data.py --data mnist --data_dir './data/'

How to use this code?

  1. Train Networks:
  • 2d: python main.py dataset=2d model=Basic num_steps=20 num_iter=5000
  • mnist python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>
  • celeba python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.

References

.. [1] Hans Föllmer Random fields and diffusion processes In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard A survey of the Schrödinger problem and some of its connections with optimal transport In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon Optimal Transport in Systems and Control In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyvärinen and Peter Dayan Estimation of non-normalized statistical models by score matching In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon Generative modeling by estimating gradients of the data distribution In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel Denoising diffusion probabilistic models In: Advances in Neural Information Processing Systems 2020

Owner
James Thornton
James Thornton
QueryInst: Parallelly Supervised Mask Query for Instance Segmentation

QueryInst is a simple and effective query based instance segmentation method driven by parallel supervision on dynamic mask heads, which outperforms previous arts in terms of both accuracy and speed.

Hust Visual Learning Team 386 Jan 08, 2023
Pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model'

RTK-PAD This is an official pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model', which is accepted by IEEE T

6 Aug 01, 2022
an implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 985 Jan 08, 2023
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data

Introduction PyTorch3D provides efficient, reusable components for 3D Computer Vision research with PyTorch. Key features include: Data structure for

Facebook Research 6.8k Jan 01, 2023
ColossalAI-Examples - Examples of training models with hybrid parallelism using ColossalAI

ColossalAI-Examples This repository contains examples of training models with Co

HPC-AI Tech 185 Jan 09, 2023
a spacial-temporal pattern detection system for home automation

Argos a spacial-temporal pattern detection system for home automation. Based on OpenCV and Tensorflow, can run on raspberry pi and notify HomeAssistan

Angad Singh 133 Jan 05, 2023
C3D is a modified version of BVLC caffe to support 3D ConvNets.

C3D C3D is a modified version of BVLC caffe to support 3D convolution and pooling. The main supporting features include: Training or fine-tuning 3D Co

Meta Archive 1.1k Nov 14, 2022
CNN designed for pansharpening

PROGRESSIVE BAND-SEPARATED CONVOLUTIONAL NEURAL NETWORK FOR MULTISPECTRAL PANSHARPENING This repository contains main code for the paper PROGRESSIVE B

SerendipitysX 3 Dec 29, 2021
Aligning Latent and Image Spaces to Connect the Unconnectable

About This repo contains the official implementation of the Aligning Latent and Image Spaces to Connect the Unconnectable paper. It is a GAN model whi

Ivan Skorokhodov 203 Jan 03, 2023
A certifiable defense against adversarial examples by training neural networks to be provably robust

DiffAI v3 DiffAI is a system for training neural networks to be provably robust and for proving that they are robust. The system was developed for the

SRI Lab, ETH Zurich 202 Dec 13, 2022
Final term project for Bayesian Machine Learning Lecture (XAI-623)

Mixquality_AL Final Term Project For Bayesian Machine Learning Lecture (XAI-623) Youtube Link The presentation is given in YoutubeLink Problem Formula

JeongEun Park 3 Jan 18, 2022
use machine learning to recognize gesture on raspberrypi

Raspberrypi_Gesture-Recognition use machine learning to recognize gesture on raspberrypi 說明 利用 tensorflow lite 訓練手部辨識模型 分辨 "剪刀"、"石頭"、"布" 之手勢 再將訓練模型匯入

1 Dec 10, 2021
Axel - 3D printed robotic hands and they controll with Raspberry Pi and Arduino combo

Axel It's our graduation project about 3D printed robotic hands and they control

0 Feb 14, 2022
[Machine Learning Engineer Basic Guide] 부스트캠프 AI Tech - Product Serving 자료

Boostcamp-AI-Tech-Product-Serving 부스트캠프 AI Tech - Product Serving 자료 Repository 구조 part1(MLOps 개론, Model Serving, 머신러닝 프로젝트 라이프 사이클은 별도의 코드가 없으며, part

Sung Yun Byeon 269 Dec 21, 2022
Instance-Dependent Partial Label Learning

Instance-Dependent Partial Label Learning Installation pip install -r requirements.txt Run the Demo benchmark-random mnist python -u main.py --gpu 0 -

17 Dec 29, 2022
prior-based-losses-for-medical-image-segmentation

Repository for papers: Benchmark: Effect of Prior-based Losses on Segmentation Performance: A Benchmark Midl: A Surprisingly Effective Perimeter-based

Rosana EL JURDI 9 Sep 07, 2022
Transparent Transformer Segmentation

Transparent Transformer Segmentation Introduction This repository contains the data and code for IJCAI 2021 paper Segmenting transparent object in the

谢恩泽 140 Jan 02, 2023
NHL 94 AI contests

nhl94-ai The end goals of this project is to: Train Models that play NHL 94 Support AI vs AI contests in NHL 94 Provide an improved AI opponent for NH

Mathieu Poliquin 2 Dec 06, 2021
Code for “ACE-HGNN: Adaptive Curvature ExplorationHyperbolic Graph Neural Network”

ACE-HGNN: Adaptive Curvature Exploration Hyperbolic Graph Neural Network This repository is the implementation of ACE-HGNN in PyTorch. Environment pyt

9 Nov 28, 2022
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Nicolas Girard 186 Jan 04, 2023