Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Overview

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module function.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import the function manually
from torch_mutable_modules import to_mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    
    def forward(self, x):
        return self.linear(x)

my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Usage with CUDA

To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU

Moving a module to the GPU with .to() and .cuda() after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.

# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")

Detailed examples

Please check out example.py to see more detailed example usages of the to_mutable_module function.

Contributing

Please feel free to submit issues or pull requests!

You might also like...
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment.

 MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare results and run monte carlo algorithm with them

Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Torch implementation of
Torch implementation of "Enhanced Deep Residual Networks for Single Image Super-Resolution"

NTIRE2017 Super-resolution Challenge: SNU_CVLab Introduction This is our project repository for CVPR 2017 Workshop (2nd NTIRE). We, Team SNU_CVLab, (B

Automatic number plate recognition using tech:  Yolo, OCR, Scene text detection, scene text recognation, flask, torch
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Releases(v1.1.2)
Owner
Kento Nishi
17-year-old programmer at Lynbrook High School, with strong interests in AI/Machine Learning. Open source developer and researcher at the Four Eyes Lab.
Kento Nishi
A blender add-on that automatically re-aligns wrong axis objects.

Auto Align A blender add-on that automatically re-aligns wrong axis objects. Usage There are three options available in the 3D Viewport Sidebar It

29 Nov 25, 2022
A Python library for unevenly-spaced time series analysis

traces A Python library for unevenly-spaced time series analysis. Why? Taking measurements at irregular intervals is common, but most tools are primar

Datascope Analytics 516 Dec 29, 2022
The final project of "Applying AI to 2D Medical Imaging Data" of "AI for Healthcare" nanodegree - Udacity.

Pneumonia Detection from X-Rays Project Overview In this project, you will apply the skills that you have acquired in this 2D medical imaging course t

Omar Laham 1 Jan 14, 2022
Blind Video Temporal Consistency via Deep Video Prior

deep-video-prior (DVP) Code for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior PyTorch implementation | paper | project web

Chenyang LEI 272 Dec 21, 2022
Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors

Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors

Kevin Bock 1.5k Jan 06, 2023
PECOS - Prediction for Enormous and Correlated Spaces

PECOS - Predictions for Enormous and Correlated Output Spaces PECOS is a versatile and modular machine learning (ML) framework for fast learning and i

Amazon 387 Jan 04, 2023
Annotate datasets with a semi-trained or fully trained YOLOv5 model

YOLOv5 Auto Annotator Annotate datasets with a semi-trained or fully trained YOLOv5 model Prerequisites Ubuntu =20.04 Python =3.7 System dependencie

Akash James 3 May 14, 2022
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 05, 2023
Code and data for "Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning" (EMNLP 2021).

GD-VCR Code for Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning (EMNLP 2021). Research Questions and Aims: How well can a model perform o

Da Yin 24 Oct 13, 2022
DeepOBS: A Deep Learning Optimizer Benchmark Suite

DeepOBS - A Deep Learning Optimizer Benchmark Suite DeepOBS is a benchmarking suite that drastically simplifies, automates and improves the evaluation

Aaron Bahde 7 May 12, 2020
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

Sornsiri.P 7 Dec 22, 2022
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
A project to build an AI voice assistant using Python . The Voice assistant interacts with the humans to perform basic tasks.

AI_Personal_Voice_Assistant_Using_Python A project to build an AI voice assistant using Python . The Voice assistant interacts with the humans to perf

Chumui Tripura 1 Oct 30, 2021
Activating More Pixels in Image Super-Resolution Transformer

HAT [Paper Link] Activating More Pixels in Image Super-Resolution Transformer Xiangyu Chen, Xintao Wang, Jiantao Zhou and Chao Dong BibTeX @article{ch

XyChen 270 Dec 27, 2022
Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

Yihong Sun 12 Nov 15, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 02, 2021
Code for "Unsupervised Source Separation via Bayesian inference in the latent domain"

LQVAE-separation Code for "Unsupervised Source Separation via Bayesian inference in the latent domain" Paper Samples GT Compressed Separated Drums GT

Michele Mancusi 30 Oct 25, 2022
DLWP: Deep Learning Weather Prediction

DLWP: Deep Learning Weather Prediction DLWP is a Python project containing data-

Kushal Shingote 3 Aug 14, 2022
A model to classify a piece of news as REAL or FAKE

Fake_news_classification A model to classify a piece of news as REAL or FAKE. This python project of detecting fake news deals with fake and real news

Gokul Stark 1 Jan 29, 2022