InvTorch: memory-efficient models with invertible functions

Related tags

Deep Learninginvtorch
Overview

InvTorch: Memory-Efficient Invertible Functions

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, not only the intermediate activations will be released from memory. The input tensors get deallocated and recomputed later using the inverse function only in the backward pass. This is useful in extreme situations where more compute is traded with memory. However, there are few caveats to consider which are detailed here.

Installation

InvTorch has minimal dependencies. It only requires PyTorch version 1.10.0 or later.

conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install invtorch

Basic Usage

The main module that we are interested in is InvertibleModule which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn
from invtorch import InvertibleModule


class InvertibleLinear(InvertibleModule):
    def __init__(self, in_features, out_features):
        super().__init__(invertible=True, checkpoint=True)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        outputs = inputs @ self.weight.T + self.bias
        requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
        return outputs.requires_grad_(requires_grad)

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()

Structure

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(*inputs). Additionally, it is necessary to define its inverse function as inverse(*outputs). Both methods can only take one or more positional arguments and return a torch.Tensor or a tuple of outputs which can have anything including tensors.

Requires Gradient

function() must manually call .requires_grad_(True/False) on all output tensors. The forward pass is run in no_grad mode and there is no way to detect which output need gradients without tracing. It is possible to infer this from requires_grad values of the inputs and self.parameters(). The above code uses do_require_grad() which returns True if any input did require gradient.

Example

Now, this model is ready to be instantiated and used directly.

x = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(x))

y = model(x)
print('Output requires_grad:', y.requires_grad)
print('Input was freed:', x.storage().size() == 0)

y.backward(torch.randn_like(y))
print('Input was restored:', x.storage().size() != 0)

Checkpoint and Invertible Modes

InvertibleModule has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, or no input or parameter has requires_grad set to True, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or an ordinary checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions.

Limitations

Under the hood, InvertibleModule uses invertible_checkpoint(); a low-level implementation which allows it to function. There are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, refer to the documentation in the code for more details.

Overriding forward()

Although forward() is now doing important things to ensure the validity of the results when calling invertible_checkpoint(), it can still be overridden. The main reason of doing so is to provide a more user-friendly interface; function signature and output format. For example, function() could return extra outputs that are not needed in the module outputs but are essential for correctly computing the inverse(). In such case, define forward() to wrap outputs = super().forward(*inputs) more cleanly.

TODOs

Here are few feature ideas that could be implemented to enrich the utility of this package:

  • Add more basic operations and modules
  • Add coupling and interleave -based invertible operations
  • Add more checks to help the user in debugging more features
  • Allow picking some inputs to not be freed in invertible mode
  • Context-manager to temporarily change the mode of operation
  • Implement dynamic discovery for outputs that requires_grad
  • Develop an automatic mode optimization for a network for various objectives
You might also like...
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Efficient-GlobalPointer - Pytorch Efficient GlobalPointer
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

Releases(v0.5.0)
Owner
Modar M. Alfadly
Deep learning researcher interested in understanding neural networks
Modar M. Alfadly
Keras Realtime Multi-Person Pose Estimation - Keras version of Realtime Multi-Person Pose Estimation project

This repository has become incompatible with the latest and recommended version of Tensorflow 2.0 Instead of refactoring this code painfully, I create

M Faber 769 Dec 08, 2022
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
Pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model'

RTK-PAD This is an official pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model', which is accepted by IEEE T

6 Aug 01, 2022
HINet: Half Instance Normalization Network for Image Restoration

HINet: Half Instance Normalization Network for Image Restoration Liangyu Chen, Xin Lu, Jie Zhang, Xiaojie Chu, Chengpeng Chen Paper: https://arxiv.org

303 Dec 31, 2022
A fast, dataset-agnostic, deep visual search engine for digital art history

imgs.ai imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings. It utilizes modern

Fabian Offert 5 Dec 14, 2022
This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used

0 Apr 02, 2022
DEMix Layers for Modular Language Modeling

DEMix This repository contains modeling utilities for "DEMix Layers: Disentangling Domains for Modular Language Modeling" (Gururangan et. al, 2021). T

Suchin 43 Nov 11, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Read number plates with https://platerecognizer.com/

HASS-plate-recognizer Read vehicle license plates with https://platerecognizer.com/ which offers free processing of 2500 images per month. You will ne

Robin 69 Dec 30, 2022
Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022)

Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022) Please cite "Independent SE(3)-Equivar

Octavian Ganea 154 Jan 02, 2023
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
A python implementation of Yolov5 to detect fire or smoke in the wild in Jetson Xavier nx and Jetson nano

yolov5-fire-smoke-detect-python A python implementation of Yolov5 to detect fire or smoke in the wild in Jetson Xavier nx and Jetson nano You can see

20 Dec 15, 2022
PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation

StyleSpeech - PyTorch Implementation PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation. Status (2021.06.13

Keon Lee 140 Dec 21, 2022
2021:"Bridging Global Context Interactions for High-Fidelity Image Completion"

TFill arXiv | Project This repository implements the training, testing and editing tools for "Bridging Global Context Interactions for High-Fidelity I

Chuanxia Zheng 111 Jan 08, 2023
Pytorch for Segmentation

Pytorch for Semantic Segmentation This repo has been deprecated currently and I will not maintain it. Meanwhile, I strongly recommend you can refer to

ycszen 411 Nov 22, 2022
An algorithm study of the 6th iOS 10 set of Boost Camp Web Mobile

알고리즘 스터디 🔥 부스트캠프 웹모바일 6기 iOS 10조의 알고리즘 스터디 입니다. 개인적인 사정 등으로 S034, S055만 참가하였습니다. 스터디 목적 상진: 코테 합격 + 부캠끝나고 아침에 일어나기 위해 필요한 사이클 기완: 꾸준하게 자리에 앉아 공부하기 +

2 Jan 11, 2022
SberSwap Video Swap base on deep learning

SberSwap Video Swap base on deep learning

Sber AI 431 Jan 03, 2023
LSUN Dataset Documentation and Demo Code

LSUN Please check LSUN webpage for more information about the dataset. Data Release All the images in one category are stored in one lmdb database fil

Fisher Yu 426 Jan 02, 2023
EgoNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale

EgonNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale Paper: EgoNN: Egocentric Neural Network for Point Cloud

19 Sep 20, 2022
Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021)

UNITE and UNITE+ Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021) Unbalanced Intrinsic Feature Transport for Exemplar-bas

Fangneng Zhan 183 Nov 09, 2022