PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

Related tags

Deep Learningpytorch
Overview

PyTorch-LIT

PyPI version

PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

With the rapid growth of deep learning research, models are becoming increasingly complex in terms of parameters and complexity, making it difficult to run the models on currently available end devices. For example, GPT-J with 6B parameters only needs 24 GB of RAM in full-precision mode to be ready for execution, which may be impossible in most systems; even a powerful GPU like the RTX 2060 with 6 GB of memory can't even contain GPT-J in half-precision mode, making direct inference impossible.

To address this issue when training large models, libraries such as DeepSpeed use offload techniques (e.g., ZeRO) to handle the parameters and make training possible by dividing the weights between devices. In contrast, there is no direct library/framework available for inference.

PyTorch-LIT allows the inference of large models by loading weights as needed from secondary specified memory, which could be disk, CPU, or GPU, allowing the inference of models that do not even fit in the system's main memory simply by trading off time.

Quick Start

  1. Install the library
pip install pytorch-lit
  1. You have to save the model's weight in a way that toolkit can use
from pytorch_lit.export import prepare_params

weights = {} # your model's parameters (state_dict)
# change the directory to save your model and specify data-type
prepare_params(weights, ".models/my-model", dtype="float32")
  1. After preparing the weights, you can infer your model
from pytorch_lit import LitModule

# pass your model construction as a closure, 
# specify weights path and inference device 
model = LitModule.from_params(".models/my-model",
                                  lambda: MyModel(),
                                  device="cuda")
result = model(*arg, **kwargs)
  1. Have fun enjoying the inference of the large model on a lower memory device:)

Examples

The repo's examples directory contains examples. There are currently two examples of GPT-J, one for text generation and the other for extracting hidden states as feature representations.

Development

This is a work in progress that will require further development before it can be considered a stable inference toolkit. Here is a list of potential future developments:

  • Caching and batch loading as many weights as memory allows, with weights being replaced in parallel with future ones (through the order of the execution graph)
  • C++ extension for PyTorch jit, so the solution applies to the majority of production end devices
  • Add functions to make it easier to export large models to onnx or trace with jit
  • Use better and faster format than numpy memmap

Contributions are welcome; to discuss your idea further, open an issue with the discussion tag. Finally, you can submit a pull request to merge your fork.

How does it work?

This implementation was made possible primarily by two ideas:

  • The first issue was that PyTorch initialized the model object's parameters when constructing it, causing the construction to fail when the model couldn't fit into memory. To address this, we proposed temporarily hijacking PyTorch's Parameter class's __new__ method during model construction, allowing us to replace the parameter's tensor with a view from a shared global tensor immediately after creation. By doing so, all parameters use the same shared big tensor as their primary storage, allowing the model to be built and tested with inputs to follow and trace the execution graph.
  • The second issue was the large size of model parameters; in the preparation step, we built a numpy memmap(np.memmap) and saved metadata that provided us with the location of each key in the memmap. This allowed us to read parameters from the memmap as needed. Following that, we use the PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution.

Citation

Please cite PyTorch-LIT if it helps your research. You can use the following BibTeX entry:

@misc{pytorch_lit,
	title = {PyTorch-LIT},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/PyTorch-LIT}},
	year = {2021}
}
You might also like...
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification

FPGA & FreeNet Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification by Zhuo Zheng, Yanfei Zhong, Ailong M

 WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

this is a lite easy to use virtual keyboard project for anyone to use
this is a lite easy to use virtual keyboard project for anyone to use

virtual_Keyboard this is a lite easy to use virtual keyboard project for anyone to use motivation I made this for this year's recruitment for RobEn AA

Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Code & Models for 3DETR - an End-to-end transformer model for 3D object detection
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] An end-to-end PyTorch framework for image and video classification
An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Comments
  • RuntimeError : OrderdDict mutated during iteration.

    RuntimeError : OrderdDict mutated during iteration.

    Hi, there are new problems. When the model parameters forward, raise a RuntimeError : OrderdDict mutated during iteration. detail as below: Traceback (most recent call last): File "nlp/rct-FPM-rhino/big_model/predict.py", line 24, in result = model(**tokens) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 34, in call return self.forward(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 31, in forward return self.module(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1057, in _call_impl for hook in itertools.chain( RuntimeError: OrderedDict mutated during iteration

    enviroments:

    GPU:NVIDIA GeForce 3090 CUDA version 11.4 pip list: certifi 2021.10.8 charset-normalizer 2.0.8 click 8.0.3 filelock 3.4.0 huggingface-hub 0.2.0 idna 3.3 joblib 1.1.0 numpy 1.21.4 packaging 21.3 Pillow 8.4.0 pip 21.2.4 pyparsing 3.0.6 pytorch-lit 0.1.7 PyYAML 6.0 regex 2021.11.10 requests 2.26.0 sacremoses 0.0.46 setuptools 58.0.4 six 1.16.0 tokenizer 3.3.2 tokenizers 0.10.3 torch 1.9.1+cu111 torchaudio 0.8.1 torchvision 0.9.1+cu111 tqdm 4.62.3 transformers 4.12.5 typing_extensions 4.0.1 urllib3 1.26.7

    I think this problem caused by PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution, when load and unload the parameters,the OrderedDict was be mutated.

    opened by changleilei 9
  • TypeError: <lambda>() missing 1 required positional argument: 'k'

    TypeError: () missing 1 required positional argument: 'k'

    Hello, when i use pytorch-lit prepare a model, got a TypeError as title. The detail as blow:

    File "nlp/rct-FPM-rhino/big_model/prepare_model.py", line 16, in prepare_model prepare_params(model, args.save_path, dtype='float32') File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 19, in prepare_params _params_to_memmap(parameters, path.join(save_dir, "model.bin"), File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 52, in _params_to_memmap param = get_param(k) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 50, in get_param = lambda key: params"get" TypeError: () missing 1 required positional argument: 'k'

    package list:

    certifi 2021.10.8 numpy 1.21.4 pip 21.2.4 pytorch-lit 0.1.6 setuptools 58.0.4 torch 1.10.0 tqdm 4.62.3 typing_extensions 4.0.1 wheel 0.37.0

    model: gpt-j-6B

    Have any suggesstion? Thanks.

    opened by changleilei 1
  • gpt-j generation speed very low

    gpt-j generation speed very low

    The output of gpt-j is very slow, for a 200 output token generation it takes about 20 minutes, for 2048 it takes more than an hour, this significantly limits any experimentation with the model.

    I checked Gpu utilization during inference which is about 1 percent or 4 percent, and gpu memory usage is below 4GB usage, my system has 8GB Gpu memory, if full Gpu is utilized it may be significantly increase the inference speed

    Are their simple hacks to speedup inference time ?

    opened by usama-ahmedkhan 3
  • Weights file format is changed, function partial_loader fails

    Weights file format is changed, function partial_loader fails

    Hi, thanks for your effort for making it easy to load and do inference from large models. I tried your code on a gpt-j model with different model file format, the weight files of the model are in several .pt files not like a single .bin file which your code function partial_loader() expects, does the code work with multiple weight file ? , how can i change it.

    opened by usama-ahmedkhan 4
Releases(0.1.7)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
3D detection and tracking viewer (visualization) for kitti & waymo dataset

3D detection and tracking viewer (visualization) for kitti & waymo dataset

222 Jan 08, 2023
Shared Attention for Multi-label Zero-shot Learning

Shared Attention for Multi-label Zero-shot Learning Overview This repository contains the implementation of Shared Attention for Multi-label Zero-shot

dathuynh 26 Dec 14, 2022
Neural Cellular Automata + CLIP

🧠 Text-2-Cellular Automata Using Neural Cellular Automata + OpenAI CLIP (Work in progress) Examples Text Prompt: Cthulu is watching cthulu_is_watchin

Mainak Deb 21 Dec 19, 2022
code from "Tensor decomposition of higher-order correlations by nonlinear Hebbian plasticity"

Code associated with the paper "Tensor decomposition of higher-order correlations by nonlinear Hebbian learning," Ocker & Buice, Neurips 2021. "plot_f

Gabriel Koch Ocker 4 Oct 16, 2022
Pytorch implementation of NeurIPS 2021 paper: Geometry Processing with Neural Fields.

Geometry Processing with Neural Fields Pytorch implementation for the NeurIPS 2021 paper: Geometry Processing with Neural Fields Guandao Yang, Serge B

Guandao Yang 162 Dec 16, 2022
Code release for "Conditional Adversarial Domain Adaptation" (NIPS 2018)

CDAN Code release for "Conditional Adversarial Domain Adaptation" (NIPS 2018) New version: https://github.com/thuml/Transfer-Learning-Library Dataset

THUML @ Tsinghua University 363 Dec 20, 2022
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
Image super-resolution (SR) is a fast-moving field with novel architectures attracting the spotlight

Revisiting RCAN: Improved Training for Image Super-Resolution Introduction Image super-resolution (SR) is a fast-moving field with novel architectures

Zudi Lin 76 Dec 01, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Python Actor concurrency library

Thespian Actor Library This library provides the framework of an Actor model for use by applications implementing Actors. Thespian Site with Documenta

Kevin Quick 177 Dec 11, 2022
Source code for our paper "Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash"

Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash Abstract: Apple recently revealed its deep perceptual hashing system NeuralHash to

<a href=[email protected]"> 11 Dec 03, 2022
Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019)

Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019) Introduction Official implementation of Dynamic Multi-scale Filters for Semant

23 Oct 21, 2022
Blender scripts for computing geodesic distance

GeoDoodle Geodesic distance computation for Blender meshes Table of Contents Overivew Usage Implementation Overview This addon provides an operator fo

20 Jun 08, 2022
ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Katherine Crowson 53 Dec 29, 2022
End-to-End Object Detection with Fully Convolutional Network

This project provides an implementation for "End-to-End Object Detection with Fully Convolutional Network" on PyTorch.

472 Dec 22, 2022
This repo contains the source code and a benchmark for predicting user's utilities with Machine Learning techniques for Computational Persuasion

Machine Learning for Argument-Based Computational Persuasion This repo contains the source code and a benchmark for predicting user's utilities with M

Ivan Donadello 4 Nov 07, 2022
An Straight Dilated Network with Wavelet for image Deblurring

SDWNet: A Straight Dilated Network with Wavelet Transformation for Image Deblurring(offical) 1. Introduction This repo is not only used for our paper(

FlyEgle 41 Jan 04, 2023
State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Fredrik Carlsson 88 Dec 30, 2022
Geometric Vector Perceptron --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Code to accompany Learning from Protein Structure with Geometric Vector Perceptrons by B Jing, S Eismann, P Suriana, RJL T

Dror Lab 85 Dec 29, 2022
This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods

pyLiDAR-SLAM This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods, which can easily be evaluated

Kitware, Inc. 208 Dec 16, 2022