A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
India Today Astrology App

India Today Astrology App Introduction This repository contains the code for the Backend setup of the India Today Astrology app as a part of their rec

Pranjal Pratap Dubey 4 May 07, 2022
The Doodle Master seeks to turn your UI mockups into real code.

Doodle Master The Doodle Master seeks to turn your UI mockups into real code. Currently this repository just serves to demonstrate a Proof Of Concept

Karanbir Chahal 2.4k Dec 09, 2022
Convert-Decimal-to-Binary-Octal-and-Hexadecimal

Convert-Decimal-to-Binary-Octal-and-Hexadecimal We have a number in a decimal number, and we have to convert it into a binary, octal, and hexadecimal

Maanyu M 2 Oct 08, 2021
TrackGen - The simplest tropical cyclone track map generator

TrackGen - The simplest tropical cyclone track map generator Usage Each line is a point to be plotted on the map Each field gives information about th

TrackGen 6 Jul 20, 2022
WhyNotWin11 - Detection Script to help identify why your PC isn't Windows 11 Release Ready

WhyNotWin11 - Detection Script to help identify why your PC isn't Windows 11 Release Ready

Robert C. Maehl 5.9k Dec 31, 2022
Demo of a WAM Prolog implementation in Python

Prol: WAM demo This is a simplified Warren Abstract Machine (WAM) implementation for Prolog, that showcases the main instructions, compiling, register

Bruno Kim Medeiros Cesar 62 Dec 26, 2022
Pequenos programas variados que estou praticando e implementando, leia o Read.me!

my-small-programs Pequenos programas variados que estou praticando e implementando! Arquivo: automacao Automacao de processos de rotina com código Pyt

Léia Rafaela 43 Nov 22, 2022
万能通用对象池,可以池化任意自定义类型的对象。

pip install universal_object_pool 此包能够将一切任意类型的python对象池化,是万能池,适用范围远大于单一用途的mysql连接池 http连接池等。 框架使用对象池包,自带实现了4个对象池。可以直接开箱用这四个对象池,也可以作为例子学习对象池用法。

12 Dec 15, 2022
CPython extension implementing Shared Transactional Memory with native-looking interface

CPython extension implementing Shared Transactional Memory with native-looking interface

21 Jul 22, 2022
Python Control Systems Library

The Python Control Systems Library is a Python module that implements basic operations for analysis and design of feedback control systems.

Control Systems Library for Python 1.3k Jan 06, 2023
A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

BoMb-OT Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via

Khai Ba Nguyen 18 Nov 14, 2022
Adjust the white point, gamma or make your XDR display darker without losing HDR peak luminance or the ability to adjust display brightness

XDR Tuner Adjust the white point, gamma or make your XDR display darker without losing HDR peak luminance or the ability to adjust display brightness

François Simond 16 Dec 28, 2022
Project aims to map out common user behavior on the computer

User-Behavior-Mapping-Tool Project aims to map out common user behavior on the computer. Most of the code is based on the research by kacos2000 found

trustedsec 136 Dec 23, 2022
Ultimate Microsoft Edge Uninstaller!

Ultimate Microsoft Edge Uninstaller

1 Feb 08, 2022
Addon for Blender 2.8+ that automatically creates NLA tracks for all animations. Useful for GLTF export.

PushDownAll An addon for Blender 2.8+ that runs Push Down on all animations, creating NLA tracks for each. This is useful if you have an object with m

Cory Petkovsek 16 Oct 06, 2022
propuestas electorales de los candidatos a constituyentes, Chile 2021

textos-constituyentes propuestas electorales de los candidatos a constituyentes, Chile 2021 Programas descargados desde https://elecciones2021.servel.

Sergio Lucero 6 Nov 19, 2021
A curses based mpd client with basic functionality and album art.

Miniplayer A curses based mpd client with basic functionality and album art. After installation, the player can be opened from the terminal with minip

Tristan Ferrua 102 Dec 24, 2022
Minos-python - A framework which helps you create reactive microservices in Python

minos-python Summary [TODO] Packages minos-microservice-aggregate minos-microser

Minos Framework 380 Jan 04, 2023
1 May 12, 2022
Programmatic startup/shutdown of ASGI apps.

asgi-lifespan Programmatically send startup/shutdown lifespan events into ASGI applications. When used in combination with an ASGI-capable HTTP client

Florimond Manca 129 Dec 27, 2022