Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Overview

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(torch.arange(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = torch.randn(1, 1024, 64) # queries - (batch, seq len, dimension of head)
k = torch.randn(1, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # unsqueeze for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = torch.cat((rots1, rots2), dim = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Learning embeddings for classification, retrieval and ranking.
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

Improving XGBoost survival analysis with embeddings and debiased estimators
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Reliable probability face embeddings
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

 UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

🤖 A Python library for learning and evaluating knowledge graph embeddings
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Comments
  • Custom position offset when rotating queries or keys

    Custom position offset when rotating queries or keys

    This library seems to assume that queries and keys are left-aligned position-wise e.g.

    q = [p_0, p_1, p_2]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

    q =           [p_2, p_3, p_4]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

    import torch
    from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding
    
    rot = RotaryEmbedding(dim=32)
    
    q = torch.ones(1, 8, 4, 32)
    k = torch.ones(1, 8, 6, 32)
    
    q = q / torch.norm(q, dim=-1, keepdim=True)
    k = k / torch.norm(k, dim=-1, keepdim=True)
    
    q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    prints the following relative position embedding

    tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
            [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
            [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])
    

    (diagonal of 1s right-aligned) whereas the default behavior

    ...
    
    q_rot = rot.rotate_queries_or_keys(q)
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    would print

    tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
            [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
            [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])
    

    (diagonal of 1s left-aligned).

    opened by krasserm 1
  • about axial rotary embeddings

    about axial rotary embeddings

    Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi " Where does this formula come from?What parameter is max_freqs?Why the freqs is not " 1/(10000^(2i/d))"?

    Thank you again.

    opened by raindrop313 0
Owner
Phil Wang
Working with Attention
Phil Wang
Code for our ICCV 2021 Paper "OadTR: Online Action Detection with Transformers".

Code for our ICCV 2021 Paper "OadTR: Online Action Detection with Transformers".

66 Dec 15, 2022
Official code for our EMNLP2021 Outstanding Paper MindCraft: Theory of Mind Modeling for Situated Dialogue in Collaborative Tasks

MindCraft Authors: Cristian-Paul Bara*, Sky CH-Wang*, Joyce Chai This is the official code repository for the paper (arXiv link): Cristian-Paul Bara,

Situated Language and Embodied Dialogue (SLED) Research Group 14 Dec 29, 2022
CRISCE: Automatically Generating Critical Driving Scenarios From Car Accident Sketches

CRISCE: Automatically Generating Critical Driving Scenarios From Car Accident Sketches This document describes how to install and use CRISCE (CRItical

Chair of Software Engineering II, Uni Passau 2 Feb 09, 2022
Learning to trade under the reinforcement learning framework

Trading Using Q-Learning In this project, I will present an adaptive learning model to trade a single stock under the reinforcement learning framework

Uirá Caiado 470 Nov 28, 2022
This is a simple plugin for Vim that allows you to use OpenAI Codex.

🤖 Vim Codex An AI plugin that does the work for you. This is a simple plugin for Vim that will allow you to use OpenAI Codex. To use this plugin you

Tom Dörr 195 Dec 28, 2022
Predicting Auction Sale Price using the kaggle bulldozer auction sales data: Modeling with Ensembles vs Neural Network

Predicting Auction Sale Price using the kaggle bulldozer auction sales data: Modeling with Ensembles vs Neural Network The performances of tree ensemb

Mustapha Unubi Momoh 2 Sep 13, 2022
Individual Treatment Effect Estimation

CAPE Individual Treatment Effect Estimation Run CAPE python train_causal.py --loop 10 -m cape_cau -d NI --i_t 1 Run a baseline model python train_cau

S. Deng 4 Sep 02, 2022
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
[CVPR 2022 Oral] Rethinking Minimal Sufficient Representation in Contrastive Learning

Rethinking Minimal Sufficient Representation in Contrastive Learning PyTorch implementation of Rethinking Minimal Sufficient Representation in Contras

36 Nov 23, 2022
McGill Physics Hackathon 2021: Reaction-Diffusion Models for the Generation of Biological Patterns

DiffuseAnimals: Reaction-Diffusion Models for the Generation of Biological Patterns Introduction Reaction-diffusion equations can be utilized in order

Austin Szuminsky 2 Mar 07, 2022
Code from Daniel Lemire, A Better Alternative to Piecewise Linear Time Series Segmentation

PiecewiseLinearTimeSeriesApproximation code from Daniel Lemire, A Better Alternative to Piecewise Linear Time Series Segmentation, SIAM Data Mining 20

Daniel Lemire 21 Oct 27, 2022
Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation

Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation This is the implementation of the approach describ

Taosha Fan 47 Nov 15, 2022
sktime companion package for deep learning based on TensorFlow

NOTE: sktime-dl is currently being updated to work correctly with sktime 0.6, and wwill be fully relaunched over the summer. The plan is Refactor and

sktime 573 Jan 05, 2023
PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluation of Visual Stories via Semantic Consistency"

Improving Generation and Evaluation of Visual Stories via Semantic Consistency PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluat

Adyasha Maharana 28 Dec 08, 2022
Pytorch implementation of "Forward Thinking: Building and Training Neural Networks One Layer at a Time"

forward-thinking-pytorch Pytorch implementation of Forward Thinking: Building and Training Neural Networks One Layer at a Time Requirements Python 2.7

Kim Heecheol 65 Oct 06, 2022
Multimodal Temporal Context Network (MTCN)

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
RoadMap and preparation material for Machine Learning and Data Science - From beginner to expert.

ML-and-DataScience-preparation This repository has the goal to create a learning and preparation roadMap for Machine Learning Engineers and Data Scien

33 Dec 29, 2022