Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Overview

Memory Efficient Attention Pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
  • benchmark and see how much torch jit helps
  • look at Triton and Keops and see if either can be a fit

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • [feature request] Combining with flash attention?

    [feature request] Combining with flash attention?

    There is a new algorithm to optimize the qkv attention, https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 It optimises the qkv attention part. Maybe you can look into integrating it with this.

    opened by Vbansal21 15
  • i did this, we could build on top

    i did this, we could build on top

    Hi there!

    It seems I did already some of the code... https://github.com/CHARM-Tx/linear_mem_attention_pytorch could we build on top of this? I talked to https://github.com/Chillee about an experimental functionality from functorch: https://github.com/pytorch/functorch that would allow for increased speed (mainly i want to match jax perofmance but its just difficult w/ pytorch imperative style).

    I would love to collaborate on this if you want!

    opened by hypnopump 5
  • Added dropout support to memory efficient variant

    Added dropout support to memory efficient variant

    Hey Phil,

    I have been using this repository for a project and I wanted to add dropout for completeness. I checked consistency with perceiver-ar impl.. I hope this is helpful.

    -Matt

    opened by usryokousha 2
  • Making this work with relative position bias from XTransformers

    Making this work with relative position bias from XTransformers

    Is there a way to make this work with RelativePositionBias. Currently this produces an attention bias of size $BHN^2$ where B is batch size, H is number of heads and N is input size. Can this be chunked and computed per chunk?

    opened by pfeatherstone 5
  •  save_for_backward can only save variables, but argument 5 is of type bool

    save_for_backward can only save variables, but argument 5 is of type bool

    Hi,

    Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

    Can you please help me out?

    Code:

    import torch from memory_efficient_attention_pytorch import Attention

    cross_attn = Attention( dim = 512, dim_head = 64, heads = 8, memory_efficient = True, q_bucket_size = 1024, k_bucket_size = 2048 ).cuda() (# out = sm_mod(inp1)) did this to avoid being a header x = torch.randn(1, 65536, 512).cuda() context = torch.randn(1, 65536, 512).cuda() (# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading out = cross_attn(x

    ERROR:

    File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in cli.main() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main run() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file runpy.run_path(target_as_str, run_name=compat.force_str("main")) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path return _run_module_code(code, init_globals, run_name, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out) File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, *args) TypeError: save_for_backward can only save variables, but argument 5 is of type bool

    opened by aliabid2243 1
  • Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/35559a05572f9d4eb982a8e2e399b40a2d61b85c/memory_efficient_attention_pytorch/memory_efficient_attention.py#L95

    Should this be: summarize_qkv_fn = summarize_qkv_chunk if needs_backwards else checkpointed_summarize_qkv_chunk instead of: summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    opened by vrobot 0
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language Models"

GreaseLM: Graph REASoning Enhanced Language Models This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language

137 Jan 02, 2023
Scenic: A Jax Library for Computer Vision and Beyond

Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c

Google Research 1.6k Dec 27, 2022
A simple baseline for 3d human pose estimation in tensorflow. Presented at ICCV 17.

3d-pose-baseline This is the code for the paper Julieta Martinez, Rayat Hossain, Javier Romero, James J. Little. A simple yet effective baseline for 3

Julieta Martinez 1.3k Jan 03, 2023
Repository containing detailed experiments related to the paper "Memotion Analysis through the Lens of Joint Embedding".

Memotion Analysis Through The Lens Of Joint Embedding This repository contains the experiments conducted as described in the paper 'Memotion Analysis

Nethra Gunti 1 Mar 16, 2022
Codebase for BMVC 2021 paper "Text Based Person Search with Limited Data"

Text Based Person Search with Limited Data This is the codebase for our BMVC 2021 paper. Please bear with me refactoring this codebase after CVPR dead

Xiao Han 33 Nov 24, 2022
Supervised forecasting of sequential data in Python.

Supervised forecasting of sequential data in Python. Intro Supervised forecasting is the machine learning task of making predictions for sequential da

The Alan Turing Institute 54 Nov 15, 2022
Speedy Implementation of Instance-based Learning (IBL) agents in Python

A Python library to create single or multi Instance-based Learning (IBL) agents that are built based on Instance Based Learning Theory (IBLT) 1 Instal

0 Nov 18, 2021
Learning Saliency Propagation for Semi-supervised Instance Segmentation

Learning Saliency Propagation for Semi-supervised Instance Segmentation PyTorch Implementation This repository contains: the PyTorch implementation of

Berkeley DeepDrive 68 Oct 18, 2022
Air Quality Prediction Using LSTM

AirQualityPredictionUsingLSTM In this Repo, i present to you the winning solution of smart gujarat hackathon 2019 where the task was to predict the qu

Deepak Nandwani 2 Dec 13, 2022
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
Constraint-based geometry sketcher for blender

Constraint-based sketcher addon for Blender that allows to create precise 2d shapes by defining a set of geometric constraints like tangent, distance,

1.7k Dec 31, 2022
Functional TensorFlow Implementation of Singular Value Decomposition for paper Fast Graph Learning

tf-fsvd TensorFlow Implementation of Functional Singular Value Decomposition for paper Fast Graph Learning with Unique Optimal Solutions Cite If you f

Sami Abu-El-Haija 14 Nov 25, 2021
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
Temporally Coherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Duc Linh Nguyen 2 Jan 18, 2022
Disentangled Face Attribute Editing via Instance-Aware Latent Space Search, accepted by IJCAI 2021.

Instance-Aware Latent-Space Search This is a PyTorch implementation of the following paper: Disentangled Face Attribute Editing via Instance-Aware Lat

67 Dec 21, 2022
Code for the RA-L (ICRA) 2021 paper "SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition"

SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition [ArXiv+Supplementary] [IEEE Xplore RA-L 2021] [ICRA 2021 YouTube Video]

Sourav Garg 63 Dec 12, 2022
Bi-level feature alignment for versatile image translation and manipulation (Under submission of TPAMI)

Bi-level feature alignment for versatile image translation and manipulation (Under submission of TPAMI) Preparation Clone the Synchronized-BatchNorm-P

Fangneng Zhan 12 Aug 10, 2022
Unofficial PyTorch Implementation of "Augmenting Convolutional networks with attention-based aggregation"

Pytorch Implementation of Augmenting Convolutional networks with attention-based aggregation This is the unofficial PyTorch Implementation of "Augment

DK 20 Sep 09, 2022
Pytorch implementation of MixNMatch

MixNMatch: Multifactor Disentanglement and Encoding for Conditional Image Generation [Paper] Yuheng Li, Krishna Kumar Singh, Utkarsh Ojha, Yong Jae Le

910 Dec 30, 2022
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022