A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
pybicyclewheel calulates the required spoke length for bicycle wheels

pybicyclewheel pybicyclewheel calulates the required spoke length for bicycle wheels. (under construcion) - homepage further readings wikipedia bicyc

karl 0 Aug 24, 2022
Santa's kitchen helper for python

Santa's Kitchen Helper Introduction/Overview Contents UX User Stories Design Wireframes Color Scheme Typography Imagery Features Exisiting Features Fe

Paul Browne 4 May 31, 2022
SimilarWeb for Team ACT v.0.0.1

SimilarWeb for Team ACT v.0.0.1 This module has been built to provide a better environment specifically for Similarweb in Team ACT. This module itself

Sunkyeong Lee 0 Dec 29, 2021
APC Power Usage is an application which shows power consuption overtime for UPS units manufactured by APC.

APC Power Usage Introduction APC Power Usage is an application which shows power consuption overtime for UPS units manufactured by APC. Screenshoots G

Stefan Kondinski 3 Oct 08, 2021
Print 'text color' and 'text format' on Term with Python

term-printer Print 'text color' and 'text format' on Term with Python ※ It may not work depending on the OS and shell used. PIP $ pip install term-pri

ななといつ 10 Nov 12, 2022
Cross-platform .NET Core pre-commit hooks

dotnet-core-pre-commit Cross-platform .NET Core pre-commit hooks How to use Add this to your .pre-commit-config.yaml - repo: https://github.com/juan

Juan Odicio 5 Jul 20, 2021
Open Source Management System for Botanic Garden Collections.

BotGard 3.0 Open Source Management System for Botanic Garden Collections built and maintained by netzkolchose.de in cooperation with the Botanical Gar

netzkolchose.de 1 Dec 15, 2021
COVID-19 case tracker in Dash

covid_dashy_personal This is a personal project to build a simple COVID-19 tracker for Australia with Dash. Key functions of this dashy will be to Dis

Jansen Zhang 1 Nov 30, 2021
GCP Scripts and API Client Toolss

GCP Scripts and API Client Toolss Script Authentication The scripts and CLI assume GCP Application Default Credentials are set. Credentials can be set

3 Feb 21, 2022
Prints values and types during compilation!

Compile-Time Printer Compile-Time Printer prints values and types at compile-time in C++. Teaser test.cpp compile-time-printer

43 Dec 26, 2022
A lightweight solution for local Particle development.

neopo A lightweight solution for local Particle development. Features Builds Particle projects locally without any overhead. Compatible with Particle

Nathan Robinson 19 Jan 01, 2023
March-madness - March Madness results 1985-2021

march-madness Results for all 2,268 NCAA Division I Men's Basketball Tournament games since the modern format was introduced in 1985. Includes years,

Darik Harter 2 Feb 26, 2022
Izy - Python functions and classes that make python even easier than it is

izy Python functions and classes that make it even easier! You will wonder why t

5 Jul 04, 2022
Automatically remove user join messages when the user leaves the server.

CleanLeave Automatically remove user join messages when the user leaves the server. Installation You will need to install poetry to run this bot local

11 Sep 19, 2022
A clipboard where a user can add and retrieve multiple items to and from (resp) from the clipboard cache.

A clipboard where a user can add and retrieve multiple items to and from (resp) from the clipboard cache.

Gaurav Bhattacharjee 2 Feb 07, 2022
This tool don't used illegal ativity

ETHICALTOOL This tool for only educational purposes don't used illegal ativity @onlinehacking this tool for pkg update && pkg upgrade && pkg install g

Mrkarthick 4 Dec 23, 2021
A simple app that helps to train quick calculations.

qtcounter A simple app that helps to train quick calculations. Usage Manual Clone the repo in a folder using git clone https://github.com/Froloket64/q

0 Nov 27, 2021
Hitchhikers-guide - The Hitchhiker's Guide to Data Science for Social Good

Welcome to the Hitchhiker's Guide to Data Science for Social Good. What is the Data Science for Social Good Fellowship? The Data Science for Social Go

Data Science for Social Good 907 Jan 01, 2023
ticguide: quick + painless TESS observing information

ticguide: quick + painless TESS observing information Complementary to the TESS observing tool tvguide (see also WTV), which tells you if your target

Ashley Chontos 5 Nov 05, 2022
E5 自动续期

请选择跳转 新版本系统 (2021-2-9采用): 以后更新都在AutoApi,采用v0.0版本号覆盖式更新 AutoApi : 最新版 保留1到2个稳定的简易版,防止萌新大范围报错 AutoApi'X' : 稳定版1 ( 即本版AutpApiP ) AutoApiP ( 即v5.0,稳定版 ) —

95 Feb 15, 2021