Runtime type annotations for the shape, dtype etc. of PyTorch Tensors.

Overview

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.


Installation

pip install torchtyping

Requires Python 3.7+ and PyTorch 1.7.0+.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as per named tensors;
  • arbitrary number of batch dimensions with ...;
  • ...plus anything else you like, as torchtyping is highly extensible.

If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.

If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of shape, dtype, layout, details are optional.

  • The shape argument can be any of:
    • An int: the dimension must be of exactly this size. If it is -1 then any size is allowed.
    • A str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A ...: An arbitrary number of dimensions of any sizes.
    • A str: int pair (technically it's a slice), combining both str and int behaviour. (Just a str on its own is equivalent to str: -1.)
    • A str: ... pair, in which case the multiple dimensions corresponding to ... will be bound to the name specified by str, and again checked for consistency between arguments.
    • None, which when used in conjunction with is_named below, indicates a dimension that must not have a name in the sense of named tensors.
    • A None: int pair, combining both None and int behaviour. (Just a None on its own is equivalent to None: -1.)
    • A typing.Any: Any size is allowed for this dimension (equivalent to -1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example TensorType[-1, -1, -1] for a three-dimensional tensor.
  • The dtype argument can be any of:
    • torch.float32, torch.float64 etc.
    • int, bool, float, which are converted to their corresponding PyTorch types. float is specifically interpreted as torch.get_default_dtype(), which is usually float32.
  • The layout argument can be either torch.strided or torch.sparse_coo, for dense and sparse tensors respectively.
  • The details argument offers a way to pass an arbitrary number of additional flags that customise and extend torchtyping. Two flags are built-in by default. torchtyping.is_named causes the names of tensor dimensions to be checked, and torchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. TensorType[torch.float32].) For discussion on how to customise torchtyping with your own details, see the further documentation.

Check multiple things at once by just putting them all together inside a single []. For example TensorType["batch": ..., "length", "channels", float, is_named].

torchtyping.patch_typeguard()

torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using @typeguard.typechecked, then torchtyping.patch_typeguard() should be called any time before using @typeguard.typechecked. For example you could call it at the start of each file using torchtyping.
  • If using typeguard.importhook.install_import_hook, then torchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could call torchtyping.patch_typeguard() just once, at the same time as the typeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".

Further documentation

See the further documentation for:

  • FAQ;
    • Including flake8 and mypy compatibility;
  • How to write custom extensions to torchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
Research code for Arxiv paper "Camera Motion Agnostic 3D Human Pose Estimation"

GMR(Camera Motion Agnostic 3D Human Pose Estimation) This repo provides the source code of our arXiv paper: Seong Hyun Kim, Sunwon Jeong, Sungbum Park

Seong Hyun Kim 1 Feb 07, 2022
COVID-Net Open Source Initiative

The COVID-Net models provided here are intended to be used as reference models that can be built upon and enhanced as new data becomes available

Linda Wang 1.1k Dec 26, 2022
PyTorch reimplementation of Diffusion Models

PyTorch pretrained Diffusion Models A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author'

Patrick Esser 265 Jan 01, 2023
Akshat Surolia 2 May 11, 2022
95.47% on CIFAR10 with PyTorch

Train CIFAR10 with PyTorch I'm playing with PyTorch on the CIFAR10 dataset. Prerequisites Python 3.6+ PyTorch 1.0+ Training # Start training with: py

5k Dec 30, 2022
A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

CapsGNN ⠀⠀ A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019). Abstract The high-quality node embeddings learned from the Graph Neur

Benedek Rozemberczki 1.2k Jan 02, 2023
Large-Scale Unsupervised Object Discovery

Large-Scale Unsupervised Object Discovery Huy V. Vo, Elena Sizikova, Cordelia Schmid, Patrick Pérez, Jean Ponce [PDF] We propose a novel ranking-based

17 Sep 19, 2022
Invertible conditional GANs for image editing

Invertible Conditional GANs This is the implementation of the IcGAN model proposed in our paper: Invertible Conditional GANs for image editing. Novemb

Guim 278 Dec 12, 2022
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
Spatial Action Maps for Mobile Manipulation (RSS 2020)

spatial-action-maps Update: Please see our new spatial-intention-maps repository, which extends this work to multi-agent settings. It contains many ne

Jimmy Wu 27 Nov 30, 2022
Music library streaming app written in Flask & VueJS

djtaytay This is a little toy app made to explore Vue, brush up on my Python, and make a remote music collection accessable through a web interface. I

Ryan Tasson 6 May 27, 2022
AntiFuzz: Impeding Fuzzing Audits of Binary Executables

AntiFuzz: Impeding Fuzzing Audits of Binary Executables Get the paper here: https://www.usenix.org/system/files/sec19-guler.pdf Usage: The python scri

Chair for Sys­tems Se­cu­ri­ty 88 Dec 21, 2022
BiSeNet based on pytorch

BiSeNet BiSeNet based on pytorch 0.4.1 and python 3.6 Dataset Download CamVid dataset from Google Drive or Baidu Yun(6xw4). Pretrained model Download

367 Dec 26, 2022
QT Py Media Knob using rotary encoder & neopixel ring

QTPy-Knob QT Py USB Media Knob using rotary encoder & neopixel ring The QTPy-Knob features: Media knob for volume up/down/mute with "qtpy-knob.py" Cir

Tod E. Kurt 56 Dec 30, 2022
Resources related to our paper "CLIN-X: pre-trained language models and a study on cross-task transfer for concept extraction in the clinical domain"

CLIN-X (CLIN-X-ES) & (CLIN-X-EN) This repository holds the companion code for the system reported in the paper: "CLIN-X: pre-trained language models a

Bosch Research 4 Dec 05, 2022
Realistic lighting in ursina!

Ursina Lighting Realistic lighting in ursina! If you want to have realistic lighting in ursina, import the UrsinaLighting.py in your project and use t

17 Jul 07, 2022
Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021.

NL-CSNet-Pytorch Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021. Note: this repo only shows the strategy of

WenxueCui 7 Nov 07, 2022
The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper.

Intermdiate layer matters - SSL The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper. Downl

Aakash Kaku 35 Sep 19, 2022
Implementation for our ICCV 2021 paper: Dual-Camera Super-Resolution with Aligned Attention Modules

DCSR: Dual Camera Super-Resolution Implementation for our ICCV 2021 oral paper: Dual-Camera Super-Resolution with Aligned Attention Modules paper | pr

Tengfei Wang 110 Dec 20, 2022
CTF challenges and write-ups for MicroCTF 2021.

MicroCTF 2021 Qualifications About This repository contains CTF challenges and official write-ups for MicroCTF 2021 Qualifications. License Distribute

Shellmates 12 Dec 27, 2022