Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Overview

Memory Efficient Attention Pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
  • benchmark and see how much torch jit helps
  • look at Triton and Keops and see if either can be a fit

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • [feature request] Combining with flash attention?

    [feature request] Combining with flash attention?

    There is a new algorithm to optimize the qkv attention, https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 It optimises the qkv attention part. Maybe you can look into integrating it with this.

    opened by Vbansal21 15
  • i did this, we could build on top

    i did this, we could build on top

    Hi there!

    It seems I did already some of the code... https://github.com/CHARM-Tx/linear_mem_attention_pytorch could we build on top of this? I talked to https://github.com/Chillee about an experimental functionality from functorch: https://github.com/pytorch/functorch that would allow for increased speed (mainly i want to match jax perofmance but its just difficult w/ pytorch imperative style).

    I would love to collaborate on this if you want!

    opened by hypnopump 5
  • Added dropout support to memory efficient variant

    Added dropout support to memory efficient variant

    Hey Phil,

    I have been using this repository for a project and I wanted to add dropout for completeness. I checked consistency with perceiver-ar impl.. I hope this is helpful.

    -Matt

    opened by usryokousha 2
  • Making this work with relative position bias from XTransformers

    Making this work with relative position bias from XTransformers

    Is there a way to make this work with RelativePositionBias. Currently this produces an attention bias of size $BHN^2$ where B is batch size, H is number of heads and N is input size. Can this be chunked and computed per chunk?

    opened by pfeatherstone 5
  •  save_for_backward can only save variables, but argument 5 is of type bool

    save_for_backward can only save variables, but argument 5 is of type bool

    Hi,

    Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

    Can you please help me out?

    Code:

    import torch from memory_efficient_attention_pytorch import Attention

    cross_attn = Attention( dim = 512, dim_head = 64, heads = 8, memory_efficient = True, q_bucket_size = 1024, k_bucket_size = 2048 ).cuda() (# out = sm_mod(inp1)) did this to avoid being a header x = torch.randn(1, 65536, 512).cuda() context = torch.randn(1, 65536, 512).cuda() (# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading out = cross_attn(x

    ERROR:

    File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in cli.main() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main run() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file runpy.run_path(target_as_str, run_name=compat.force_str("main")) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path return _run_module_code(code, init_globals, run_name, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out) File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, *args) TypeError: save_for_backward can only save variables, but argument 5 is of type bool

    opened by aliabid2243 1
  • Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/35559a05572f9d4eb982a8e2e399b40a2d61b85c/memory_efficient_attention_pytorch/memory_efficient_attention.py#L95

    Should this be: summarize_qkv_fn = summarize_qkv_chunk if needs_backwards else checkpointed_summarize_qkv_chunk instead of: summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    opened by vrobot 0
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Deep Reinforcement Learning based autonomous navigation for quadcopters using PPO algorithm.

PPO-based Autonomous Navigation for Quadcopters This repository contains an implementation of Proximal Policy Optimization (PPO) for autonomous naviga

Bilal Kabas 16 Nov 11, 2022
magiCARP: Contrastive Authoring+Reviewing Pretraining

magiCARP: Contrastive Authoring+Reviewing Pretraining Welcome to the magiCARP API, the test bed used by EleutherAI for performing text/text bi-encoder

EleutherAI 43 Dec 29, 2022
AgML is a comprehensive library for agricultural machine learning

AgML is a comprehensive library for agricultural machine learning. Currently, AgML provides access to a wealth of public agricultural datasets for common agricultural deep learning tasks.

Plant AI and Biophysics Lab 1 Jul 07, 2022
Real-time Joint Semantic Reasoning for Autonomous Driving

MultiNet MultiNet is able to jointly perform road segmentation, car detection and street classification. The model achieves real-time speed and state-

Marvin Teichmann 518 Dec 12, 2022
ExCon: Explanation-driven Supervised Contrastive Learning

ExCon: Explanation-driven Supervised Contrastive Learning Contributors of this repo: Zhibo Zhang ( Zhibo (Darren) Zhang 18 Nov 01, 2022

An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Facebook Research 1.5k Dec 31, 2022
Fast, modular reference implementation and easy training of Semantic Segmentation algorithms in PyTorch.

TorchSeg This project aims at providing a fast, modular reference implementation for semantic segmentation models using PyTorch. Highlights Modular De

ycszen 1.4k Jan 02, 2023
Official code for the CVPR 2021 paper "How Well Do Self-Supervised Models Transfer?"

How Well Do Self-Supervised Models Transfer? This repository hosts the code for the experiments in the CVPR 2021 paper How Well Do Self-Supervised Mod

Linus Ericsson 157 Dec 16, 2022
Download & Install mods for your favorit game with a few simple clicks

Husko's SteamWorkshop Downloader 🔴 IMPORTANT ❗ 🔴 The Tool is currently being rewritten so updates will be slow and only on the dev branch until it i

Husko 67 Nov 25, 2022
A demo of how to use JAX to create a simple gravity simulation

JAX Gravity This repo contains a demo of how to use JAX to create a simple gravity simulation. It uses JAX's experimental ode package to solve the dif

Cristian Garcia 16 Sep 22, 2022
This repo is developed for Strong Baseline For Vehicle Re-Identification in Track 2 Ai-City-2021 Challenges

A STRONG BASELINE FOR VEHICLE RE-IDENTIFICATION This paper is accepted to the IEEE Conference on Computer Vision and Pattern Recognition Workshop(CVPR

Cybercore Co. Ltd 78 Dec 29, 2022
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021

This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'. Installation To install, use conda with conda env c

14 Sep 21, 2022
AirLoop: Lifelong Loop Closure Detection

AirLoop This repo contains the source code for paper: Dasong Gao, Chen Wang, Sebastian Scherer. "AirLoop: Lifelong Loop Closure Detection." arXiv prep

Chen Wang 53 Jan 03, 2023
Official implementation of Representer Point Selection via Local Jacobian Expansion for Post-hoc Classifier Explanation of Deep Neural Networks and Ensemble Models at NeurIPS 2021

Representer Point Selection via Local Jacobian Expansion for Classifier Explanation of Deep Neural Networks and Ensemble Models This repository is the

Yi(Amy) Sui 2 Dec 01, 2021
Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet)

Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet) Our paper: https://arxiv.org/abs/2111.13324 We will release the complet

15 Oct 17, 2022
Neighborhood Contrastive Learning for Novel Class Discovery

Neighborhood Contrastive Learning for Novel Class Discovery This repository contains the official implementation of our paper: Neighborhood Contrastiv

Zhun Zhong 56 Dec 09, 2022
[NeurIPS 2021] "G-PATE: Scalable Differentially Private Data Generator via Private Aggregation of Teacher Discriminators"

G-PATE This is the official code base for our NeurIPS 2021 paper: "G-PATE: Scalable Differentially Private Data Generator via Private Aggregation of T

AI Secure 14 Oct 12, 2022
Unofficial pytorch implementation of the paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution"

DFSA Unofficial pytorch implementation of the ICCV 2021 paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution" (p

2 Nov 15, 2021
Code for project: "Learning to Minimize Remainder in Supervised Learning".

Learning to Minimize Remainder in Supervised Learning Code for project: "Learning to Minimize Remainder in Supervised Learning". Requirements and Envi

Yan Luo 0 Jul 18, 2021
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) English | 简体中文 This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flo

Media Computing Group @ Nankai University 537 Jan 07, 2023