Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Overview

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
You might also like...
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Comments
  • The order of masking and softmax operation

    The order of masking and softmax operation

    Hi,

    In memory_compressed_attention.py, I'm wondering if we need to do softmax operation after masking? Btw, if the entry in the mask should be float('-inf') instead of -float('-inf')? If I make something wrong, please correct me.

    image

    Thanks!

    opened by cfeng16 3
  • mask error in attention

    mask error in attention

    Very grateful for your pioneering work! I want to use it in Standard Transformer released in http://nlp.seas.harvard.edu/2018/04/03/attention.html. but it mat a mask error in training. more detail information shown as follow, the code i use: image class ConvCompress(nn.Module): def init(self, dim, ratio = 2, groups = 1): super(ConvCompress, self).init() self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups) #self.linear = nn.Linear(dim, dim)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)
    

    class MemoryCompressedAttention(nn.Module): def init(self, h, d_model, compression_factor = 2, dropout = 0.1): super(MemoryCompressedAttention, self).init() assert (d_model % h) == 0, 'dimension must be divisible by number of heads' self.h = h self.d_model = d_model self.d_k = d_model // h

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(d_model, compression_factor, groups = h)
    
        #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.wq = nn.Linear(d_model, d_model, bias = False)
        self.wk = nn.Linear(d_model, d_model, bias = False)
        self.wv = nn.Linear(d_model, d_model, bias = False)
    
        self.wo = nn.Linear(d_model, d_model)
    
        self.dropout = nn.Dropout(dropout)
    
        #self.null_k = nn.Parameter(torch.zeros(1, 1, d_model))
        #self.null_v = nn.Parameter(torch.zeros(1, 1, d_model))
    
    def forward(self, query, key, value, mask = None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        t = query.size(1)
        cf = self.compression_factor
    
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)
    
        # make sure keys and values sequence lengths
        # are divisible by the compression factor
        padding = cf - (t % cf)
        if padding != 0:
            key, value = map(lambda t: F.pad(t, (0, 0, padding, 0)), (key, value))
    
    
        # compress keys and values
        key, value = map(self.compress_fn, (key, value))
    
        # attach a null key and value, in the case that the first query has no keys to pay attention to
        null_k = nn.Parameter(torch.zeros(key.size(0), 1, self.d_model)).cuda()
        null_v = nn.Parameter(torch.zeros(value.size(0), 1, self.d_model)).cuda()
    
        key = torch.cat((null_k, key), dim=1)
        value = torch.cat((null_v, value), dim=1)
        
        # merge heads
        #query, key, value = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (query, key, value))
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = query.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
    
      
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
    
        # 3) "Concat" using a view and apply a final linear.   # split heads and combine
        x = x.contiguous().view(nbatches, -1, self.d_model)
        out = self.wo(x)
    
        return out
    

    The error was show that image

    I want to know how to fix it, and how to do mask for N*M matrix??

    opened by HN123-123 0
Releases(0.0.5)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Final project for Intro to CS class.

Financial Analysis Web App https://share.streamlit.io/mayurk1/fin-web-app-final-project/webApp.py 1. Project Description This project is a technical a

Mayur Khanna 1 Dec 10, 2021
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
PyTorch implementation of SIFT descriptor

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Pytorch implementation of four neural network based domain adaptation techniques: DeepCORAL, DDC, CDAN and CDAN+E. Evaluated on benchmark dataset Office31.

Deep-Unsupervised-Domain-Adaptation Pytorch implementation of four neural network based domain adaptation techniques: DeepCORAL, DDC, CDAN and CDAN+E.

Alan Grijalva 49 Dec 20, 2022
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
Paddle implementation for "Highly Efficient Knowledge Graph Embedding Learning with Closed-Form Orthogonal Procrustes Analysis" (NAACL 2021)

ProcrustEs-KGE Paddle implementation for Highly Efficient Knowledge Graph Embedding Learning with Orthogonal Procrustes Analysis 🙈 A more detailed re

Lincedo Lab 4 Jun 09, 2021
DeRF: Decomposed Radiance Fields

DeRF: Decomposed Radiance Fields Daniel Rebain, Wei Jiang, Soroosh Yazdani, Ke Li, Kwang Moo Yi, Andrea Tagliasacchi Links Paper Project Page Abstract

UBC Computer Vision Group 24 Dec 02, 2022
"Projelerle Yapay Zeka Ve Bilgisayarlı Görü" Kitabımın projeleri

"Projelerle Yapay Zeka Ve Bilgisayarlı Görü" Kitabımın projeleri Bu Github Reposundaki tüm projeler; kaleme almış olduğum "Projelerle Yapay Zekâ ve Bi

Ümit Aksoylu 4 Aug 03, 2022
ICCV2021 Expert-Goal Trajectory Prediction

ICCV 2021: Where are you heading? Dynamic Trajectory Prediction with Expert Goal Examples This repository contains the code for the paper Where are yo

hz 21 Dec 12, 2022
Official Implementation of PCT

Official Implementation of PCT Prerequisites python == 3.8.5 Please make sure you have the following libraries installed: numpy torch=1.4.0 torchvisi

32 Nov 21, 2022
First-Order Probabilistic Programming Language

FOPPL: A First-Order Probabilistic Programming Language This is an implementation of FOPPL, an S-expression based probabilistic programming language d

Renato Costa 23 Dec 20, 2022
Implementation for the "Surface Reconstruction from 3D Line Segments" paper.

Surface Reconstruction from 3D Line Segments Surface reconstruction from 3d line segments. Langlois, P. A., Boulch, A., & Marlet, R. In 2019 Internati

85 Jan 04, 2023
[AAAI 2021] MVFNet: Multi-View Fusion Network for Efficient Video Recognition

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

Wenhao Wu 114 Nov 27, 2022
EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness

EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness Improving GAN Equilibrium by Raising Spatial Awareness Jianyuan Wang, Ceyuan Yang, Ying

GenForce: May Generative Force Be with You 149 Dec 19, 2022
Madanalysis5 - A package for event file analysis and recasting of LHC results

Welcome to MadAnalysis 5 Outline What is MadAnalysis 5? Requirements Downloading

MadAnalysis 15 Jan 01, 2023
Remote sensing change detection using PaddlePaddle

Change Detection Laboratory Developing and benchmarking deep learning-based remo

Lin Manhui 15 Sep 23, 2022
Mscp jamf - Build compliance in jamf

mscp_jamf Build compliance in Jamf. This will build the following xml pieces to

Bob Gendler 3 Jul 25, 2022
f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation

f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation [Paper] [PyTorch] [MXNet] [Video] This repository provides code for training

Visual Understanding Lab @ Samsung AI Center Moscow 516 Dec 21, 2022
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
Code and dataset for ACL2018 paper "Exploiting Document Knowledge for Aspect-level Sentiment Classification"

Aspect-level Sentiment Classification Code and dataset for ACL2018 [paper] ‘‘Exploiting Document Knowledge for Aspect-level Sentiment Classification’’

Ruidan He 146 Nov 29, 2022