Pytorch implementation of Compressive Transformers, from Deepmind

Overview

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • aux_loss does not update any weigth

    aux_loss does not update any weigth

    Hi lucidrains, thanks for your implementation, it is very elegant and helped me a lot with my disertation. Anyway I can't understand a particular: it seems like aux_loss is not related to any weight because of the detaching in the last part of the SelfAttention layer. With the following code, for example, I get that there is no layer optimized by aux_loss:

    import torch
    from compressive_transformer_pytorch import CompressiveTransformer
    from compressive_transformer_pytorch import AutoregressiveWrapper
    
    model = CompressiveTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        seq_len = 1024,
        mem_len = 1024,
        cmem_len = 256,
        cmem_ratio = 4,
        memory_layers = [5,6]
    ).cuda()
    
    model = AutoregressiveWrapper(model)
    
    inputs = torch.randint(0, 20000, (1, 1024)).cuda()
    
    optimizer = torch.optim.Adam(model.parameters())
    
    for loss, aux_loss, _ in model(inputs, return_loss = True):
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        print("OPTIMIZED BY LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
        optimizer.zero_grad(set_to_none=True)
        aux_loss.backward(retain_graph=True)
        print("OPTIMIZED BY AUX_LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
    

    I am not expert about the PyTorch mechanisms, so maybe I am getting something wrong. Again thank you

    opened by StefanoBerti 3
  • How to use this for speech/audio generation?

    How to use this for speech/audio generation?

    Great work Phil! In their paper, the authors applied this model to speech modeling, how would you advise on what should I change to use for speech. Because in speech, the data are signals, we do not have num_tokens, nor do we have emb_dim. Our data input is simply, [batch, channel, time]. Any advice?

    opened by jinglescode 3
  • [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    hello, I run code "examples/enwik8_simple" now, and I got error as follows:

    train.py:65: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) training: 0%| | 0/100000 [00:00<?, ?it/s] Traceback (most recent call last): File "train.py", line 101, in <module> for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/autoregressive_wrapper.py", line 151, in forward logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 338, in f orward x, = ff(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 84, in fo rward out = self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 106, in f orward return self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 140, in f orward x = self.act(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 122, in f orward return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) NameError: name 'math' is not defined

    so I inserted "import math" into compressive_transformer_pytorch.py file and it work well. I hope you modify compressive_transformer_pytorch.py code.

    opened by dinoSpeech 3
  • Training enwik8 but loss fail to converge

    Training enwik8 but loss fail to converge

    Hi lucidrains, I appreciate your implementation very much, and it helps me a lot with understanding compressive transformer. However when I tried running your code (enwik8 and exactly the same code in github), and the loss failed to converge after 100 epochs. Is this in expectation ? Or should I do other additional effort to improve, for example tokenizing the raw data in enwik8 and remove all the xml tags ? The figure below is the training and validation loss while I train enwik8 with the same code as in github.

    截圖 2021-03-26 下午5 40 16 截圖 2021-03-26 下午5 41 22

    Thanks and look forward to your reply!

    opened by KaiPoChang 2
  • Details about text generation

    Details about text generation

    Hi lucidrains, Thank you for your excellent code. I am curious about the generation scripts. Could you tell me how to generate text with the compressive transformer? Because it has the compressive memory, maybe we cannot use the current predicted word as the input for the next generation (input length ==1). In addition, if the prompt has 100 words and we use tokens [0:100], tokens[1:101], tokens[2:102]... as the input for the following timesteps, the tokens[1:100] may overlap with the memory, because the memory already contains hidden states for tokens[1:100].

    I would be very appeciated if you can provide the generation scripts!

    Thank you

    opened by theseventhflow 3
  • Links to original tf code - fyi

    Links to original tf code - fyi

    After reading deepmind blog post I was looking forward to downloading model but no luck. Looking forward to your implementation.

    You may be aware of this post and link but if not this is the coder's original tf implementation. Hope it helps.

    Copy of comment to original model request:

    https://github.com/huggingface/transformers/issues/4688

    Interested in model weights too but currently not available. Author does mention releasing tf code here:

    https://news.ycombinator.com/item?id=22290227

    Requires tf 1.15+ and deepmind/sonnet ver 1.36. Link to python script here:

    https://github.com/deepmind/sonnet/blob/cd5b5fa48e15e4d020f744968f5209949ebe750f/sonnet/python/modules/nets/transformer.py#L915

    Have tried running as-is but doesn't appear to have options for training on custom data as per the paper and available data sets.

    opened by GenTxt 8
Releases(0.4.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation of the paper ''Implicit Feature Refinement for Instance Segmentation''.

Implicit Feature Refinement for Instance Segmentation This repository is an official implementation of the ACM Multimedia 2021 paper Implicit Feature

Lufan Ma 17 Dec 28, 2022
根据midi文件演奏“风物之诗琴”的脚本 "Windsong Lyre" auto play

Genshin-lyre-auto-play 简体中文 | English 简介 根据midi文件演奏“风物之诗琴”的脚本。由Python驱动,在此承诺, ⚠️ 项目内绝不含任何能够引起安全问题的代码。 前排提示:所有键盘在动但是原神没反应的都是因为没有管理员权限,双击run.bat或者以管理员模式

御坂17032号 386 Jan 01, 2023
Official Implementation of SWAD (NeurIPS 2021)

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21) Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha 97 Dec 20, 2022
Official Implementation of "Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras"

Multi Camera Pig Tracking Official Implementation of Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras CVPR2021 CV4Animals Workshop P

44 Jan 06, 2023
Blender Add-on that sets a Material's Base Color to one of Pantone's Colors of the Year

Blender PCOY (Pantone Color of the Year) MCMC (Mid-Century Modern Colors) HG71 (House & Garden Colors 1971) Blender Add-ons That Assign a Custom Color

Don Schnitzius 15 Nov 20, 2022
PyTorch implementation of probabilistic deep forecast applied to air quality.

Probabilistic Deep Forecast PyTorch implementation of a paper, titled: Probabilistic Deep Learning to Quantify Uncertainty in Air Quality Forecasting

Abdulmajid Murad 13 Nov 16, 2022
Deep Surface Reconstruction from Point Clouds with Visibility Information

Data, code and pretrained models for the paper Deep Surface Reconstruction from Point Clouds with Visibility Information.

Raphael Sulzer 23 Jan 04, 2023
We simulate traveling back in time with a modern camera to rephotograph famous historical subjects.

[SIGGRAPH Asia 2021] Time-Travel Rephotography [Project Website] Many historical people were only ever captured by old, faded, black and white photos,

298 Jan 02, 2023
African language Speech Recognition - Speech-to-Text

Swahili-Speech-To-Text Table of Contents Swahili-Speech-To-Text Overview Scenario Approach Project Structure data: models: notebooks: scripts tests: l

2 Jan 05, 2023
Official PyTorch implementation of "Evolving Search Space for Neural Architecture Search"

Evolving Search Space for Neural Architecture Search Usage Install all required dependencies in requirements.txt and replace all ..path/..to in the co

Yuanzheng Ci 10 Oct 24, 2022
Nb workflows - A workflow platform which allows you to run parameterized notebooks programmatically

NB Workflows Description If SQL is a lingua franca for querying data, Jupyter sh

Xavier Petit 6 Aug 18, 2022
A Convolutional Transformer for Keyword Spotting

☢️ Audiomer ☢️ Audiomer: A Convolutional Transformer for Keyword Spotting [ arXiv ] [ Previous SOTA ] [ Model Architecture ] Results on SpeechCommands

49 Jan 27, 2022
Fine-Tune EleutherAI GPT-Neo to Generate Netflix Movie Descriptions in Only 47 Lines of Code Using Hugginface And DeepSpeed

GPT-Neo-2.7B Fine-Tuning Example Using HuggingFace & DeepSpeed Installation cd venv/bin ./pip install -r ../../requirements.txt ./pip install deepspe

Nikita 180 Jan 05, 2023
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
PyTorch implementation of normalizing flow models

PyTorch implementation of normalizing flow models

Vincent Stimper 242 Jan 02, 2023
Sample code from the Neural Networks from Scratch book.

Neural Networks from Scratch (NNFS) book code Code from the NNFS book (https://nnfs.io) separated by chapter.

Harrison 172 Dec 31, 2022
Official implementation of "A Shared Representation for Photorealistic Driving Simulators" in PyTorch.

A Shared Representation for Photorealistic Driving Simulators The official code for the paper: "A Shared Representation for Photorealistic Driving Sim

VITA lab at EPFL 7 Oct 13, 2022
Experimental solutions to selected exercises from the book [Advances in Financial Machine Learning by Marcos Lopez De Prado]

Advances in Financial Machine Learning Exercises Experimental solutions to selected exercises from the book Advances in Financial Machine Learning by

Brian 1.4k Jan 04, 2023
Minimisation of a negative log likelihood fit to extract the lifetime of the D^0 meson (MNLL2ELDM)

Minimisation of a negative log likelihood fit to extract the lifetime of the D^0 meson (MNLL2ELDM) Introduction The average lifetime of the $D^{0}$ me

Son Gyo Jung 1 Dec 17, 2021
Permeability Prediction Via Multi Scale 3D CNN

Permeability-Prediction-Via-Multi-Scale-3D-CNN Data: The raw CT rock cores are obtained from the Imperial Colloge portal. The CT rock cores are sub-sa

Mohamed Elmorsy 2 Jul 06, 2022