Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Overview

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module function.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import the function manually
from torch_mutable_modules import to_mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    
    def forward(self, x):
        return self.linear(x)

my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Usage with CUDA

To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU

Moving a module to the GPU with .to() and .cuda() after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.

# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")

Detailed examples

Please check out example.py to see more detailed example usages of the to_mutable_module function.

Contributing

Please feel free to submit issues or pull requests!

You might also like...
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment.

 MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare results and run monte carlo algorithm with them

Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Torch implementation of
Torch implementation of "Enhanced Deep Residual Networks for Single Image Super-Resolution"

NTIRE2017 Super-resolution Challenge: SNU_CVLab Introduction This is our project repository for CVPR 2017 Workshop (2nd NTIRE). We, Team SNU_CVLab, (B

Automatic number plate recognition using tech:  Yolo, OCR, Scene text detection, scene text recognation, flask, torch
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Releases(v1.1.2)
Owner
Kento Nishi
17-year-old programmer at Lynbrook High School, with strong interests in AI/Machine Learning. Open source developer and researcher at the Four Eyes Lab.
Kento Nishi
Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Pre-trained image classification models for Jax/Haiku Jax/Haiku Applications are deep learning models that are made available alongside pre-trained we

Alper Baris CELIK 14 Dec 20, 2022
Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation

NorCal Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation On Model Calibration for Long-Tailed Object Detec

Tai-Yu (Daniel) Pan 24 Dec 25, 2022
for a paper about leveraging discourse markers for training new models

TSLM-DISCOURSE-MARKERS Scope This repository contains: (1) Code to extract discourse markers from wikipedia (TSA). (1) Code to extract significant dis

International Business Machines 6 Nov 02, 2022
Graph parsing approach to structured sentiment analysis.

Fine-grained Sentiment Analysis as Dependency Graph Parsing This repository contains the code and datasets described in following paper: Fine-grained

Jeremy Barnes 36 Dec 12, 2022
An implementation for `Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction`

Text2Event An implementation for Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction Please contact Yaojie Lu (@

Roger 153 Jan 07, 2023
SymPy-powered, Wolfram|Alpha-like answer engine totally in your browser, without backend computation

SymPy Beta SymPy Beta is a fork of SymPy Gamma. The purpose of this project is to run a SymPy-powered, Wolfram|Alpha-like answer engine totally in you

Liumeo 25 Dec 21, 2022
Reverse engineer your pytorch vision models, in style

🔍 Rover Reverse engineer your CNNs, in style Rover will help you break down your CNN and visualize the features from within the model. No need to wri

Mayukh Deb 32 Sep 24, 2022
2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation

2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation Authors: Ge-Peng Ji*, Yu-Cheng Chou*, Deng-Ping Fan, Geng Che

Ge-Peng Ji (Daniel) 85 Dec 30, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

DV Lab 116 Dec 20, 2022
PyTorch Personal Trainer: My framework for deep learning experiments

Alex's PyTorch Personal Trainer (ptpt) (name subject to change) This repository contains my personal lightweight framework for deep learning projects

Alex McKinney 8 Jul 14, 2022
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Developed By Google!

Machine Learning Hand Detector This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Dev

Popstar Idhant 3 Feb 25, 2022
TensorFlow implementation of "TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?"

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Source: Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Aritra Roy Gosthipaty 23 Dec 24, 2022
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
[WACV 2022] Contextual Gradient Scaling for Few-Shot Learning

CxGrad - Official PyTorch Implementation Contextual Gradient Scaling for Few-Shot Learning Sanghyuk Lee, Seunghyun Lee, and Byung Cheol Song In WACV 2

Sanghyuk Lee 4 Dec 05, 2022
A light weight data augmentation tool for training CNNs and Viola Jones detectors

hey-daug A light weight data augmentation tool for training CNNs and Viola Jones detectors (Haar Cascades). This tool inflates your data by up to six

Jaiyam Sharma 2 Nov 23, 2019
ChatBot-Pytorch - A GPT-2 ChatBot implemented using Pytorch and Huggingface-transformers

ChatBot-Pytorch A GPT-2 ChatBot implemented using Pytorch and Huggingface-transf

ParZival 42 Dec 09, 2022
Artificial Intelligence playing minesweeper 🤖

AI playing Minesweeper ✨ Minesweeper is a single-player puzzle video game. The objective of the game is to clear a rectangular board containing hidden

Vaibhaw 8 Oct 17, 2022
Ego4d dataset repository. Download the dataset, visualize, extract features & example usage of the dataset

Ego4D EGO4D is the world's largest egocentric (first person) video ML dataset and benchmark suite, with 3,600 hrs (and counting) of densely narrated v

Meta Research 118 Jan 07, 2023