An implementation of Performer, a linear attention-based transformer, in Pytorch

Overview

Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).

Install

$ pip install performer-pytorch

Usage

Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = nn.ReLU(),          # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    tie_embedding = False,          # multiply final embeddings with token weights for logits, like gpt decoder
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True      # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding
)

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True
)

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,
)

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False
)

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo

Advanced

At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following

model.fix_projection_matrices_()

Now your model will have fixed projection matrices across all layers

Citations

@misc{choromanski2020rethinking,
    title   = {Rethinking Attention with Performers},
    author  = {Krzysztof Choromanski and Valerii Likhosherstov and David Dohan and Xingyou Song and Andreea Gane and Tamas Sarlos and Peter Hawkins and Jared Davis and Afroz Mohiuddin and Lukasz Kaiser and David Belanger and Lucy Colwell and Adrian Weller},
    year    = {2020},
    eprint  = {2009.14794},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{katharopoulos_et_al_2020,
    author  = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title   = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year    = {2020}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@techreport{zhuiyiroformer,
    title   = {RoFormer: Transformer with Rotary Position Embeddings - ZhuiyiAI},
    author  = {Jianlin Su},
    year    = {2021},
    url     = "https://github.com/ZhuiyiTechnology/roformer",
}
Comments
  • [Feature] EncoderDecoder framework, similar to ReformerEncDec

    [Feature] EncoderDecoder framework, similar to ReformerEncDec

    Hello Phil,

    Nice job on this great architecture. I want to use it as an Encoder Decoder within Deepspeed, so I am thinking of writing a wrapper similar to the one you did for Reformer. Do you have any tips on what to pay attention (no pun intended) and if I need to use padding as in Autopadder?

    Thanks

    opened by gulnazaki 22
  • Causal linear attention benchmark

    Causal linear attention benchmark

    First, thanks for this awesome repo!!

    Based on T5 model classes from Huggingface's transformers, I was trying to use performer attention instead of original T5 attention. We finetuned t5-large with summarization model, and tried to profile both time and memory usage, and compare the performer attention with the original attention. I have only benchmarked with input size of 1024.

    The result clearly showed that performer attention use lot less memory compared to the original transformer. I know from the paper that performer outperforms the original transformer when input size is bigger than 1024. However, finetuning and generation with the performer actually took longer, so I profiled the forward call of both the original T5 attention and the performer attention. The forward of T5 performer took twice longer and the main bottleneck was causal_dot_product_kernel from fast-transformers.

    Is this a normal performace of the performer or causal attention calculation? or Will the performer attention be faster with the bigger input size?

    opened by ice-americano 13
  • Regarding DDP and reversible networks

    Regarding DDP and reversible networks

    Hi, I'm trying to figure out how to combine DDP with setting the network to be reversible.

    My code basically looks like this:

    import pytorch_lightning as pl
    from performer_pytorch import Performer
    ...
    model = nn.Sequential([...,Performer(...,reversible=True)])
    trainer = pl.Trainer(...
                        distributed_backend='ddp',
                        ...)
    trainer.fit(model,train_loader,val_loader)
    

    Now all combinations work for me (ddp/not reversible, not ddp/reversible, not ddp/not reversible) except for ddp and reversible.

    The error I get is:

    RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:

    1. Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes
    2. Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.

    I've seen multiple people have similar issues: https://github.com/huggingface/transformers/issues/7160 ,https://github.com/pytorch/pytorch/issues/46166 , https://github.com/tatp22/linformer-pytorch/issues/23

    Do you have any suggestion for how to deal with this issue? Im not really familiar with the inner workings of DDP and the autograd engine, so I'm not sure how to fix this myself.

    opened by Parskatt 11
  • Triangular matrices ?

    Triangular matrices ?

    Does the current implementation provide triangular matrices (to constrain the attention always on the "left" of the sequence, both for input and encoded values) as described in the last section of the original paper?

    opened by jeremycochoy 10
  • wrong implementation for autoregressive self-attention

    wrong implementation for autoregressive self-attention

    Hi, I found that you used fast_transfomers's CUDA Kernel, but it does not contain normalization part, which needs a cumsum outside the CausalDotProduct (in causal_linear_attention). If I didn't miss something, the result of your code should be wrong... But I am not 100% sure.

    opened by Sleepychord 10
  • There are no tests in this project, use_rezero=True is non-functional

    There are no tests in this project, use_rezero=True is non-functional

    Tests are needed to validate that models can train in various configurations. I built and ran simple tests (trying to get authorization to contribute as a PR) and found that use_rezero=True kills the gradient and results in a performer model that cannot learn. The fix consists in initializing the rezero parameter with a small value, but not zero (e.g., 1E-3 works in my tests). Zero prevents any signal to pass through the module so that the parameter will never change from zero.

    opened by fcampagne 10
  • Show what is the performance on enwiki8 is across your projects

    Show what is the performance on enwiki8 is across your projects

    Hello @lucidrains , I´m a very big fan of your work. It is of such as high quality, that every new project you release I get sleepless to try it.

    You do have many different versions of transformers, such as reformer, memory-xl, performer... And apparently you already test it with enwiki8.

    Would be possible to post on Read-me a table with the enwiki runtime, memory and some performance metric? That would be awesome to compare the different implementations.

    Thanks again for your work!!

    opened by bratao 10
  • Issue with biased estimates from QR decomposition

    Issue with biased estimates from QR decomposition

    Hi again :)

    See issue: https://github.com/google-research/google-research/issues/436 that I posted on the main repository. Using the QR incorrectly produces results with significantly higher variance. There is quite an easy fix by simply doing

     q, r = torch.qr(flattened)
     # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
     d = torch.diag(r, 0)
     ph = d.sign()
     q *= ph
    
    opened by Parskatt 9
  • Collaborate on Implementation?

    Collaborate on Implementation?

    I was planning on implementing this on Pytorch as well and started a repo https://github.com/calclavia/Performer-Pytorch Implemented the kernel so far. If the author(s) of this repo wants to collaborate, would be happy to contribute.

    opened by calclavia 9
  • Extra FF when using cross attention

    Extra FF when using cross attention

    Hello Phil,

    I have noticed that when using cross attention a new block (with attention and a FeedForward layer is added), while only an attention layer should be added between the self attention and the FF layer.

    Is there any reason for this?

    opened by gulnazaki 8
  • Add feature_redraw_interval option

    Add feature_redraw_interval option

    This fork allows the user to select a number of forward passes after which the random features will be redrawn. This allows us to avoid doing QR decomposition on the CPU every forward pass. By default it is set to redraw every 1000 passes.

    opened by norabelrose 8
  • Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Hi,

    First of all, this is a great package from lucidrains and I find it very helpful in my research.

    A quick question is that I noticed ViT-performer is slower than the regular ViT from lucidrains. For example running on mnist from pytorch will take 15 sec/epoch for regular ViT with the configuration below while ViT performer takes 23 sec/epoch.

    Checking the parameter count also shows ViT-performer has double the size of regular ViT.

    Screen Shot 2022-12-12 at 11 32 41 PM Screen Shot 2022-12-12 at 11 28 50 PM

    I am hoping that someone has intuition about the speed of ViT performer vs regular ViT and their parameter counts.

    Thank you very much in advance!

    opened by weihaosong 1
  • Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    As the title says, has anyone tried replacing multi head attention in a typical transformer with the self attention as described in this library.

    my thought was that I can essentially concat the multiple self attention elements together to replicate this per the attached image from the torch website. image

    I'm relatively new to transformers as a whole so hopefully this question makes some sense.

    for reference, considering the interest in a previous post, I've been attempting to explore performer effectiveness with DETR (https://github.com/facebookresearch/detr)

    thanks!

    opened by JGittles 0
  • Question about masking

    Question about masking

    Hi, thanks for the wonderful repo, I am new in BERT, so I 'd like to make sure in your example:

    model = PerformerLM() x = torch.randint(0, 20000, (1, 2048)) mask = torch.ones_like(x).bool() model(x, mask = mask) # (1, 2048, 20000)

    is this 'mask' is attention_mask? i.e., TRUE (1) for normal tokens and FALSE (0) for padding tokens? Or set 1 to indicate padding token? Thanks a lot!

    opened by Microbiods 1
  • Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Hi,

    Thanks for the amazing implementation. I'm wondering if Performer can be used like a set-operator (i.e. whether it is order equivariant)

    For instance, say I have a point cloud and I want to apply self-attention across all the point features. Can Performer be used here (note equivariance: points can be arbitrarily shuffled, but we expect the corresponding transformed features to be identical regardless of the shuffling)?

    Thanks!

    opened by nmakes 0
  • Using Performer with GNNs

    Using Performer with GNNs

    My understanding of "Rethinking Attention with Performers" is that FAVOR+ is used to approximate the attention matrix and avoids the use of the softmax function. In the README.md file, you note that the Plain Performer can be used if we are using images or other modalities, just as the authors elude to Performer's use in other areas.

    I am interested in using Perfomer to approximate attention between nodes in a graph neural network. The graph neural network contains vectors characterizing the node's features and boolean edge indices indicating a connection between two nodes.

    Do you have any recommendations how this is feasible with the current Performer model? I see that Attention.forward() contains input for a mask.

    opened by jah377 0
Releases(1.1.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
Training PyTorch models with differential privacy

Opacus is a library that enables training PyTorch models with differential privacy. It supports training with minimal code changes required on the cli

1.3k Dec 29, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022