A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Msgpack serialization/deserialization library for Python, written in Rust using PyO3 and rust-msgpack. Reboot of orjson. msgpack.org[Python]

ormsgpack ormsgpack is a fast msgpack library for Python. It is a fork/reboot of orjson It serializes faster than msgpack-python and deserializes a bi

Aviram Hassan 139 Dec 30, 2022
Python tools for working with Orbit Ephemeris Messages (OEMs).

Python Orbit Ephemeris Message tools Python tools for working with Orbit Ephemeris Messages (OEMs). Development Status Installation The oem package is

Brad Sease 4 Apr 06, 2022
Wordle Solver

Wordle Solver Installation Install the following onto your computer: Python 3.10.x Download Page Run pip install -r requirements.txt Instructions To r

John Bucknam 1 Feb 15, 2022
Digdata presented 'BrandX' as a clothing brand that wants to know the best places to set up a 'pop up' store.

Digdata presented 'BrandX' as a clothing brand that wants to know the best places to set up a 'pop up' store. I used the dataset given to write a program that ranks these places.

Mahmoud 1 Dec 11, 2021
Provide Prometheus url_sd compatible API Endpoint with data from Netbox

netbox-plugin-prometheus-sd Provide Prometheus http_sd compatible API Endpoint with data from Netbox. HTTP SD is a new feature in Prometheus and not a

Felix Peters 66 Dec 19, 2022
Py4J enables Python programs to dynamically access arbitrary Java objects

Py4J Py4J enables Python programs running in a Python interpreter to dynamically access Java objects in a Java Virtual Machine. Methods are called as

Barthelemy Dagenais 1k Jan 02, 2023
Usando Multi Player Perceptron e Regressão Logistica para classificação de SPAM

Relatório dos procedimentos executados e resultados obtidos. Objetivos Treinar um modelo para classificação de SPAM usando o dataset train_data. Class

André Mediote 1 Feb 02, 2022
A Python software implementation of the Intel 4004 processor

Pyntel4004 A Python software implementation of the Intel 4004 processor. General Information Two pass assembler using the original mnemonics, directiv

alshapton 5 Oct 01, 2022
Margin Calculator - Personally tailored investment tool

Margin Calculator - Personally tailored investment tool

1 Jul 19, 2022
Decentralized intelligent voting application.

DiVA Decentralized intelligent voting application. Hack the North 2021. Inspiration Following the previous US election, many voters were fearful that

Ali Shariatmadari 4 Jun 05, 2022
p5 is a Python package based on the core ideas of Processing.

p5 p5 is a Python library that provides high level drawing functionality to help you quickly create simulations and interactive art using Python. It c

p5py 645 Jan 04, 2023
Transform a Google Drive server into a VFX pipeline ready server

Google Drive VFX Server VFX Pipeline About The Project Quick tutorial to setup a Google Drive Server for multiple machines access, and VFX Pipeline on

Valentin Beaumont 17 Jun 27, 2022
An improved version of the common ˙pacman -S˙

BetterPacmanLook An improved version of the common pacman -S. Installation I know that this is probably one of the worst solutions and i will be worki

1 Nov 06, 2021
A simple assembly- and brainfuck-inspired stack-based language

asm-stackfuck A simple assembly- and brainfuck-inspired stack-based language. The language has a few goals: Be stack-based Look like assembly Have a s

Nils Trinity 1 Feb 06, 2022
Open source book about making Python packages.

Python packages Tomas Beuzen & Tiffany Timbers Python packages are a core element of the Python programming language and are how you create organized,

Python Packages 169 Jan 06, 2023
Completed task 1 and task 2 at LetsGrowMore as a data science intern.

LetsGrowMore-Internship Completed task 1 and task 2 at LetsGrowMore as a data science intern. Task 1- Task 2- Creating a Decision Tree classifier and

Sanjyot Panure 1 Jan 16, 2022
Minutaria is a basic educational Python timer used to learn python and software testing libraries.

minutaria minutaria is a basic educational Python timer. The project is educational, it aims to teach myself programming, python programming, python's

1 Jul 16, 2021
YunoHost is an operating system aiming to simplify as much as possible the administration of a server.

YunoHost is an operating system aiming to simplify as much as possible the administration of a server. This repository corresponds to the core code, written mostly in Python and Bash.

YunoHost 1.5k Jan 09, 2023
Little tool in python to watch anime from the terminal (the better way to watch anime)

anipy-cli Little tool in python to watch anime from the terminal (the better way to watch anime) Has a resume playback function when picking from Hist

sdao 97 Dec 29, 2022
A free and powerful system for awareness and research of the American judicial system.

CourtListener Started in 2009, CourtListener.com is the main initiative of Free Law Project. The goal of CourtListener.com is to provide high quality

Free Law Project 332 Dec 25, 2022