Distributional Sliced-Wasserstein distance code

Related tags

Deep LearningDSW
Overview

Distributional Sliced Wasserstein distance

This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Generative Modeling". The work was done during the residency at VinAI Research, Hanoi, Vietnam.

Requirement

  • python3.6
  • pytorch 1.3
  • torchvision
  • numpy
  • tqdm

Train on MNIST and FMNIST

python mnist.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --dataset='MNIST'
    --model-type='DSWD'\
    --latent-size=32 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD|MGSWNN|JMGSWNN|MGSWD|JMGSWD)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=10

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=10

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Train on CELEBA and CIFAR10 and LSUN

python main.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --model-type='DSWD'\
    --dataset='CELEBA'
    --latent-size=100 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=1

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=1

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Some generated images

MNIST generated images

MNIST

CELEBA generated images

MNIST

LSUN generated images

MNIST

Owner
VinAI Research
VinAI Research
Official code for paper "Optimization for Oriented Object Detection via Representation Invariance Loss".

Optimization for Oriented Object Detection via Representation Invariance Loss By Qi Ming, Zhiqiang Zhou, Lingjuan Miao, Xue Yang, and Yunpeng Dong. Th

ming71 56 Nov 28, 2022
A lightweight deep network for fast and accurate optical flow estimation.

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation The official PyTorch implementation of FastFlowNet (ICRA 2021). Authors: Lingtong

Tone 161 Jan 03, 2023
Simple helper library to convert a collection of numpy data to tfrecord, and build a tensorflow dataset from the tfrecord.

numpy2tfrecord Simple helper library to convert a collection of numpy data to tfrecord, and build a tensorflow dataset from the tfrecord. Installation

Ryo Yonetani 2 Jan 16, 2022
Image transformations designed for Scene Text Recognition (STR) data augmentation. Published at ICCV 2021 Workshop on Interactive Labeling and Data Augmentation for Vision.

Data Augmentation for Scene Text Recognition (ICCV 2021 Workshop) (Pronounced as "strog") Paper Arxiv Why it matters? Scene Text Recognition (STR) req

Rowel Atienza 152 Dec 28, 2022
Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust.

Subspace Adversarial Training Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust. However,

15 Sep 02, 2022
Unsupervised Pre-training for Person Re-identification (LUPerson)

LUPerson Unsupervised Pre-training for Person Re-identification (LUPerson). The repository is for our CVPR2021 paper Unsupervised Pre-training for Per

143 Dec 24, 2022
EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21)

EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21) Citation If y

addisonwang 18 Nov 11, 2022
TART - A PyTorch implementation for Transition Matrix Representation of Trees with Transposed Convolutions

TART This project is a PyTorch implementation for Transition Matrix Representati

Lee Sael 2 Jan 19, 2022
Resources for the Ki testnet challenge

Ki Testnet Challenge This repository hosts ki-testnet-challenge. A set of scripts and resources to be used for the Ki Testnet Challenge What is the te

Ki Foundation 23 Aug 08, 2022
Official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch.

Multi-speaker DGP This repository provides official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch. O

sarulab-speech 24 Sep 07, 2022
Official Implementation of DDOD (Disentangle your Dense Object Detector), ACM MM2021

Disentangle Your Dense Object Detector This repo contains the supported code and configuration files to reproduce object detection results of Disentan

loveSnowBest 51 Jan 07, 2023
A GPT, made only of MLPs, in Jax

MLP GPT - Jax (wip) A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units. Working Pytorch implementat

Phil Wang 53 Sep 27, 2022
Implementation of ICCV2021(Oral) paper - VMNet: Voxel-Mesh Network for Geodesic-aware 3D Semantic Segmentation

VMNet: Voxel-Mesh Network for Geodesic-Aware 3D Semantic Segmentation Created by Zeyu HU Introduction This work is based on our paper VMNet: Voxel-Mes

HU Zeyu 82 Dec 27, 2022
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022
This repository comes with the paper "On the Robustness of Counterfactual Explanations to Adverse Perturbations"

Robust Counterfactual Explanations This repository comes with the paper "On the Robustness of Counterfactual Explanations to Adverse Perturbations". I

Marco 5 Dec 20, 2022
Hierarchical Metadata-Aware Document Categorization under Weak Supervision (WSDM'21)

Hierarchical Metadata-Aware Document Categorization under Weak Supervision This project provides a weakly supervised framework for hierarchical metada

Yu Zhang 53 Sep 17, 2022
Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.

ConvNeXt-TF This repository provides TensorFlow / Keras implementations of different ConvNeXt [1] variants. It also provides the TensorFlow / Keras mo

Sayak Paul 87 Dec 06, 2022
Official PyTorch repo for JoJoGAN: One Shot Face Stylization

JoJoGAN: One Shot Face Stylization This is the PyTorch implementation of JoJoGAN: One Shot Face Stylization. Abstract: While there have been recent ad

1.3k Dec 29, 2022
2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation

2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation Authors: Ge-Peng Ji*, Yu-Cheng Chou*, Deng-Ping Fan, Geng Che

Ge-Peng Ji (Daniel) 85 Dec 30, 2022
Code for the RA-L (ICRA) 2021 paper "SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition"

SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition [ArXiv+Supplementary] [IEEE Xplore RA-L 2021] [ICRA 2021 YouTube Video]

Sourav Garg 63 Dec 12, 2022