A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Get you an ultimate lexer generator using Fable; port OCaml sedlex to FSharp, Python and more!

NOTE: currently we support interpreted mode and Python source code generation. It's EASY to compile compiled_unit into source code for C#, F# and othe

Taine Zhao 15 Aug 06, 2022
Esercizi di Python svolti per il biennio di Tecnologie Informatiche.

Esercizi di Python Un piccolo aiuto per Sofia che nel 2° quadrimestre inizierà Python :) Questo repository (termine tecnico di Git) puoi trovare tutti

Leonardo Essam Dei Rossi 2 Nov 07, 2022
The Official interpreter for the Pix programming language.

The official interpreter for the Pix programming language. Pix Pix is a programming language dedicated to readable syntax and usability Q) Is Pix the

Pix 6 Sep 25, 2022
A simple single-color identicon generator

Identicons What are identicons? Setup: git clone https://github.com/vjdad4m/identicons.git cd identicons pip3 install -r requirements.txt chmod +x

Adam Vajda 1 Oct 31, 2021
Would upload anything I do with/related to brainfuck

My Brainfu*k Repo Basically wanted to create something with Brainfu*k but realized that with the smol brain I have, I need to see the cell values real

Rafeed 1 Mar 22, 2022
A faster copy of nell's comet nuker

Astro a faster copy of nell's comet nuker also nell uses external libraries like it's cocaine man never learned to use ansi color codes (ily nell) (On

horrid 8 Aug 15, 2022
Tesla App Update Differences Extractor

Tesla App Update Differences Extractor Python program that finds the differences between two versions of the Tesla App. When Tesla updates the app a l

Adrian 5 Apr 11, 2022
Telegram bot to remove the forwarded tag from messages.

Anonymous Sender Bot @AnonySendBot Telegram bot to remove the forwarded tag from messages. Table of Contents Usage Deploy To Heroku Local Deploying En

Stark Bots 26 Nov 24, 2022
Open Source Repository for CFD Solvers

Background and Validation This wiki is built in Notion. Here are all the tips you need to contribute. General Background Flow over a cylinder The proj

1 Dec 30, 2021
Brython (Browser Python) is an implementation of Python 3 running in the browser

brython Brython (Browser Python) is an implementation of Python 3 running in the browser, with an interface to the DOM elements and events. Here is a

5.9k Jan 02, 2023
A password genarator/manager for passwords uesing a pseudorandom number genarator

pseudorandom-password-genarator a password genarator/manager for passwords uesing a pseudorandom number genarator when you give the program a word eg

1 Nov 18, 2021
A fluid medium for storing, relating, and surfacing thoughts.

Conceptarium A fluid medium for storing, relating, and surfacing thoughts. Read more... Instructions The conceptarium takes up about 1GB RAM when runn

115 Dec 19, 2022
CPLib is the abbreviation of Competitive Programming Library.

CPLib CPLib is the abbreviation of Competitive Programming Library. It aims to be a general template and optimization library for competitive programm

12 Oct 16, 2021
Problem statements on System Design and Software Architecture as part of Arpit's System Design Masterclass

Problem statements on System Design and Software Architecture as part of Arpit's System Design Masterclass

Relog 1.1k Jan 04, 2023
Project Faros is a reference implimentation of Red Hat OpenShift 4 on small footprint, bare-metal clusters.

Project Faros Project Faros is a reference implimentation of Red Hat OpenShift 4 on small footprint, bare-metal clusters. The project includes referen

project: Faros 9 Jul 18, 2022
A student information management system in Python

Student-information-management-system 本项目是一个学生信息管理系统,这个项目是用Python语言实现的,也实现了图形化界面的显示,同时也实现了管理员端,学生端两个登陆入口,同时底层使用的是Redis做的数据持久化。 This project is a stude

liuyunfei 7 Nov 15, 2022
The best free and open-source automated time tracker. Cross-platform, extensible, privacy-focused.

Records what you do so that you can know how you've spent your time. All in a secure way where you control the data. Website — Forum — Documentation —

ActivityWatch 7.8k Jan 09, 2023
A collection of useful functions for writers to analyze text/stories.

AuthorTools AuthorTools provides a multitude of functions for easily analyzing (your?) writing. AuthorTools is made especially for creative writers wi

1 Jan 14, 2022
How did Covid affect businesses?

NYC_Business_Analysis How did Covid affect businesses? COVID's effect on NYC businesses We all know that businesses in NYC have been affected by COVID

AK 1 Jan 15, 2022
A common, beautiful interface to tabular data, no matter the format

rows No matter in which format your tabular data is: rows will import it, automatically detect types and give you high-level Python objects so you can

Álvaro Justen 834 Jan 03, 2023