Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

Overview

H-Transformer-1D

Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs.

For now, the H-Transformer will only act as a long-context encoder

Install

$ pip install h-transformer-1d

Usage

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 2,                 # depth
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128           # block size
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Citations

@misc{zhu2021htransformer1d,
    title   = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences}, 
    author  = {Zhenhai Zhu and Radu Soricut},
    year    = {2021},
    eprint  = {2107.11906},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Masking not working in training, thanks

    Masking not working in training, thanks

    Hi, I have tried to train the model on GPU with masking enabled. The line 94 t = torch.flip(t, dims = (2,)) reports an error: RuntimeError: "flip_cuda" not implemented for 'Bool', even though I have tried to move mask to CPU.

    Any ideas to solve the problem? Thanks a lot.

    opened by junyongyou 6
  • Application to sequence classification?

    Application to sequence classification?

    Hi,

    Forgive the naive question, I am trying to make sense of this paper but it's tough going. If I understand correctly, this attention mechanism focuses mainly on nearby tokens and only attends to distant tokens via a hierarchical, low-rank approximation. In that case, can the usual sequence classification approach of having a global [CLS] token that can attend to all other tokens (and vice versa) still work? If not, how can this attention mechanism handle the text classification tasks in the long range arena benchmark?

    Cheers for whatever insights you can share, and thanks for the great work!

    opened by trpstra 4
  • eos token does not work in batch mode generation

    eos token does not work in batch mode generation

    When generating the sequence with current code it seems the eos_token will work when generating one sequence at a time https://github.com/lucidrains/h-transformer-1d/blob/main/h_transformer_1d/autoregressive_wrapper.py#L59

    opened by tmphex 4
  • RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    RuntimeError: Tensor type unknown to einops

    lib/python3.7/site-packages/einops/_backends.py", line 52, in get_backend raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

    RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    I understand why this gets raised. Could it be a pytorch version problem? Mine is 1.6.0

    opened by wajihullahbaig 4
  • Algorithm Mismatch

    Algorithm Mismatch

    Paper Implementation

    In the implementation, we get blocked Q, K, V tensors by level with the code below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L164-L179

    And return the final result of matrix-matrix product with Equation 29 or 69 with the for loop below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L234-L247

    What is problem?

    However, according to the current code, it is not possible to include information about the level 0 white blocks in the figure below. (Equation 70 of the paper includes the corresponding attention matrix entries.)

    fig2

    I think you should also add an off-diagonal term of near-interaction (level 0) to match Equation 69!

    opened by jinmang2 3
  • I have some questions about implementation details

    I have some questions about implementation details

    Thanks for making your implementation public. I have three questions about your h-transformer 1d implementation.

    1. The number of levels M

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L105-L114

    In papers, eq (32) gives a guide on how M is determined.

    img1

    In your code implementations,

    • n is sequence length which is not padded
    • bsz is block size (is same to N_r which is numerical rank of the off-diagonal blocks)
    • Because code line 111 already contains level 0, M is equal to int(log2(n // bsz)) - 1

    However, in the Section 6.1 Constructing Hierarchical Attention, I found that sequence length(L) must be a multiple of 2. In my opinion, eq (31)'s L is equal to 2^{M+1}. In implementation, n is not padded sequence. So one of M is missing.

    Since x is the sequence padded by processing below, https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L83-L91

    I think the above implementation should be modified as below

    107    num_levels = int(log2(x.size(1) // bsz)) - 1 
    

    2. Super- and Sub-diagonal blocks of the coarsened matrix \tilde{A} as level-l

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L190-L198

    Ys conatins y and A computed as calculate_Y_and_A. For examples,

    # Ys
    [
        (batch_size*n_heads, N_b(2), N_r), (batch_size*n_heads, N_b(2)),  # level 2, (Y(2), tilde_A(2))
        (batch_size*n_heads, N_b(1), N_r), (batch_size*n_heads, N_b(1)),  # level 1, (Y(1), tilde_A(1))
        (batch_size*n_heads, N_b(0), N_r), (batch_size*n_heads, N_b(0)),  # level 0, (Y(0), tilde_A(0))
    ]
    

    In eq (29), Y is calculated as Y = AV = Y(0) + P(0)( Y(1) + P(1)Y(2) ) However, in code line 190, Y is calculated using only level-0 and level-1 blocks, no matter how many M there are. Y = AV = Y(0) + P(0)Y(1)

    Does increasing the level cause performance degradation issues in implementation? I'm so curious!


    3. Comparison with Luna: Linear Unified Nested Attention

    h-transformer significantly exceeded the scores of BigBird and Luna in LRA. However, what I regretted while reading the paper was that there was no comparison of computation time with other sub-quadratic and Luna. Is this algorithm much faster than other sub-quadratic? And how about compared to Luna?


    Thanks again for the implementation release! The idea of ​​calculating off-diagonal with flip was amazing and I learned a lot. Thank you!! 😄

    opened by jinmang2 3
  • Add Norm Missing

    Add Norm Missing

    I am using code now, and i wonder is there implemented add norm? I only find layer norm, but no add operation. Here is code in h-transformer-1d.py line 489 ... Is this a bug or something ? Thanks @Lucidrains

    for ind in range(depth): attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs) ff = FeedForward(dim, mult = ff_mult)

            if shift_tokens:
                attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff))
    
            attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff))
            layers.append(nn.ModuleList([attn ,ff]))_
    
    opened by wwx13 2
  • Mask not working

    Mask not working

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length'
        x = self.token_emb(x)
        x = self.layers(x)
        return self.to_logits(x)
    

    I think... Masking does not work ???

    opened by wwx13 2
  • One simple question

    One simple question

    Hi, Phil!

    One simple question, (my math is not good) https://github.com/lucidrains/h-transformer-1d/blob/7c11d036d53926495ec0917a34a1aad7261892b5/train.py#L65

    why not be randint(0, self.data.size(0)-self.seq_len+1)? Since the high part should be excluded

    opened by CiaoHe 2
  • Mini-batching (b > 1) does not work with masking

    Mini-batching (b > 1) does not work with masking

    When using x and mask that have batch size larger than 1 following error is arises:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    batch_size = 2
    x = torch.randint(0, 256, (batch_size, 8000))   # variable sequence length
    mask = torch.ones((batch_size, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives following error:

    ~/git/h-transformer-1d/h_transformer_1d/h_transformer_1d.py in masked_aggregate(tensor, mask, dim, average)
         19     diff_len = len(tensor.shape) - len(mask.shape)
         20     mask = mask[(..., *((None,) * diff_len))]
    ---> 21     tensor = tensor.masked_fill(~mask, 0.)
         22 
         23     total_el = mask.sum(dim = dim)
    
    RuntimeError: The size of tensor a (2) must match the size of tensor b (16) at non-singleton dimension 0
    

    It seems the tensor has shape heads * batch in 0 dimension and not batch what mask has.

    opened by jaak-s 2
  • Example in README does not work

    Example in README does not work

    Executing the example:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    x = torch.randint(0, 256, (1, 8000))   # variable sequence length
    mask = torch.ones((1, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives the following error:

    ~/miniconda3/lib/python3.7/site-packages/rotary_embedding_torch/rotary_embedding_torch.py in apply_rotary_emb(freqs, t, start_index)
         43     assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
         44     t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    ---> 45     t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
         46     return torch.cat((t_left, t, t_right), dim = -1)
         47 
    
    RuntimeError: The size of tensor a (8192) must match the size of tensor b (8000) at non-singleton dimension 1
    
    opened by jaak-s 2
  • Fix indexing

    Fix indexing

    I am fixing a few apparent bugs in the code. The upshot is that the attention now supports a block size of the (next largest power of two) of the input length, and for this value of the block size it becomes exact. This allows one to look at the systematic error in the output as a function of decreased block size (and memory usage).

    I've found this module to reduce memory consumption by a factor of two, but the approximation quickly becomes too inaccurate with decreasing block size to use it as a drop-in replacement for an existing (full) attention layer.

    This repository shows how to compute the full attention with linear memory complexity: https://github.com/CHARM-Tx/linear_mem_attention_pytorch

    opened by jglaser 0
  • Approximated values are off

    Approximated values are off

    I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.

    import torch
    import torch.nn as nn
    import math
    
    from h_transformer_1d.h_transformer_1d import HAttention1D
    
    def transpose_for_scores(x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
            dim_head = query.size()[-1] // num_attention_heads
            all_head_size = dim_head*num_attention_heads
    
            query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
            key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
            value_layer = transpose_for_scores(value, num_attention_heads, dim_head)
    
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    
            attention_scores = attention_scores / math.sqrt(dim_head)
    
            if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask
    
            # Normalize the attention scores to probabilities.
            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    
            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            #attention_probs = self.dropout(attention_probs)
    
            context_layer = torch.matmul(attention_probs, value_layer)
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
    
            return context_layer, attention_probs
    
    if __name__ == "__main__":
        query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
    #    query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
        key = value = query
    
        n_heads = 1
        attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
        print('bert_self_attention out: ', attn)
    
        block_size = 1
        for _ in range(0,2):
            dim_head = query.size()[-1]//n_heads
            h_attn = HAttention1D(
                dim=query.size()[-1],
                heads=n_heads,
                dim_head=dim_head,
                block_size=block_size
            )
    
            h_attn.to_qkv = torch.nn.Identity()
            h_attn.to_out = torch.nn.Identity()
    
            qkv = torch.stack([query, key, value], dim=2)
            qkv = torch.flatten(qkv, start_dim=2)
    
            attn_scores = h_attn(qkv)
            print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)
    
            block_size *= 2
    

    This is the output I get

    bert_self_attention:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2000,  0.4500],
             [-0.2000,  0.4500],
             [-0.1885, -0.1470],
             [-0.1885, -0.1470]]])
    

    before it errors out with

    assert num_levels >= 0, 'number of levels must be at least greater than 0'
    

    Some of the values are off in absolute magnitude by more than a factor of two.

    Looking at the code, this line seems problematic: https://github.com/lucidrains/h-transformer-1d/blob/8afd75cc6bc41754620bb6ab3737176cb69bdf93/h_transformer_1d/h_transformer_1d.py#L172

    I believe it should read

    num_levels = int(log2(pad_to_len // bsz)) - 1
    

    If I make that change, the approximated attention output is much closer to the exact one:

    bert_self_attention out:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020]]])
    hattention_1d: (block_size = 2) tensor([[[-0.1808,  0.1972],
             [-0.1980,  0.2314],
             [-0.2438,  0.0910],
             [-0.1719,  0.2413]]])
    
    opened by jglaser 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python=3.7 pytorch=1.6.0 torchvision=0.8

Yunfan Li 210 Dec 30, 2022
Implementation for paper: Self-Regulation for Semantic Segmentation

Self-Regulation for Semantic Segmentation This is the PyTorch implementation for paper Self-Regulation for Semantic Segmentation, ICCV 2021. Citing SR

Dong ZHANG 30 Nov 21, 2022
ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation

ST++ This is the official PyTorch implementation of our paper: ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation. Lihe Ya

Lihe Yang 147 Jan 03, 2023
Implementation of ICLR 2020 paper "Revisiting Self-Training for Neural Sequence Generation"

Self-Training for Neural Sequence Generation This repo includes instructions for running noisy self-training algorithms from the following paper: Revi

Junxian He 45 Dec 31, 2022
Accelerate Neural Net Training by Progressively Freezing Layers

FreezeOut A simple technique to accelerate neural net training by progressively freezing layers. This repository contains code for the extended abstra

Andy Brock 203 Jun 19, 2022
Using Tensorflow Object Detection API to detect Waymo open dataset

Waymo-2D-Object-Detection Using Tensorflow Object Detection API to detect Waymo open dataset Result CenterNet Training Loss SSD ResNet Training Loss C

76 Dec 12, 2022
A Re-implementation of the paper "A Deep Learning Framework for Character Motion Synthesis and Editing"

What is This This is a simple re-implementation of the paper "A Deep Learning Framework for Character Motion Synthesis and Editing"(1). Only Sections

102 Dec 14, 2022
Joint-task Self-supervised Learning for Temporal Correspondence (NeurIPS 2019)

Joint-task Self-supervised Learning for Temporal Correspondence Project | Paper Overview Joint-task Self-supervised Learning for Temporal Corresponden

Sifei Liu 167 Dec 14, 2022
RuDOLPH: One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP

[Paper] [Хабр] [Model Card] [Colab] [Kaggle] RuDOLPH 🦌 🎄 ☃️ One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP Russian Diffusio

AI Forever 232 Jan 04, 2023
Human pose estimation from video plays a critical role in various applications such as quantifying physical exercises, sign language recognition, and full-body gesture control.

Pose Detection Project Description: Human pose estimation from video plays a critical role in various applications such as quantifying physical exerci

Hassan Shahzad 2 Jan 17, 2022
Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks.

The Lottery Ticket Hypothesis for Pre-trained BERT Networks Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks. [NeurIPS

VITA 122 Dec 14, 2022
Prototypical Networks for Few shot Learning in PyTorch

Prototypical Networks for Few shot Learning in PyTorch Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code)

Orobix 835 Jan 08, 2023
Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators This is our Pytorch implementation for t

RUCAIBox 12 Jul 22, 2022
QMagFace: Simple and Accurate Quality-Aware Face Recognition

Quality-Aware Face Recognition 26.11.2021 start readme QMagFace: Simple and Accurate Quality-Aware Face Recognition Research Paper Implementation - To

Philipp Terhörst 59 Jan 04, 2023
Spectral Temporal Graph Neural Network (StemGNN in short) for Multivariate Time-series Forecasting

Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting This repository is the official implementation of Spectral Temporal Gr

Microsoft 306 Dec 29, 2022
Symbolic Music Generation with Diffusion Models

Symbolic Music Generation with Diffusion Models Supplementary code release for our work Symbolic Music Generation with Diffusion Models. Installation

Magenta 119 Jan 07, 2023
Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 904 Dec 21, 2022
PointPillars inference with TensorRT

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.

NVIDIA AI IOT 315 Dec 31, 2022
Easily Process a Batch of Cox Models

ezcox: Easily Process a Batch of Cox Models The goal of ezcox is to operate a batch of univariate or multivariate Cox models and return tidy result. ⏬

Shixiang Wang 15 May 23, 2022