toy_gradlogp
This repo implements some toy examples of the following score matching algorithms in PyTorch:
ssm-vr
: sliced score matching with variance reductionssm
: sliced score matchingdeen
: deep energy estimator networksdsm
: denoisnig score matching
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow >= 2.3.0
- PyTorch >= 1.8.0
Install from PyPI
pip install toy_gradlogp
Or install the latest version from this repo
pip install git+https://github.com.Ending2015a/[email protected]
Examples
The examples are placed in toy_gradlogp/run/
Train an energy model
Run ssm-vr
on 2spirals
dataset (don't forget to add --gpu
to enable gpu)
python -m toy_gradlogp.run.train_energy --gpu --loss ssm-vr --data 2spirals
To see the full options, type --help
command:
python -m toy_gradlogp.run.train_energy --help
usage: train_energy.py [-h] [--logdir LOGDIR]
[--data {8gaussians,2spirals,checkerboard,rings}]
[--loss {ssm-vr,ssm,deen,dsm}]
[--noise {radermacher,sphere,gaussian}] [--lr LR]
[--size SIZE] [--eval_size EVAL_SIZE]
[--batch_size BATCH_SIZE] [--n_epochs N_EPOCHS]
[--n_slices N_SLICES] [--n_steps N_STEPS] [--eps EPS]
[--gpu] [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ]
[--vis_freq VIS_FREQ]
optional arguments:
-h, --help show this help message and exit
--logdir LOGDIR
--data {8gaussians,2spirals,checkerboard,rings}
dataset
--loss {ssm-vr,ssm,deen,dsm}
loss type
--noise {radermacher,sphere,gaussian}
noise type
--lr LR learning rate
--size SIZE dataset size
--eval_size EVAL_SIZE
dataset size for evaluation
--batch_size BATCH_SIZE
training batch size
--n_epochs N_EPOCHS number of epochs to train
--n_slices N_SLICES number of slices for sliced score matching
--n_steps N_STEPS number of steps for langevin dynamics
--eps EPS noise scale for langevin dynamics
--gpu enable gpu
--log_freq LOG_FREQ logging frequency (unit: epoch)
--eval_freq EVAL_FREQ
evaluation frequency (unit: epoch)
--vis_freq VIS_FREQ visualization frequency (unit: epoch)
Results
Tips: The larger density has a lower energy!
8gaussians
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
2spirals
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
checkerboard
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |
rings
Algorithm | Results |
---|---|
ssm-vr |
|
ssm |
|
deen |
|
dsm |