RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch


RETRO - Pytorch (wip)

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.

If you are interested, please join this Discord for discussions


$ pip install retro-pytorch


import torch
from retro_pytorch import RETRO

retro = RETRO(
    num_tokens = 20000,                      # number of tokens
    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 796,                           # decoder model dim
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout

seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

loss = retro(seq, retrieved, return_loss = True)

# do above for many steps


  • handle indexing of corpus of text with faiss
  • handle reindexing of all nearest neighbors
  • function for getting frozen BERT embeddings for batch of chunks
  • handle partially filled chunks with mask and null tokens as a safeguard
  • inference code, autoretrieving at chunk boundaries
  • autohandle retrieved chunks for last chunk in sequence, whether it is given or not


    title   = {Improving language models by retrieving from trillions of tokens}, 
    author  = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},
    year  = {2022},
    eprint = {2112.04426},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}

I consider always the adult life to be the continuous retrieval of childhood. - Umberto Eco

  • Error Reconstructing FAISS Index

    Error Reconstructing FAISS Index

    Hiya! Thanks for making this library out in the open!

    I've been trying to get your training wrapper working, but when it tries to generate the index, I get the following error:

    RuntimeError: Error in virtual void faiss::Index::reconstruct(faiss::Index::idx_t, float*) const at /project/faiss/faiss/Index.cpp:48: reconstruct not implemented for this type of index

    To reproduce, you can use this google colab:

    Any help with this would be greatly appreciated!

    opened by ncoop57 18
  • Why are there so many position embeddings?

    Why are there so many position embeddings?

    Hi! Thanks for your great work, it's very helpful for my project! I was just curious why there are so many position embeddings. Essentially it looks like the sequence is also being added a (1 to n) pos emb initially in the RETRO class, and then in each attention module rotary embeddings are added again. I thought just two in the Attention and CCA would be quite enough. Thanks in advance!

    opened by jasperhyp 5
  • `doc_ids_memmap` shape

    `doc_ids_memmap` shape

    Is there a reason doc_ids_memmap is shape (max_docs, )? Shouldn't it be (max_chunks, ) since it's mapping chunks to doc ids?

    opened by josephcappadona 5
  • rotary embedding question

    rotary embedding question

    I have a two questions about the rotary embedding implementation.

    1. Divide the d-dimension space in to d/2 sub-spaces

    In rotary embedding, head_dim is divided by 2 to utilize the conjugate space with sin and cos.

    from rotary_embedding_torch import RotaryEmbedding
    head_dim = 64
    rotary_emb = RotaryEmbedding(dim=head_dim)
    class RotaryEmbedding(nn.Module):
        def __init__(
            custom_freqs = None,
            freqs_for = 'lang',
            theta = 10000,
            max_freq = 10,
            num_freqs = 1,
            learned_freq = False
            if exists(custom_freqs):
                freqs = custom_freqs
            elif freqs_for == 'lang':
                # freqs.shape == (head_dim // 2)
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))

    But the freqs of the rotary in RETRO is kind of weird. Rotary embedding in RETRO's Encoder and Decoder divides head_dim by 2 in advance and puts it as an input.

    And divide freq by 2 once again in the initializer as shown below.

    In this way, when head_dim=48, the shape of freqs is obtained as follows.

    Because the apply_rotary_emb function concats the tensor that exceeds rot_dim, the shape of the resulting tensor is the same, but the rotary pos does not seem to be fully applied.

    Hence, I think you need to modify the two lines of code as below.

      • The resulting tensor has the same shape.
    >>> ASIS
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    <<< TOBE
                freqs = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
      • As shown in the confirmation code below, the above modification is the same as the existing rotary embedding implementation.
      import torch
      dim1 = hid_dim // n_heads
      dim2 = (hid_dim // n_heads) // 2
      freqs1 = 1. / (10000 ** (torch.arange(0, dim1, 2).float() / dim1))
      freqs2 = 1. / (10000 ** (torch.arange(0, dim2, 1).float() / dim2))
      assert torch.equal(freqs1, freqs2)
    >>> ASIS
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    <<< TOBE
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 1).float() / dim))

    2. rotate_half function

    The rotary_half implementations of RETRO-pytorch and rotary-embedding-torch are slightly different.

    # In rotary-embedding-torch
    def rotate_half(x):
        x = rearrange(x, '... (d r) -> ... d r', r = 2)
        x1, x2 = x.unbind(dim = -1)
        x = torch.stack((-x2, x1), dim = -1)
        return rearrange(x, '... d r -> ... (d r)')
    # In RETRO-pytorch
    def rotate_half(x):
        x = rearrange(x, '... (j d) -> ... j d', j = 2)
        x1, x2 = x.unbind(dim = -2)
        return, x1), dim = -1)

    In rotary, concat is stacked with [0 1 0 1 0 1 0 1], and retro is stacked with [0 0 0 0 1 1 1 1].

    • [0 0 0 0] is pre-half
    • [1 1 1 1] is post-half

    I wonder why it was implemented with this change! (just curious)

    Looking at your implementation, I am studying and matching the thesis. Thank you always :)

    opened by jinmang2 3
  • Autoregressivity


    I had a question about Figure 2 and equation 3 from the paper. How does the last token of each chunk C_u being able to attend to the retrieved content E_u not break autoregressivity?

    opened by sdpmas 3
  • Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Hiya, hope Ice Cream is doing well, as well as you of course!

    I've been trying to get distributed training working with your library and I discovered this additional Linear layer encoder_output_to_decoder_dim not being used any where:

    It seems to be a copy of the layer defined right above it to_decoder_model_dim, which does get used. Having this extra layer that is not part of the loss calculation causes the following error with data parallelism:

    [RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.](

    Not sure if this layer is supposed to be there and it just didn't get used or if it is there by accident, so wanted to ask 🤓

    opened by ncoop57 2
  • Question about the right position to encode `retrieved`

    Question about the right position to encode `retrieved`

    Hi, I am currently reading through the code and got confused when I reached this line:

    image According to Algorithm 1 in the paper (the screenshot above), doesn't this line need to go inside the decoder, under this line?

    This is an example of how I think the code of decoder.forward should be.

    def forward(self, x, *, context_mask = None, retrieved = None):
      encoded = False  # flag to know if p = min(P) (in the algorithm)
        if exists(cross_attn) and exists(retrieved):
          if not encoded:
            # use x (H at layer p where p = min(P)), not embed (Emb(X))
            x_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)
            retrieved = self.encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = x_as_context)
            encoded = True
    opened by soheeyang 2
  • Confusions about cross attentions in encoder

    Confusions about cross attentions in encoder

    In your code

    When this class is called by Encoder, the x means retrieved chunks. In attentional mechanisms it produces q matrix, but i think it should produce k,v matrix. In encoder input sequence just lead us to make attention in retrieved chunks word.

    opened by Hi-archers 2
  • 'NoneType' object is not callable

    'NoneType' object is not callable

    when I run the example of "RETRO Datasets", there is a wrong aboubt TypeError:

    Traceback (most recent call last): File "/home/fgq/all/RETRO/", line 58, in retro = RETRO( File "/home/fgq/all/RETRO/retro_pytorch/", line 507, in init self.encoder = Encoder( File "/home/fgq/all/RETRO/retro_pytorch/", line 337, in init wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)), File "/home/fgq/all/RETRO/retro_pytorch/", line 73, in init self.norm = norm_klass(dim) TypeError: 'NoneType' object is not callable


    save_memmap( './train.chunks.dat', np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1))) )

    • generate nearest neighbors for each chunk

    save_memmap( './train.chunks.knn.dat', np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS))) )

    • generate seq data

    save_memmap( './train.seq.dat', np.int32(np.random.randint(0, 128, size = (NUM_SEQS,))) )

    • instantiate dataset class train_ds = RETRODataset( num_sequences = NUM_SEQS, num_chunks = NUM_CHUNKS, num_neighbors = NUM_NEIGHBORS, chunk_size = CHUNK_SIZE, seq_len = 2048, chunk_memmap_path = './train.chunks.dat', chunk_nn_memmap_path = './train.chunks.knn.dat', seq_memmap_path = './train.seq.dat' )
    opened by f-guoqiang 1
  • Fix reconstruction error discussed in #15

    Fix reconstruction error discussed in #15

    This PR fixes the issue with reconstruction of the faiss index. One caveat is that we can no longer do memmapping to reduce RAM overhead. Maybe this will be fixed in faiss soon, but for now memory will be an issue for extremely large indices.

    opened by ncoop57 1
  • Update


    The build_index command

    In the autofaiss document "–embeddings" Description : "Source path of the directory containing your .npy embedding files. If there are several files, they are read in the lexicographical order. This can be a local path or a path in another Filesystem e.g. hdfs://root/… or s3://…"

    The build_index function read embedding folders in lexicographical order, but now saves embedding files in order of "0.npy, 1.npy, 2.npy,..., n.npy", then build_index read embeddings in order of "0.npy, 1.npy, 10.npy......., 2.npy,..., n.npy", So I fill in some zeros in front of the embedding file name to make the build_index work normal.

    opened by Hi-archers 1
  • Causal mask in Chunked Cross Attention

    Causal mask in Chunked Cross Attention

    When computing the chunked cross-attention (line 259 here, a causal mask is used. If I'm not mistaken, in the code we apply the causal mask to the last dimension of x (last words). However, my understanding was that the mask should be applied to the first dimensions as in the figure from the repo: image

    Is it normal?

    opened by Jonor127-OP 0
  • How to give Prompt to trained RETRO Model?

    How to give Prompt to trained RETRO Model?

    I am following the instructions on the RETRO-pytorch GItHub repo. After training my model, how do I go about using it to generate responses?

    retro = RETRO(
        chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
        max_seq_len = 2048,                      # max sequence length
        enc_dim = 896,                           # encoder model dim
        enc_depth = 2,                           # encoder depth
        dec_dim = 796,                           # decoder model dim
        dec_depth = 12,                          # decoder depth
        dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
        heads = 8,                               # attention heads
        dim_head = 64,                           # dimension per head
        dec_attn_dropout = 0.25,                 # decoder attention dropout
        dec_ff_dropout = 0.25,                   # decoder feedforward dropout
        use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
    seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
    retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)
    loss = retro(seq, retrieved, return_loss = True)
    wrapper = TrainingWrapper(
        retro = retro,                                 # path to retro instance
        knn = 2,                                       # knn (2 in paper was sufficient)
        chunk_size = 64,                               # chunk size (64 in paper)
        documents_path = './retro_training_set/',              # path to folder of text
        glob = '**/*.txt',                             # text glob
        chunks_memmap_path = './train.chunks.dat',     # path to chunks
        seqs_memmap_path = './train.seq.dat',          # path to sequence data
        doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
        max_chunks = 1_000_000,                        # maximum cap to chunks
        max_seqs = 100_000,                            # maximum seqs
        knn_extra_neighbors = 100,                     # num extra neighbors to fetch
        max_index_memory_usage = '100m',
        current_memory_available = '1G'    

    Now when I want to give this model a text input (any prompt), how would I go about doing that? Which method or function would I use? Which model/tokenizer should I use to encode the input prompt and then decode the model output tensor? Is there a method for that?

    Example Prompt: "The movie Dune was released in"

    opened by shahmeer99 1
  • Scann vs faiss

    Scann vs faiss

    Could you elaborate on the decision to use faiss instead of scann? In theory scann is open source too, but I'm wondering if you found easier to get the performance needed from faiss instead.

    opened by afcruzs 5
  • Clarification on Architecture

    Clarification on Architecture

    Reading the original paper, I took it that RETRO was a standard transformer (ie.. 12 layer encoder, 12 layer decoder) augmented with a DB retrieval system that included a second smaller (2 layer) encoder for the Frozen Bart encoded neighbors, where the 2 layer encoder was sort of a translator between the Bart model and the main transformer.

    Looking at the model here, it looks like there is only the 2 layer retrieval encoder and not a full-size main encoder. Is that correct?

    Going back and re-reading the paper it doesn't seem to explicitly say one way or the other. It seems odd to me that the model would only have the 2 layer retrieval encoder. Not only would this mean that the encoder is only 2 layers but it also means that most decoder layers have no standard cross attention to the encoder, only layers 6, 9, 12 with the new CCA setup.

    Has anyone trained the model from this repo and demonstrated that it can produce the results from the original paper?

    opened by bjascob 0
  • Retro-fitting a pretrained model

    Retro-fitting a pretrained model


    Thank you for your implementation! Is it possible to use your library to "retro-fit" a pretrained model?

    I guess it would mean freezing the model during training, only fine-tuning the retrieval and cross-attention? How would you recommend doing that?


    opened by dean-sh 6
Phil Wang
Working with Attention. It's all we need
Phil Wang
Curating a dataset for bioimage transfer learning

CytoImageNet A large-scale pretraining dataset for bioimage transfer learning. Motivation In past few decades, the increase in speed of data collectio

Stanley Z. Hua 9 Jun 20, 2022
Official PyTorch code for CVPR 2020 paper "Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision"

Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision Abstract Active learning (AL) aims to min

Denis 29 Nov 21, 2022
Differentiable Abundance Matching With Python

shamnet Differentiable Stellar Population Synthesis Installation You can install shamnet with pip. Installation dependencies are numpy, jax, corrfunc,

5 Dec 17, 2021
This is the official implementation for "Do Transformers Really Perform Bad for Graph Representation?".

Graphormer By Chengxuan Ying, Tianle Cai, Shengjie Luo, Shuxin Zheng*, Guolin Ke, Di He*, Yanming Shen and Tie-Yan Liu. This repo is the official impl

Microsoft 1.3k Dec 29, 2022
Tensorflow Implementation of Pixel Transposed Convolutional Networks (PixelTCN and PixelTCL)

Pixel Transposed Convolutional Networks Created by Hongyang Gao, Hao Yuan, Zhengyang Wang and Shuiwang Ji at Texas A&M University. Introduction Pixel

Hongyang Gao 95 Jul 24, 2022
Code for the paper "SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness" (NeurIPS 2021)

SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness (NeurIPS2021) This repository contains code for the paper "Smo

Jongheon Jeong 17 Dec 27, 2022
a generic C++ library for image analysis

VIGRA Computer Vision Library Copyright 1998-2013 by Ullrich Koethe This file is part of the VIGRA computer vision library. You may use,

Ullrich Koethe 378 Dec 30, 2022
Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR 2018).

Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR2018) By Zilong Huang, Xinggang Wang, Jiasi Wang, Wenyu Liu and J

Zilong Huang 245 Dec 13, 2022
Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

Bayesian Methods Research Group 56 Nov 15, 2022
PyTorch implementation of Pointnet2/Pointnet++

Pointnet2/Pointnet++ PyTorch Project Status: Unmaintained. Due to finite time, I have no plans to update this code and I will not be responding to iss

Erik Wijmans 1.2k Dec 29, 2022
Underwater industrial application yolov5m6

This project wins the intelligent algorithm contest finalist award and stands out from over 2000teams in China Underwater Robot Professional Contest, entering the final of China Underwater Robot Prof

8 Nov 09, 2022
Decentralized Reinforcment Learning: Global Decision-Making via Local Economic Transactions (ICML 2020)

Decentralized Reinforcement Learning This is the code complementing the paper Decentralized Reinforcment Learning: Global Decision-Making via Local Ec

40 Oct 30, 2022
A Keras implementation of CapsNet in the paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules

NOTE This implementation is fork of , applied to IMDB texts reviews dataset. CapsNet-Keras A Keras implemen

Lauro Moraes 5 Oct 23, 2022
Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Efficient Sharpness-aware Minimization for Improved Training of Neural Networks Code for “Efficient Sharpness-aware Minimization for Improved Training

Angusdu 32 Oct 18, 2022
An LSTM for time-series classification

Update 10-April-2017 And now it works with Python3 and Tensorflow 1.1.0 Update 02-Jan-2017 I updated this repo. Now it works with Tensorflow 0.12. In

Rob Romijnders 391 Dec 27, 2022
Toontown House CT Edition

Toontown House: Classic Toontown House Classic source that should just work. ❓ W

Open Source Toontown Servers 5 Jan 09, 2022
Spatial Sparse Convolution Library

SpConv: Spatially Sparse Convolution Library PyPI Install Downloads CPU (Linux Only) pip install spconv CUDA 10.2 pip install spconv-cu102 CUDA 11.1 p

Yan Yan 1.2k Jan 07, 2023
🌈 PyTorch Implementation for EMNLP'21 Findings "Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer"

SGLKT-VisDial Pytorch Implementation for the paper: Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer Gi-Cheon Kang, Junseok P

Gi-Cheon Kang 9 Jul 05, 2022
Geometry-Free View Synthesis: Transformers and no 3D Priors

Geometry-Free View Synthesis: Transformers and no 3D Priors Geometry-Free View Synthesis: Transformers and no 3D Priors Robin Rombach*, Patrick Esser*

CompVis Heidelberg 293 Dec 22, 2022
code and models for "Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation"

Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation This repository contains code and models for the method described in: Golnaz

55 Jun 18, 2022