This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

Related tags

Deep LearningMesa
Overview

Mesa: A Memory-saving Training Framework for Transformers

This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Transformers.

By Zizheng Pan, Peng Chen, Haoyu He, Jing Liu, Jianfei Cai and Bohan Zhuang.

image-20211116105242785

Installation

  1. Create a virtual environment with anaconda.

    conda create -n mesa python=3.7 -y
    conda activate mesa
    
    # Install PyTorch, we use PyTorch 1.7.1 with CUDA 10.1 
    pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
    # Install ninja
    pip install ninja
  2. Build and install Mesa.

    # cloen this repo
    git clone https://github.com/zhuang-group/Mesa
    # build
    cd Mesa/
    # You need to have an NVIDIA GPU
    python setup.py develop

Usage

  1. Prepare your policy and save as a text file, e.g. policy.txt.

    on gelu: # layer tag, choices: fc, conv, gelu, bn, relu, softmax, matmul, layernorm
        by_index: all # layer index
        enable: True # enable for compressing
        level: 256 # we adopt 8-bit quantization by default
        ema_decay: 0.9 # the decay rate for running estimates
        
        by_index: 1 2 # e.g. exluding GELU layers that indexed by 1 and 2.
        enable: False
  2. Next, you can wrap your model with Mesa by:

    import mesa as ms
    ms.policy.convert_by_num_groups(model, 3)
    # or convert by group size with ms.policy.convert_by_group_size(model, 64)
    
    # setup compression policy
    ms.policy.deploy_on_init(model, '[path to policy.txt]', verbose=print, override_verbose=False)

    That's all you need to use Mesa for memory saving.

    Note that convert_by_num_groups and convert_by_group_size only recognize nn.XXX, if your code has functional operations, such as [email protected] and F.Softmax, you may need to manually setup these layers. For example:

    # matrix multipcation (before)
    out = Q@K.transpose(-2, -1)
    # with Mesa
    self.mm = ms.MatMul(quant_groups=3)
    out = self.mm(q, k.transpose(-2, -1))
    
    # sofmtax (before)
    attn = attn.softmax(dim=-1)
    # with Mesa
    self.softmax = ms.Softmax(dim=-1, quant_groups=3)
    attn = self.softmax(attn)
  3. You can also target one layer by:

    import mesa as ms
    # previous 
    self.act = nn.GELU()
    # with Mesa
    self.act = ms.GELU(quant_groups=[num of quantization groups])

Demo projects for DeiT and Swin

We provide demo projects to replicate our results of training DeiT and Swin with Mesa, please refer to DeiT-Mesa and Swin-Mesa.

Results on ImageNet

Model Param (M) FLOPs (G) Train Memory Top-1 (%)
DeiT-Ti 5 1.3 4,171 71.9
DeiT-Ti w/ Mesa 5 1.3 1,858 72.1
DeiT-S 22 4.6 8,459 79.8
DeiT-S w/ Mesa 22 4.6 3,840 80.0
DeiT-B 86 17.5 17,691 81.8
DeiT-B w/ Mesa 86 17.5 8,616 81.8
Swin-Ti 29 4.5 11,812 81.3
Swin-Ti w/ Mesa 29 4.5 5,371 81.3
PVT-Ti 13 1.9 7,800 75.1
PVT-Ti w/ Mesa 13 1.9 3,782 74.9

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgments

This repository has adopted part of the quantization codes from ActNN, we thank the authors for their open-sourced code.

Owner
Zhuang AI Group
Zhuang AI Group
A coin flip game in which you can put the amount of money below or equal to 1000 and then choose heads or tail

COIN_FLIPPY ##This is a simple example package. You can use Github-flavored Markdown to write your content. Coinflippy A coin flip game in which you c

2 Dec 26, 2021
The code for replicating the experiments from the LFI in SSMs with Unknown Dynamics paper.

Likelihood-Free Inference in State-Space Models with Unknown Dynamics This package contains the codes required to run the experiments in the paper. Th

Alex Aushev 0 Dec 27, 2021
A framework to train language models to learn invariant representations.

Invariant Language Modeling Implementation of the training for invariant language models. Motivation Modern pretrained language models are critical co

6 Nov 16, 2022
Official repository for: Continuous Control With Ensemble DeepDeterministic Policy Gradients

Continuous Control With Ensemble Deep Deterministic Policy Gradients This repository is the official implementation of Continuous Control With Ensembl

4 Dec 06, 2021
SuRE Evaluation: A Supplementary Material

SuRE Evaluation: A Supplementary Material This repository contains supplementary material regarding the evaluations presented in the paper Visual Expl

NYU Visualization Lab 0 Dec 14, 2021
Definition of a business problem according to Wilson Lower Bound Score and Time Based Average Rating

Wilson Lower Bound Score, Time Based Rating Average In this study I tried to calculate the product rating and sorting reviews more accurately. I have

3 Sep 30, 2021
A task Provided by A respective Artenal Ai and Ml based Company to complete it

A task Provided by A respective Alternal Ai and Ml based Company to complete it .

Parth Madan 1 Jan 25, 2022
Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 short.

Session-aware BERT4Rec Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 shor

Jamie J. Seol 22 Dec 13, 2022
Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

Time2box Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

LingCai 4 Aug 23, 2022
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
Code for layerwise detection of linguistic anomaly paper (ACL 2021)

Layerwise Anomaly This repository contains the source code and data for our ACL 2021 paper: "How is BERT surprised? Layerwise detection of linguistic

6 Dec 07, 2022
ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Katherine Crowson 53 Dec 29, 2022
PyBrain - Another Python Machine Learning Library.

PyBrain -- the Python Machine Learning Library =============================================== INSTALLATION ------------ Quick answer: make sure you

2.8k Dec 31, 2022
Intro-to-dl - Resources for "Introduction to Deep Learning" course.

Introduction to Deep Learning course resources https://www.coursera.org/learn/intro-to-deep-learning Running on Google Colab (tested for all weeks) Go

Advanced Machine Learning specialisation by HSE 761 Dec 24, 2022
Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"

On-the-Fly Adaptation Official Pytorch Code base for On-the-Fly Test-time Adaptation for Medical Image Segmentation Paper Introduction One major probl

Jeya Maria Jose 17 Nov 10, 2022
Codebase for "Revisiting spatio-temporal layouts for compositional action recognition" (Oral at BMVC 2021).

Revisiting spatio-temporal layouts for compositional action recognition Codebase for "Revisiting spatio-temporal layouts for compositional action reco

Gorjan 20 Dec 15, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Google_Landmark_Retrieval_2021_2nd_Place_Solution The 2nd place solution of 2021 google landmark retrieval on kaggle. Environment We use cuda 11.1/pyt

229 Dec 13, 2022
I-SECRET: Importance-guided fundus image enhancement via semi-supervised contrastive constraining

I-SECRET This is the implementation of the MICCAI 2021 Paper "I-SECRET: Importance-guided fundus image enhancement via semi-supervised contrastive con

13 Dec 02, 2022
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022