Pytorch implementation of Compressive Transformers, from Deepmind

Overview

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • aux_loss does not update any weigth

    aux_loss does not update any weigth

    Hi lucidrains, thanks for your implementation, it is very elegant and helped me a lot with my disertation. Anyway I can't understand a particular: it seems like aux_loss is not related to any weight because of the detaching in the last part of the SelfAttention layer. With the following code, for example, I get that there is no layer optimized by aux_loss:

    import torch
    from compressive_transformer_pytorch import CompressiveTransformer
    from compressive_transformer_pytorch import AutoregressiveWrapper
    
    model = CompressiveTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        seq_len = 1024,
        mem_len = 1024,
        cmem_len = 256,
        cmem_ratio = 4,
        memory_layers = [5,6]
    ).cuda()
    
    model = AutoregressiveWrapper(model)
    
    inputs = torch.randint(0, 20000, (1, 1024)).cuda()
    
    optimizer = torch.optim.Adam(model.parameters())
    
    for loss, aux_loss, _ in model(inputs, return_loss = True):
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        print("OPTIMIZED BY LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
        optimizer.zero_grad(set_to_none=True)
        aux_loss.backward(retain_graph=True)
        print("OPTIMIZED BY AUX_LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
    

    I am not expert about the PyTorch mechanisms, so maybe I am getting something wrong. Again thank you

    opened by StefanoBerti 3
  • How to use this for speech/audio generation?

    How to use this for speech/audio generation?

    Great work Phil! In their paper, the authors applied this model to speech modeling, how would you advise on what should I change to use for speech. Because in speech, the data are signals, we do not have num_tokens, nor do we have emb_dim. Our data input is simply, [batch, channel, time]. Any advice?

    opened by jinglescode 3
  • [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    hello, I run code "examples/enwik8_simple" now, and I got error as follows:

    train.py:65: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) training: 0%| | 0/100000 [00:00<?, ?it/s] Traceback (most recent call last): File "train.py", line 101, in <module> for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/autoregressive_wrapper.py", line 151, in forward logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 338, in f orward x, = ff(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 84, in fo rward out = self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 106, in f orward return self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 140, in f orward x = self.act(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 122, in f orward return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) NameError: name 'math' is not defined

    so I inserted "import math" into compressive_transformer_pytorch.py file and it work well. I hope you modify compressive_transformer_pytorch.py code.

    opened by dinoSpeech 3
  • Training enwik8 but loss fail to converge

    Training enwik8 but loss fail to converge

    Hi lucidrains, I appreciate your implementation very much, and it helps me a lot with understanding compressive transformer. However when I tried running your code (enwik8 and exactly the same code in github), and the loss failed to converge after 100 epochs. Is this in expectation ? Or should I do other additional effort to improve, for example tokenizing the raw data in enwik8 and remove all the xml tags ? The figure below is the training and validation loss while I train enwik8 with the same code as in github.

    截圖 2021-03-26 下午5 40 16 截圖 2021-03-26 下午5 41 22

    Thanks and look forward to your reply!

    opened by KaiPoChang 2
  • Details about text generation

    Details about text generation

    Hi lucidrains, Thank you for your excellent code. I am curious about the generation scripts. Could you tell me how to generate text with the compressive transformer? Because it has the compressive memory, maybe we cannot use the current predicted word as the input for the next generation (input length ==1). In addition, if the prompt has 100 words and we use tokens [0:100], tokens[1:101], tokens[2:102]... as the input for the following timesteps, the tokens[1:100] may overlap with the memory, because the memory already contains hidden states for tokens[1:100].

    I would be very appeciated if you can provide the generation scripts!

    Thank you

    opened by theseventhflow 3
  • Links to original tf code - fyi

    Links to original tf code - fyi

    After reading deepmind blog post I was looking forward to downloading model but no luck. Looking forward to your implementation.

    You may be aware of this post and link but if not this is the coder's original tf implementation. Hope it helps.

    Copy of comment to original model request:

    https://github.com/huggingface/transformers/issues/4688

    Interested in model weights too but currently not available. Author does mention releasing tf code here:

    https://news.ycombinator.com/item?id=22290227

    Requires tf 1.15+ and deepmind/sonnet ver 1.36. Link to python script here:

    https://github.com/deepmind/sonnet/blob/cd5b5fa48e15e4d020f744968f5209949ebe750f/sonnet/python/modules/nets/transformer.py#L915

    Have tried running as-is but doesn't appear to have options for training on custom data as per the paper and available data sets.

    opened by GenTxt 8
Releases(0.4.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This is a computer vision based implementation of the popular childhood game 'Hand Cricket/Odd or Even' in python

Hand Cricket Table of Content Overview Installation Game rules Project Details Future scope Overview This is a computer vision based implementation of

Abhinav R Nayak 6 Jan 12, 2022
TGS Salt Identification Challenge

TGS Salt Identification Challenge This is an open solution to the TGS Salt Identification Challenge. Note Unfortunately, we can no longer provide supp

neptune.ai 123 Nov 04, 2022
A collection of random and hastily hacked together scripts for investigating EU-DCC

A collection of random and hastily hacked together scripts for investigating EU-DCC

Ryan Barrett 8 Mar 01, 2022
A `Neural = Symbolic` framework for sound and complete weighted real-value logic

Logical Neural Networks LNNs are a novel Neuro = symbolic framework designed to seamlessly provide key properties of both neural nets (learning) and s

International Business Machines 138 Dec 19, 2022
Federated Deep Reinforcement Learning for the Distributed Control of NextG Wireless Networks.

FDRL-PC-Dyspan Federated Deep Reinforcement Learning for the Distributed Control of NextG Wireless Networks. This repository contains the entire code

Peyman Tehrani 17 Nov 18, 2022
AAAI-22 paper: SimSR: Simple Distance-based State Representationfor Deep Reinforcement Learning

SimSR Code and dataset for the paper SimSR: Simple Distance-based State Representationfor Deep Reinforcement Learning (AAAI-22). Requirements We assum

7 Dec 19, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022) Introdu

anonymous 14 Oct 27, 2022
Full Resolution Residual Networks for Semantic Image Segmentation

Full-Resolution Residual Networks (FRRN) This repository contains code to train and qualitatively evaluate Full-Resolution Residual Networks (FRRNs) a

Toby Pohlen 274 Oct 27, 2022
An implementation of paper `Real-time Convolutional Neural Networks for Emotion and Gender Classification` with PaddlePaddle.

简介 通过PaddlePaddle框架复现了论文 Real-time Convolutional Neural Networks for Emotion and Gender Classification 中提出的两个模型,分别是SimpleCNN和MiniXception。利用 imdb_crop

8 Mar 11, 2022
bio_inspired_min_nets_improve_the_performance_and_robustness_of_deep_networks

Code Submission for: Bio-inspired Min-Nets Improve the Performance and Robustness of Deep Networks Run with docker To build a docker environment, chan

0 Dec 09, 2021
The code for "Deep Level Set for Box-supervised Instance Segmentation in Aerial Images".

Deep Levelset for Box-supervised Instance Segmentation in Aerial Images Wentong Li, Yijie Chen, Wenyu Liu, Jianke Zhu* This code is based on MMdetecti

sunshine.lwt 112 Jan 05, 2023
PyTorch Implementation of the paper Learning to Reweight Examples for Robust Deep Learning

Learning to Reweight Examples for Robust Deep Learning Unofficial PyTorch implementation of Learning to Reweight Examples for Robust Deep Learning. Th

Daniel Stanley Tan 325 Dec 28, 2022
Contrastive Language-Image Pretraining

CLIP [Blog] [Paper] [Model Card] [Colab] CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pair

OpenAI 11.5k Jan 08, 2023
Sentinel-1 vessel detection model used in the xView3 challenge

sar_vessel_detect Code for the AI2 Skylight team's submission in the xView3 competition (https://iuu.xview.us) for vessel detection in Sentinel-1 SAR

AI2 6 Sep 10, 2022
Chinese Mandarin tts text-to-speech 中文 (普通话) 语音 合成 , by fastspeech 2 , implemented in pytorch, using waveglow as vocoder,

Chinese mandarin text to speech based on Fastspeech2 and Unet This is a modification and adpation of fastspeech2 to mandrin(普通话). Many modifications t

291 Jan 02, 2023
code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

Facebook Research 94 Oct 26, 2022
PyTorch implementation of CVPR'18 - Perturbative Neural Networks

This is an attempt to reproduce results in Perturbative Neural Networks paper. See original repo for details.

Michael Klachko 57 May 14, 2021
PoolFormer: MetaFormer is Actually What You Need for Vision

PoolFormer: MetaFormer is Actually What You Need for Vision (arXiv) This is a PyTorch implementation of PoolFormer proposed by our paper "MetaFormer i

Sea AI Lab 1k Dec 30, 2022
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022