A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
FollowSpot is a comprehensive audition tracking fullstack web application for entertainment industry professionals.

FollowSpot is a comprehensive audition tracking fullstack web application for entertainment industry professionals. This app allows users to store information/media for all of their auditions while a

Jen Brissman 9 Jul 12, 2022
Convert three types of color in your clipboard and paste it to the color property (gamma correct)

ColorPaster [Blender Addon] Convert three types of color in your clipboard and paste it to the color property (gamma correct) How to Use Hover your mo

13 Oct 31, 2022
People tracker on the Internet: OSINT analysis and research tool by Jose Pino

trape (stable) v2.0 People tracker on the Internet: Learn to track the world, to avoid being traced. Trape is an OSINT analysis and research tool, whi

Jose Pino 7.3k Dec 30, 2022
How to use Microsoft Bing to search for leaks?

Installation In order to install the project, you need install its dependencies: $ pip3 install -r requirements.txt Add your Bing API key to bingKey.t

Ernestas Kardzys 2 Sep 21, 2022
Secret santa is a fun and easy way to get together with your friends and/or family with a gift for them.

Vaccine Validator Tool to validate domestic New Zealand vaccine passes Create a new virtual environment: python3 -m venv ./venv Activate virtual envi

2 Dec 06, 2021
Monochrome's API, implemented with Deta Base and Deta Drive.

Monochrome Monochrome's API, implemented with Deta Base and Deta Drive. Create a free account on Deta to test this out! Most users will prefer the Mon

Monochrome 5 Sep 22, 2022
Autogenerador tonto de paquetes para ROSCPP

Autogenerador tonto de paquetes para ROSCPP Autogenerador de paquetes que usan C++ en ROS. Por ahora tiene las siguientes capacidades: Permite crear p

1 Nov 26, 2021
Python implementation of the Learning Time-Series Shapelets method, that learns a shapelet-based time-series classifier with gradient descent.

shaplets Python implementation of the Learning Time-Series Shapelets method by Josif Grabocka et al., that learns a shapelet-based time-series classif

Mohamed Haseeb 187 Dec 14, 2022
A python script to simplify recompiling, signing and installing reverse engineered android apps.

urszi.py A python script to simplify the Uninstall Recompile Sign Zipalign Install cycle when reverse engineering Android applications. It checks if d

Ahmed Harmouche 4 Jun 24, 2022
Repo created for the purpose of adding any kind of programs and projects

Programs and Project Repository A repository for adding programs and projects of any kind starting from beginners level to expert ones Contributing to

Unicorn Dev Community 3 Nov 02, 2022
This project recreates the R-based RCy3 Cytoscape Automation library as a Python package.

Python library for calling Cytoscape Automation via CyREST

Cytoscape Consortium 40 Dec 22, 2022
Cute study buddy that helps you study with the Pomodoro technique!

study-buddy Cute study buddy that helps you study with the Pomodoro (or Animedoro) technique! Kirby The Kirby folder has a Kirby, pink-themed Pomodoro

Ethan Emmanuel 1 Jan 19, 2022
A server shell for you to play with Powered by Django + Nginx + Postgres + Bootstrap + Celery.

A server shell for you to play with Powered by Django + Nginx + Postgres + Bootstrap + Celery.

Mengting Song 1 Jan 10, 2022
Some ideas and tools to develop Python 3.8 plugins for GIMP 2.99.4

gimp-python-development Some ideas and tools to develop Python 3.8 plugins for GIMP 2.99.4. GIMP 2.99.4 is the latest unstable pre-release of GIMP 3.

Ismael Benito 53 Sep 25, 2022
An extended, game oriented, turtle

Burtle A Better TURTLE. Makes making games easier. write less do more!! Documentation & guide: https://alannxq.github.io/burtle/ Installation pip inst

5 May 19, 2022
WGGCommute - Adding Commute Times to WG-Gesucht Listings

WGGCommute - Adding Commute Times to WG-Gesucht Listings This is a barebones implementation of a chrome extension that can be used to add commute time

Jannis 2 Jul 20, 2022
List of resources for learning Category Theory

A curated list of resources for studying category theory. As resources aimed at mathematicians are abundant, this list is aimed at materials whose target audience is not people with a graduate-level

Bruno Gavranović 100 Jan 01, 2023
Experiments with Tox plugin system

The project is an attempt to add to the tox some missing out of the box functionality. Basically it is just an extension for the tool that will be loa

Volodymyr Vitvitskyi 30 Nov 26, 2022
Python script to commit to your github for a perfect commit streak. This is purely for education purposes, please don't use this script to do bad stuff.

Daily-Git-Commit Commit to repo every day for the perfect commit streak Requirments pip install -r requirements.txt Setup Download this repository. Cr

JareBear 34 Dec 14, 2022