Release for Improved Denoising Diffusion Probabilistic Models

Overview

improved-diffusion

This is the codebase for Improved Denoising Diffusion Probabilistic Models.

Usage

This section of the README walks through how to train and sample from a model.

Installation

Clone this repository and navigate to it in your terminal. Then run:

pip install -e .

This should install the improved_diffusion python package that the scripts depend on.

Preparing Data

The training code reads images from a directory of image files. In the datasets folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.

For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).

The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass --data_dir path/to/images to the training script, and it will take care of the rest.

Training

To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

Here are some changes we experiment with, and how to set them in the flags:

  • Learned sigmas: add --learn_sigma True to MODEL_FLAGS
  • Cosine schedule: change --noise_schedule linear to --noise_schedule cosine
  • Reweighted VLB: add --use_kl True to DIFFUSION_FLAGS and add --schedule_sampler loss-second-moment to TRAIN_FLAGS.
  • Class-conditional: add --class_cond True to MODEL_FLAGS.

Once you have setup your hyper-parameters, you can run an experiment like so:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

You may also want to train in a distributed manner. In this case, run the same command with mpiexec:

mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

When training in a distributed manner, you must manually divide the --batch_size argument by the number of ranks. In lieu of distributed training, you may use --microbatch 16 (or --microbatch 1 in extreme memory-limited cases) to reduce memory usage.

The logs and saved models will be written to a logging directory determined by the OPENAI_LOGDIR environment variable. If it is not set, then a temporary directory will be created in /tmp.

Sampling

The above training script saves checkpoints to .pt files in the logging directory. These checkpoints will have names like ema_0.9999_200000.pt and model200000.pt. You will likely want to sample from the EMA models, since those produce much better samples.

Once you have a path to your model, you can generate a large batch of samples like so:

python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

Again, this will save results to a logging directory. Samples are saved as a large npz file, where arr_0 in the file is a large batch of samples.

Just like for training, you can run image_sample.py through MPI to use multiple GPUs and machines.

You can change the number of sampling steps using the --timestep_respacing argument. For example, --timestep_respacing 250 uses 250 steps to sample. Passing --timestep_respacing ddim250 is similar, but uses the uniform stride from the DDIM paper rather than our stride.

To sample using DDIM, pass --use_ddim True.

Owner
OpenAI
OpenAI
The blancmange curve can be visually built up out of triangle wave functions if the infinite sum is approximated by finite sums of the first few terms.

Blancmange-curve The blancmange curve can be visually built up out of triangle wave functions if the infinite sum is approximated by finite sums of th

Shankar Mahadevan L 1 Nov 30, 2021
This is the Code Institute student template for Gitpod.

Welcome AnaG0307, This is the Code Institute student template for Gitpod. We have preinstalled all of the tools you need to get started. It's perfectl

0 Feb 02, 2022
This is an online course where you can learn and master the skill of low-level performance analysis and tuning.

Performance Ninja Class This is an online course where you can learn to find and fix low-level performance issues, for example CPU cache misses and br

Denis Bakhvalov 1.2k Dec 30, 2022
🎉 🎉 PyComp - Java Code compiler written in python.

🎉 🎉 PyComp Java Code compiler written in python. This is yet another compiler meant for babcock students project which was created using pure python

Alumona Benaiah 5 Nov 30, 2022
Converts a base copy of Pokemon BDSP's masterdatas into a more readable and editable Pokemon Showdown Format.

Showdown-BDSP-Converter Converts a base copy of Pokemon BDSP's masterdatas into a more readable and editable Pokemon Showdown Format. Download the lat

Alden Mo 2 Jan 02, 2022
Automatically load and dump your dataclasses 📂🙋

file dataclasses Installation By default, filedataclasses comes with support for JSON files only. To support other formats like YAML and TOML, filedat

Alon 1 Dec 30, 2021
A Web app to Cross-Seed torrents in Deluge/qBittorrent/Transmission

SeedCross A Web app to Cross-Seed torrents in Deluge/qBittorrent/Transmission based on CrossSeedAutoDL Require Jackett Deluge/qBittorrent/Transmission

ccf2012 76 Dec 19, 2022
WGGCommute - Adding Commute Times to WG-Gesucht Listings

WGGCommute - Adding Commute Times to WG-Gesucht Listings This is a barebones implementation of a chrome extension that can be used to add commute time

Jannis 2 Jul 20, 2022
Python Script to add OpenGapps, Magisk, libhoudini translation library and libndk translation library to waydroid !

Waydroid Extras Script Script to add gapps and other stuff to waydroid ! Installation/Usage "lzip" is required for this script to work, install it usi

Casu Al Snek 331 Jan 02, 2023
Social reading and reviewing, decentralized with ActivityPub

BookWyrm Social reading and reviewing, decentralized with ActivityPub Contents Joining BookWyrm Contributing About BookWyrm What it is and isn't The r

BookWyrm 1.4k Jan 08, 2023
the classic version Of torrentleechx #Unmaintained #Archived

TorrentleechX-Classic Old Modified Version Repo #Unmaintained #Archived for support join here working example group Leech Here For Any Issues/Imroveme

XcodersHub 18 Jan 30, 2022
A site that went kinda viral that lets you put Bernie Sanders in places

Bernie In Places An app that accidentally went viral! Read the story in WIRED here Install First, create a python virtual environment, and install all

310 Aug 22, 2022
A variant caller for the GBA gene using WGS data

Gauchian: WGS-based GBA variant caller Gauchian is a targeted variant caller for the GBA gene based on a whole-genome sequencing (WGS) BAM file. Gauch

Illumina 16 Oct 13, 2022
This script can be used to get unlimited Gb for WARP.

Warp-Unlimited-GB This script can be used to get unlimited Gb for WARP. How to use Change the value of the 'referrer' to warp id of yours You can down

Anix Sam Saji 1 Feb 14, 2022
This is the improvised version of Dobot Magician which can be implemented for Dobot M1

pydobotM1 This is the edited driver for Dobot M1 version of the original pydobot library intended for use with the Dobot Magician. Here's what you nee

Shaik Abdullah 2 Jul 11, 2022
Un Assistente Vocale scritto in Python e altamente personalizzabile

Un Assistente Vocale scritto in Python e altamente personalizzabile

Marco 2 May 06, 2022
Tutorials for on-ramping to StarkNet

Full-Stack StarkNet Repo containing the code for a short tutorial series I wrote while diving into StarkNet and learning Cairo. Aims to onramp existin

Sam Barnes 71 Dec 07, 2022
ArinjoyTheDev 1 Jul 17, 2022
This repository contains all the data analytics projects that I've worked on in python.

93_Python_Data_Analytics_Projects This repository contains all the data analytics projects that I've worked on in python. No. Name 01 001_Cervical_Can

Milaan Parmar / Милан пармар / _米兰 帕尔马 267 Jan 06, 2023
A Python version of Canvacord

A copy of canvacord made in python! Installation Run any of these commands in terminal: Mac / Linux pip install canvacord Windows python -m pip insta

10 Mar 28, 2022