Distributional Sliced-Wasserstein distance code

Related tags

Deep LearningDSW
Overview

Distributional Sliced Wasserstein distance

This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Generative Modeling". The work was done during the residency at VinAI Research, Hanoi, Vietnam.

Requirement

  • python3.6
  • pytorch 1.3
  • torchvision
  • numpy
  • tqdm

Train on MNIST and FMNIST

python mnist.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --dataset='MNIST'
    --model-type='DSWD'\
    --latent-size=32 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD|MGSWNN|JMGSWNN|MGSWD|JMGSWD)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=10

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=10

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Train on CELEBA and CIFAR10 and LSUN

python main.py \
    --datadir='./' \
    --outdir='./result' \
    --batch-size=512 \
    --seed=16 \
    --p=2 \
    --lr=0.0005 \
    --model-type='DSWD'\
    --dataset='CELEBA'
    --latent-size=100 \ 
model-type in (SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER)

Options for Sliced distances (number of projections used to approximate the distances)

--num-projection=1000

Options for Max Sliced-Wasserstein distance and Distributional distances (number of gradient steps for find the max slice or the optimal push-forward function):

--niter=1

Options for Distributional Sliced-Wasserstein Distance and Distributional Generalized Sliced-Wasserstein Distance (regularization strength)

--lam=1

Options for Generalized Wasserstein Distance (using circular function for Generalized Radon Transform)

--r=1000;\
--g='circular'

Some generated images

MNIST generated images

MNIST

CELEBA generated images

MNIST

LSUN generated images

MNIST

Owner
VinAI Research
VinAI Research
Generative Models for Graph-Based Protein Design

Graph-Based Protein Design This repo contains code for Generative Models for Graph-Based Protein Design by John Ingraham, Vikas Garg, Regina Barzilay

John Ingraham 159 Dec 15, 2022
A 2D Visual Localization Framework based on Essential Matrices [ICRA2020]

A 2D Visual Localization Framework based on Essential Matrices This repository provides implementation of our paper accepted at ICRA: To Learn or Not

Qunjie Zhou 27 Nov 07, 2022
[ICRA 2022] An opensource framework for cooperative detection. Official implementation for OPV2V.

OpenCOOD OpenCOOD is an Open COOperative Detection framework for autonomous driving. It is also the official implementation of the ICRA 2022 paper OPV

Runsheng Xu 322 Dec 23, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning

AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning (NeurIPS 2020) Introduction AdaShare is a novel and differentiable approach fo

94 Dec 22, 2022
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

neural-combinatorial-rl-pytorch PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning. I have implemented the basic

Patrick E. 454 Jan 06, 2023
NovelD: A Simple yet Effective Exploration Criterion

NovelD: A Simple yet Effective Exploration Criterion Intro This is an implementation of the method proposed in NovelD: A Simple yet Effective Explorat

29 Dec 05, 2022
An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.

NNI Doc | 简体中文 NNI (Neural Network Intelligence) is a lightweight but powerful toolkit to help users automate Feature Engineering, Neural Architecture

Microsoft 12.4k Dec 31, 2022
Implementation of the master's thesis "Temporal copying and local hallucination for video inpainting".

Temporal copying and local hallucination for video inpainting This repository contains the implementation of my master's thesis "Temporal copying and

David Álvarez de la Torre 1 Dec 02, 2022
3D Human Pose Machines with Self-supervised Learning

3D Human Pose Machines with Self-supervised Learning Keze Wang, Liang Lin, Chenhan Jiang, Chen Qian, and Pengxu Wei, “3D Human Pose Machines with Self

Chenhan Jiang 398 Dec 20, 2022
natural image generation using ConvNets

The Eyescream Project Generating Natural Images using Neural Networks. For our research summary on this work, please read the Arxiv paper: http://arxi

Meta Archive 601 Nov 23, 2022
Pynomial - a lightweight python library for implementing the many confidence intervals for the risk parameter of a binomial model

Pynomial - a lightweight python library for implementing the many confidence intervals for the risk parameter of a binomial model

Demetri Pananos 9 Oct 04, 2022
PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation Viraj Prabhu, Shivam Khare, Deeks

Viraj Prabhu 46 Dec 24, 2022
An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters

CNN-Filter-DB An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters Paul Gavrikov, Janis Keuper Paper: htt

Paul Gavrikov 18 Dec 30, 2022
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Geoff Pleiss 1.4k Dec 25, 2022
Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.

Core ML Tools Use coremltools to convert machine learning models from third-party libraries to the Core ML format. The Python package contains the sup

Apple 3k Jan 08, 2023
piSTAR Lab is a modular platform built to make AI experimentation accessible and fun. (pistar.ai)

piSTAR Lab WARNING: This is an early release. Overview piSTAR Lab is a modular deep reinforcement learning platform built to make AI experimentation a

piSTAR Lab 0 Aug 01, 2022
WORD: Revisiting Organs Segmentation in the Whole Abdominal Region

WORD: Revisiting Organs Segmentation in the Whole Abdominal Region. This repository provides the codebase and dataset for our work WORD: Revisiting Or

Healthcare Intelligence Laboratory 71 Jan 07, 2023
Trained on Simulated Data, Tested in the Real World

Trained on Simulated Data, Tested in the Real World

livox 43 Nov 18, 2022
OpenMMLab Image Classification Toolbox and Benchmark

Introduction English | 简体中文 MMClassification is an open source image classification toolbox based on PyTorch. It is a part of the OpenMMLab project. D

OpenMMLab 1.8k Jan 03, 2023