High-fidelity performance metrics for generative models in PyTorch

Overview

High-fidelity performance metrics for generative models in PyTorch

Documentation Status TestStatus PyPiVersion PyPiDownloads Twitter Follow

This repository provides precise, efficient, and extensible implementations of the popular metrics for generative model evaluation, including:

  • Inception Score (ISC)
  • Fréchet Inception Distance (FID)
  • Kernel Inception Distance (KID)
  • Perceptual Path Length (PPL)

Precision: Unlike many other reimplementations, the values produced by torch-fidelity match reference implementations up to machine precision. This allows using torch-fidelity for reporting metrics in papers instead of scattered and slow reference implementations. Read more about precision

Efficiency: Feature sharing between different metrics saves recomputation time, and an additional caching level avoids recomputing features and statistics whenever possible. High efficiency allows using torch-fidelity in the training loop, for example at the end of every epoch. Read more about efficiency

Extensibility: Going beyond 2D image generation is easy due to high modularity and abstraction of the metrics from input data, models, and feature extractors. For example, one can swap out InceptionV3 feature extractor for a one accepting 3D scan volumes, such as used in MRI. Read more about extensibility

TLDR; fast and reliable GAN evaluation in PyTorch

Installation

pip install torch-fidelity

See also: Installing the latest GitHub code

Usage Examples with Command Line

Below are three examples of using torch-fidelity to evaluate metrics from the command line. See more examples in the documentation.

Simple

Inception Score of CIFAR-10 training split:

> fidelity --gpu 0 --isc --input1 cifar10-train

inception_score_mean: 11.23678
inception_score_std: 0.09514061

Medium

Inception Score of a directory of images stored in ~/images/:

> fidelity --gpu 0 --isc --input1 ~/images/

Pro

Efficient computation of ISC and PPL for input1, and FID and KID between a generative model stored in ~/generator.onnx and CIFAR-10 training split:

> fidelity \
  --gpu 0 \
  --isc \
  --fid \
  --kid \
  --ppl \
  --input1 ~/generator.onnx \ 
  --input1-model-z-type normal \
  --input1-model-z-size 128 \
  --input1-model-num-samples 50000 \ 
  --input2 cifar10-train 

See also: Other usage examples

Quick Start with Python API

When it comes to tracking the performance of generative models as they train, evaluating metrics after every epoch becomes prohibitively expensive due to long computation times. torch_fidelity tackles this problem by making full use of caching to avoid recomputing common features and per-metric statistics whenever possible. Computing all metrics for 50000 32x32 generated images and cifar10-train takes only 2 min 26 seconds on NVIDIA P100 GPU, compared to >10 min if using original codebases. Thus, computing metrics 20 times over the whole training cycle makes overall training time just one hour longer.

In the following example, assume unconditional image generation setting with CIFAR-10, and the generative model generator, which takes a 128-dimensional standard normal noise vector.

First, import the module:

import torch_fidelity

Add the following lines at the end of epoch evaluation:

wrapped_generator = torch_fidelity.GenerativeModelModuleWrapper(generator, 128, 'normal', 0)

metrics_dict = torch_fidelity.calculate_metrics(
    input1=wrapped_generator, 
    input2='cifar10-train', 
    cuda=True, 
    isc=True, 
    fid=True, 
    kid=True, 
    verbose=False,
)

The resulting dictionary with computed metrics can logged directly to tensorboard, wandb, or console:

print(metrics_dict)

Output:

{
    'inception_score_mean': 11.23678, 
    'inception_score_std': 0.09514061, 
    'frechet_inception_distance': 18.12198,
    'kernel_inception_distance_mean': 0.01369556, 
    'kernel_inception_distance_std': 0.001310059
}

See also: Full API reference

Example of Integration with the Training Loop

Refer to sngan_cifar10.py for a complete training example.

Evolution of fixed generator latents in the example:

Evolution of fixed generator latents

A generator checkpoint resulting from training the example can be downloaded here.

Citation

Citation is recommended to reinforce the evaluation protocol in works relying on torch-fidelity. To ensure reproducibility when citing this repository, use the following BibTeX:

@misc{obukhov2020torchfidelity,
  author={Anton Obukhov and Maximilian Seitzer and Po-Wei Wu and Semen Zhydenko and Jonathan Kyl and Elvis Yu-Jing Lin},
  year=2020,
  title={High-fidelity performance metrics for generative models in PyTorch},
  url={https://github.com/toshas/torch-fidelity},
  publisher={Zenodo},
  version={v0.3.0},
  doi={10.5281/zenodo.4957738},
  note={Version: 0.3.0, DOI: 10.5281/zenodo.4957738}
}
Owner
Vikram Voleti
PhD student at Mila, University of Montreal
Vikram Voleti
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022