A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Modern API wrapper for Genshin Impact built on asyncio and pydantic.

genshin.py Modern API wrapper for Genshin Impact built on asyncio and pydantic.

sadru 212 Jan 06, 2023
NES development tool made with Python and Lua

NES Builder NES development and romhacking tool made with Python and Lua Current Stage: Alpha Features Open source "Build" project, which exports vari

10 Aug 19, 2022
ChronoRace is a tool to accurately perform timed race conditions to circumvent application business logic.

ChronoRace is a tool to accurately perform timed race conditions to circumvent application business logic. I've found in my research that w

Tanner 64 Aug 04, 2022
An extension for Arma 3 that lets you write extensions in Python 3

An Arma 3 extension that lets you to write python extensions for Arma 3. And it's really simple and straightforward to use!

Lukasz Taczuk 48 Dec 18, 2022
Stori QA Automation Challenge

Stori-QA-Automation-Challenge This is the repository is created for the Stori QA Intern Automation Engineer Challenge! In this you can find the Requir

Daniel Castañeda 0 Feb 20, 2022
Safe temperature monitor for baby's room. Made for Raspberry Pi Pico.

Baby Safe Temperature Monitor This project is meant to build a temperature safety monitor for a baby or small child's room. Studies have shown the ris

Jeff Geerling 72 Oct 09, 2022
Tindicators is a Python library to calculate the values of various technical indicators

Tindicators is a Python library to calculate the values of various technical indicators

omar 3 Mar 03, 2022
A module comment generator for python

Module Comment Generator The comment style is as a tribute to the comment from the RA . The comment generator can parse the ast tree from the python s

飘尘 1 Oct 21, 2021
scap is a tool for putting code in places and for other purposes

Scap is the deployment script used by Wikimedia Foundation to publish code and configuration on production web servers.

Wikimedia 7 Nov 02, 2022
E5 自动续期

请选择跳转 新版本系统 (2021-2-9采用): 以后更新都在AutoApi,采用v0.0版本号覆盖式更新 AutoApi : 最新版 保留1到2个稳定的简易版,防止萌新大范围报错 AutoApi'X' : 稳定版1 ( 即本版AutpApiP ) AutoApiP ( 即v5.0,稳定版 ) —

95 Feb 15, 2021
String Spy is a project aimed at improving MacOS defenses.

String Spy is a project aimed at improving MacOS defenses. It allows users to constantly monitor all running processes for user-defined strings, and if it detects a process with such a string it will

10 Dec 13, 2022
Herramienta para pentesting web.

iTell 🕴 ¡Tool con herramientas para pentesting web! Metodos ❣ DDoS Attacks Recon Active Recon (Vulns) Extras (Bypass CF, FTP && SSH Bruter) Respons

1 Jul 28, 2022
Nateve transpiler developed with python.

Adam Adam is a Nateve Programming Language transpiler developed using Python. Nateve Nateve is a new general domain programming language open source i

Nateve 7 Jan 15, 2022
Gives criticality score for an open source project

Open Source Project Criticality Score (Beta) This project is maintained by members of the Securing Critical Projects WG. Goals Generate a criticality

Open Source Security Foundation (OpenSSF) 1.1k Dec 23, 2022
Recreate the joys of Office Assistant from the comfort of the Python interpreter

Recreate the joys of Office Assistant from the comfort of the Python interpreter.

Louis Sven Goulet 3 May 21, 2022
Get a list of the top-10 rejected libraries in your WhiteSource inventory

WhiteSource Top 10 Rejected Libraries Generate a spreadsheet listing the 10 most common libraries in your WhiteSource inventory that were rejected by

WhiteSource-PS-tools 10 Mar 23, 2022
Simple and easy to use python API for the COVID registration booking system of the math department @ unipd (torre archimede)

Simple and easy to use python API for the COVID registration booking system of the math department @ unipd (torre archimede). This API creates an interface with the official browser, with more useful

Guglielmo Camporese 4 Dec 24, 2021
ERPNext Easy Letterhead

ERPNext Easy Letterhead Intro Quality letterheads are a problem for non-technical users. So we've built (really hacked together) a slightly easier sol

Bantoo 3 Jan 02, 2023
Experimental proxy for dumping the unencrypted packet data from Brawl Stars (WIP)

Brawl Stars Proxy Experimental proxy for version 39.99 of Brawl Stars. It allows you to capture the packets being sent between the Brawl Stars client

4 Oct 29, 2021
Arknights gacha simulation written in Python

Welcome to arknights-gacha repository This is my shameless attempt of simulating Arknights gacha. Current supported banner types (with potential bugs)

Swyrin 3 May 07, 2022