RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

Overview

RETRO - Pytorch (wip)

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.

If you are interested, please join this Discord for discussions

Install

$ pip install retro-pytorch

Usage

import torch
from retro_pytorch import RETRO

retro = RETRO(
    num_tokens = 20000,                      # number of tokens
    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 796,                           # decoder model dim
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
)

seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

loss = retro(seq, retrieved, return_loss = True)
loss.backward()

# do above for many steps

Todo

  • handle indexing of corpus of text with faiss
  • handle reindexing of all nearest neighbors
  • function for getting frozen BERT embeddings for batch of chunks
  • handle partially filled chunks with mask and null tokens as a safeguard
  • inference code, autoretrieving at chunk boundaries
  • autohandle retrieved chunks for last chunk in sequence, whether it is given or not

Citations

@misc{borgeaud2022improving,
    title   = {Improving language models by retrieving from trillions of tokens}, 
    author  = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},
    year  = {2022},
    eprint = {2112.04426},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

I consider always the adult life to be the continuous retrieval of childhood. - Umberto Eco

Comments
  • Error Reconstructing FAISS Index

    Error Reconstructing FAISS Index

    Hiya! Thanks for making this library out in the open!

    I've been trying to get your training wrapper working, but when it tries to generate the index, I get the following error:

    RuntimeError: Error in virtual void faiss::Index::reconstruct(faiss::Index::idx_t, float*) const at /project/faiss/faiss/Index.cpp:48: reconstruct not implemented for this type of index
    

    To reproduce, you can use this google colab: https://colab.research.google.com/drive/1BcOtBpWBGmXX_tOC7WKcHOa9SukWEKpf?usp=sharing

    Any help with this would be greatly appreciated!

    opened by ncoop57 18
  • Why are there so many position embeddings?

    Why are there so many position embeddings?

    Hi! Thanks for your great work, it's very helpful for my project! I was just curious why there are so many position embeddings. Essentially it looks like the sequence is also being added a (1 to n) pos emb initially in the RETRO class, and then in each attention module rotary embeddings are added again. I thought just two in the Attention and CCA would be quite enough. Thanks in advance!

    opened by jasperhyp 5
  • `doc_ids_memmap` shape

    `doc_ids_memmap` shape

    https://github.com/lucidrains/RETRO-pytorch/blob/7d305379b72232c54262742d3f80326ed5fdac9e/retro_pytorch/retrieval.py#L138

    Is there a reason doc_ids_memmap is shape (max_docs, )? Shouldn't it be (max_chunks, ) since it's mapping chunks to doc ids?

    opened by josephcappadona 5
  • rotary embedding question

    rotary embedding question

    I have a two questions about the rotary embedding implementation.

    1. Divide the d-dimension space in to d/2 sub-spaces

    In rotary embedding, head_dim is divided by 2 to utilize the conjugate space with sin and cos.

    from rotary_embedding_torch import RotaryEmbedding
    
    head_dim = 64
    rotary_emb = RotaryEmbedding(dim=head_dim)
    
    class RotaryEmbedding(nn.Module):
        def __init__(
            self,
            dim,
            custom_freqs = None,
            freqs_for = 'lang',
            theta = 10000,
            max_freq = 10,
            num_freqs = 1,
            learned_freq = False
        ):
            super().__init__()
            if exists(custom_freqs):
                freqs = custom_freqs
            elif freqs_for == 'lang':
                # freqs.shape == (head_dim // 2)
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
            ...
    

    But the freqs of the rotary in RETRO is kind of weird. Rotary embedding in RETRO's Encoder and Decoder divides head_dim by 2 in advance and puts it as an input.

    https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L380-L381

    And divide freq by 2 once again in the initializer as shown below.

    https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L92-L96

    In this way, when head_dim=48, the shape of freqs is obtained as follows.

    Because the apply_rotary_emb function concats the tensor that exceeds rot_dim, the shape of the resulting tensor is the same, but the rotary pos does not seem to be fully applied.

    Hence, I think you need to modify the two lines of code as below.

    • https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L76
      • The resulting tensor has the same shape.
    >>> ASIS
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    <<< TOBE
                freqs = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
    
    • https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py#L95
      • As shown in the confirmation code below, the above modification is the same as the existing rotary embedding implementation.
      import torch
      dim1 = hid_dim // n_heads
      dim2 = (hid_dim // n_heads) // 2
      freqs1 = 1. / (10000 ** (torch.arange(0, dim1, 2).float() / dim1))
      freqs2 = 1. / (10000 ** (torch.arange(0, dim2, 1).float() / dim2))
      assert torch.equal(freqs1, freqs2)
      
    >>> ASIS
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    <<< TOBE
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 1).float() / dim))
    

    2. rotate_half function

    The rotary_half implementations of RETRO-pytorch and rotary-embedding-torch are slightly different.

    # In rotary-embedding-torch
    # https://github.com/lucidrains/rotary-embedding-torch/blob/517ee2cfeb10602032ef9d282c19851e19dd8943/rotary_embedding_torch/rotary_embedding_torch.py#L34
    def rotate_half(x):
        x = rearrange(x, '... (d r) -> ... d r', r = 2)
        x1, x2 = x.unbind(dim = -1)
        x = torch.stack((-x2, x1), dim = -1)
        return rearrange(x, '... d r -> ... (d r)')
    
    # In RETRO-pytorch
    # https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L104
    def rotate_half(x):
        x = rearrange(x, '... (j d) -> ... j d', j = 2)
        x1, x2 = x.unbind(dim = -2)
        return torch.cat((-x2, x1), dim = -1)
    

    In rotary, concat is stacked with [0 1 0 1 0 1 0 1], and retro is stacked with [0 0 0 0 1 1 1 1].

    • [0 0 0 0] is pre-half
    • [1 1 1 1] is post-half

    I wonder why it was implemented with this change! (just curious)

    Looking at your implementation, I am studying and matching the thesis. Thank you always :)

    opened by jinmang2 3
  • Autoregressivity

    Autoregressivity

    I had a question about Figure 2 and equation 3 from the paper. How does the last token of each chunk C_u being able to attend to the retrieved content E_u not break autoregressivity?

    opened by sdpmas 3
  • Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Hiya, hope Ice Cream is doing well, as well as you of course!

    I've been trying to get distributed training working with your library and I discovered this additional Linear layer encoder_output_to_decoder_dim not being used any where:

    https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py#L491

    It seems to be a copy of the layer defined right above it to_decoder_model_dim, which does get used. Having this extra layer that is not part of the loss calculation causes the following error with data parallelism:

    [RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.](https://github.com/pytorch/pytorch/issues/43259#)
    

    Not sure if this layer is supposed to be there and it just didn't get used or if it is there by accident, so wanted to ask 🤓

    opened by ncoop57 2
  • Question about the right position to encode `retrieved`

    Question about the right position to encode `retrieved`

    Hi, I am currently reading through the code and got confused when I reached this line:

    https://github.com/lucidrains/RETRO-pytorch/blob/92ff28755df53352547b1868fb03feae9931c1dd/retro_pytorch/retro_pytorch.py#L598

    image According to Algorithm 1 in the paper (the screenshot above), doesn't this line need to go inside the decoder, under this line? https://github.com/lucidrains/RETRO-pytorch/blob/92ff28755df53352547b1868fb03feae9931c1dd/retro_pytorch/retro_pytorch.py#L406

    This is an example of how I think the code of decoder.forward should be.

    def forward(self, x, *, context_mask = None, retrieved = None):
      encoded = False  # flag to know if p = min(P) (in the algorithm)
      ...
        if exists(cross_attn) and exists(retrieved):
          if not encoded:
            ...
            # use x (H at layer p where p = min(P)), not embed (Emb(X))
            x_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)
            retrieved = self.encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = x_as_context)
            encoded = True
    
    opened by soheeyang 2
  • Confusions about cross attentions in encoder

    Confusions about cross attentions in encoder

    In your code https://github.com/lucidrains/RETRO-pytorch/blob/5260d70fae085ed0cc5cbe3e2d1b35947fee475d/retro_pytorch/retro_pytorch.py#L115-L119

    When this class is called by Encoder, the x means retrieved chunks. In attentional mechanisms it produces q matrix, but i think it should produce k,v matrix. In encoder input sequence just lead us to make attention in retrieved chunks word.

    https://github.com/lucidrains/RETRO-pytorch/blob/5260d70fae085ed0cc5cbe3e2d1b35947fee475d/retro_pytorch/retro_pytorch.py#L288-L294

    opened by Hi-archers 2
  • 'NoneType' object is not callable

    'NoneType' object is not callable

    when I run the example of "RETRO Datasets", there is a wrong aboubt TypeError:

    Traceback (most recent call last): File "/home/fgq/all/RETRO/fuxian_2.py", line 58, in retro = RETRO( File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 507, in init self.encoder = Encoder( File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 337, in init wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)), File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 73, in init self.norm = norm_klass(dim) TypeError: 'NoneType' object is not callable

    code

    save_memmap( './train.chunks.dat', np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1))) )

    • generate nearest neighbors for each chunk

    save_memmap( './train.chunks.knn.dat', np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS))) )

    • generate seq data

    save_memmap( './train.seq.dat', np.int32(np.random.randint(0, 128, size = (NUM_SEQS,))) )

    • instantiate dataset class train_ds = RETRODataset( num_sequences = NUM_SEQS, num_chunks = NUM_CHUNKS, num_neighbors = NUM_NEIGHBORS, chunk_size = CHUNK_SIZE, seq_len = 2048, chunk_memmap_path = './train.chunks.dat', chunk_nn_memmap_path = './train.chunks.knn.dat', seq_memmap_path = './train.seq.dat' )
    opened by f-guoqiang 1
  • Fix reconstruction error discussed in #15

    Fix reconstruction error discussed in #15

    This PR fixes the issue with reconstruction of the faiss index. One caveat is that we can no longer do memmapping to reduce RAM overhead. Maybe this will be fixed in faiss soon, but for now memory will be an issue for extremely large indices.

    opened by ncoop57 1
  • Update retrieval.py

    Update retrieval.py

    The build_index command

    In the autofaiss document "–embeddings" Description : "Source path of the directory containing your .npy embedding files. If there are several files, they are read in the lexicographical order. This can be a local path or a path in another Filesystem e.g. hdfs://root/… or s3://…"

    The build_index function read embedding folders in lexicographical order, but now saves embedding files in order of "0.npy, 1.npy, 2.npy,..., n.npy", then build_index read embeddings in order of "0.npy, 1.npy, 10.npy......., 2.npy,..., n.npy", So I fill in some zeros in front of the embedding file name to make the build_index work normal.

    opened by Hi-archers 1
  • Causal mask in Chunked Cross Attention

    Causal mask in Chunked Cross Attention

    When computing the chunked cross-attention (line 259 here https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py), a causal mask is used. If I'm not mistaken, in the code we apply the causal mask to the last dimension of x (last words). However, my understanding was that the mask should be applied to the first dimensions as in the figure from the repo: image

    Is it normal?

    opened by Jonor127-OP 0
  • How to give Prompt to trained RETRO Model?

    How to give Prompt to trained RETRO Model?

    I am following the instructions on the RETRO-pytorch GItHub repo. After training my model, how do I go about using it to generate responses?

    retro = RETRO(
        chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
        max_seq_len = 2048,                      # max sequence length
        enc_dim = 896,                           # encoder model dim
        enc_depth = 2,                           # encoder depth
        dec_dim = 796,                           # decoder model dim
        dec_depth = 12,                          # decoder depth
        dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
        heads = 8,                               # attention heads
        dim_head = 64,                           # dimension per head
        dec_attn_dropout = 0.25,                 # decoder attention dropout
        dec_ff_dropout = 0.25,                   # decoder feedforward dropout
        use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
    )
    
    seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
    retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)
    
    loss = retro(seq, retrieved, return_loss = True)
    loss.backward()
    
    wrapper = TrainingWrapper(
        retro = retro,                                 # path to retro instance
        knn = 2,                                       # knn (2 in paper was sufficient)
        chunk_size = 64,                               # chunk size (64 in paper)
        documents_path = './retro_training_set/',              # path to folder of text
        glob = '**/*.txt',                             # text glob
        chunks_memmap_path = './train.chunks.dat',     # path to chunks
        seqs_memmap_path = './train.seq.dat',          # path to sequence data
        doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
        max_chunks = 1_000_000,                        # maximum cap to chunks
        max_seqs = 100_000,                            # maximum seqs
        knn_extra_neighbors = 100,                     # num extra neighbors to fetch
        max_index_memory_usage = '100m',
        current_memory_available = '1G'    
    )
    

    Now when I want to give this model a text input (any prompt), how would I go about doing that? Which method or function would I use? Which model/tokenizer should I use to encode the input prompt and then decode the model output tensor? Is there a method for that?

    Example Prompt: "The movie Dune was released in"

    opened by shahmeer99 1
  • Scann vs faiss

    Scann vs faiss

    Could you elaborate on the decision to use faiss instead of scann? In theory scann is open source too, but I'm wondering if you found easier to get the performance needed from faiss instead.

    opened by afcruzs 5
  • Clarification on Architecture

    Clarification on Architecture

    Reading the original paper, I took it that RETRO was a standard transformer (ie.. 12 layer encoder, 12 layer decoder) augmented with a DB retrieval system that included a second smaller (2 layer) encoder for the Frozen Bart encoded neighbors, where the 2 layer encoder was sort of a translator between the Bart model and the main transformer.

    Looking at the model here, it looks like there is only the 2 layer retrieval encoder and not a full-size main encoder. Is that correct?

    Going back and re-reading the paper it doesn't seem to explicitly say one way or the other. It seems odd to me that the model would only have the 2 layer retrieval encoder. Not only would this mean that the encoder is only 2 layers but it also means that most decoder layers have no standard cross attention to the encoder, only layers 6, 9, 12 with the new CCA setup.

    Has anyone trained the model from this repo and demonstrated that it can produce the results from the original paper?

    opened by bjascob 0
  • Retro-fitting a pretrained model

    Retro-fitting a pretrained model

    Hey,

    Thank you for your implementation! Is it possible to use your library to "retro-fit" a pretrained model?

    I guess it would mean freezing the model during training, only fine-tuning the retrieval and cross-attention? How would you recommend doing that?

    Thanks!

    opened by dean-sh 6
Releases(v0.3.8a)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Real-Time Social Distance Monitoring tool using Computer Vision

Social Distance Detector A Real-Time Social Distance Monitoring Tool Table of Contents Motivation YOLO Theory Detection Output Tech Stack Functionalit

Pranav B 13 Oct 14, 2022
This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation

This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation. Yolov5 is used to detect fire and smoke and unet is used to segment fire.

7 Jan 08, 2023
Implementation for the paper 'YOLO-ReT: Towards High Accuracy Real-time Object Detection on Edge GPUs'

YOLO-ReT This is the original implementation of the paper: YOLO-ReT: Towards High Accuracy Real-time Object Detection on Edge GPUs. Prakhar Ganesh, Ya

69 Oct 19, 2022
TensorFlow implementation of the paper "Hierarchical Attention Networks for Document Classification"

Hierarchical Attention Networks for Document Classification This is an implementation of the paper Hierarchical Attention Networks for Document Classi

Quoc-Tuan Truong 83 Dec 05, 2022
Network Pruning That Matters: A Case Study on Retraining Variants (ICLR 2021)

Network Pruning That Matters: A Case Study on Retraining Variants (ICLR 2021)

Duong H. Le 18 Jun 13, 2022
Official repository for MixFaceNets: Extremely Efficient Face Recognition Networks

MixFaceNets This is the official repository of the paper: MixFaceNets: Extremely Efficient Face Recognition Networks. (Accepted in IJCB2021) https://i

Fadi Boutros 51 Dec 13, 2022
Official Implementation of LARGE: Latent-Based Regression through GAN Semantics

LARGE: Latent-Based Regression through GAN Semantics [Project Website] [Google Colab] [Paper] LARGE: Latent-Based Regression through GAN Semantics Yot

83 Dec 06, 2022
A simple library that implements CLIP guided loss in PyTorch.

pytorch_clip_guided_loss: Pytorch implementation of the CLIP guided loss for Text-To-Image, Image-To-Image, or Image-To-Text generation. A simple libr

Sergei Belousov 74 Dec 26, 2022
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
Generalized Random Forests

generalized random forests A pluggable package for forest-based statistical estimation and inference. GRF currently provides non-parametric methods fo

GRF Labs 781 Dec 25, 2022
Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Perceiver This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on t

Rishit Dagli 84 Oct 15, 2022
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
Code for "Optimizing risk-based breast cancer screening policies with reinforcement learning"

Tempo: Optimizing risk-based breast cancer screening policies with reinforcement learning Introduction This repository was used to develop Tempo, as d

Adam Yala 12 Oct 11, 2022
A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

George Gunter 4 Nov 14, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Human-Pose-and-Motion History

Human Pose and Motion Scientist Approach Eadweard Muybridge, The Galloping Horse Portfolio, 1887 Etienne-Jules Marey, Descent of Inclined Plane, Chron

Daito Manabe 47 Dec 16, 2022
As a part of the HAKE project, includes the reproduced SOTA models and the corresponding HAKE-enhanced versions (CVPR2020).

HAKE-Action HAKE-Action (TensorFlow) is a project to open the SOTA action understanding studies based on our Human Activity Knowledge Engine. It inclu

Yong-Lu Li 94 Nov 18, 2022
Lightweight Face Image Quality Assessment

LightQNet This is a demo code of training and testing [LightQNet] using Tensorflow. Uncertainty Losses: IDQ loss PCNet loss Uncertainty Networks: Mobi

Kaen 5 Nov 18, 2022
Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation

Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation (AAAI 2021) Official pytorch implementation of our paper: Discriminative

Beom 74 Dec 27, 2022
[CVPR 2021] Unsupervised Degradation Representation Learning for Blind Super-Resolution

DASR Pytorch implementation of "Unsupervised Degradation Representation Learning for Blind Super-Resolution", CVPR 2021 [arXiv] Overview Requirements

Longguang Wang 318 Dec 24, 2022