Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

Overview

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

Comments
  • Training falling on version 0.0.14 and 0.0.15

    Training falling on version 0.0.14 and 0.0.15

    Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Full error listing:

    ipython-input-7-1329da5363de> in forward(self, inputs, labels)
          7   def forward(self, inputs, labels=None):
          8     loss_mx = labels != -100
    ----> 9     output = self.model(inputs)
         10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
         11     labels = labels[loss_mx].view(-1)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        376         x = self.to_token_emb(x)
        377         x = self.pos_emb(torch.arange(t, device=device)) + x
    --> 378         x = self.sinkhorn_transformer(x)
        379         return self.to_logits(x)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        359 
        360     def forward(self, x, input_mask = None):
    --> 361         return self.layers(x)
        362 
        363 class SinkhornTransformerLM(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        330     def forward(self, x, **kwargs):
        331         x = torch.cat([x, x], dim=-1)
    --> 332         x = self.layers(x, **kwargs)
        333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
        334 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
        128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
        129 
    --> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
         98         ctx.kwargs = kwargs
         99         for block in blocks:
    --> 100             x = block(x, **kwargs)
        101         ctx.y = x.detach()
        102         ctx.blocks = blocks
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
         51         with torch.no_grad():
         52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
    ---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
         54 
         55         return torch.cat([y1, y2], dim=2)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
         25 
         26         if not set_rng:
    ---> 27             return self.net(*args, **kwargs)
         28 
         29         rng_devices = []
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        112     def forward(self, x, **kwargs):
        113         x = self.norm(x)
    --> 114         return self.fn(x, **kwargs)
        115 
        116 class SortNet(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
        103 
        104     def forward(self, x):
    --> 105         return self.net(x)
        106 
        107 class PreNorm(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
         98     def forward(self, input):
         99         for module in self:
    --> 100             input = module(input)
        101         return input
        102 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
         85 
         86     def forward(self, input):
    ---> 87         return F.linear(input, self.weight, self.bias)
         88 
         89     def extra_repr(self):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
       1591         ret = torch.addmm(bias, input, weight.t())
       1592     else:
    -> 1593         output = input.matmul(weight.t())
       1594         if bias is not None:
       1595             output += bias
    
    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Also, version 0.0.11(and all other version from 0.0.8) work stable.

    opened by blizda 30
  • generation problem in a toy task

    generation problem in a toy task

    Here is the full script for my toy task (x -> xx like "abc" to "abcabc")

    from sinkhorn_transformer import SinkhornTransformerLM
    from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
    
    import random
    import tqdm
    import gzip
    import numpy as np
    import torch
    import torch.optim as optim
    from torch import nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY  = 100
    GENERATE_EVERY  = 100
    ENC_SEQ_LEN=16
    DEC_SEQ_LEN=40
    NUM_TOKENS = 256 + 2
    BUCKET_SIZE = 8
    
    # helpers
    
    def top_k(logits, thres = 0.9):
        k = int((1 - thres) * logits.shape[-1])
        val, ind = torch.topk(logits, k)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(1, ind, val)
        return probs
    
    
    def cycle():
        while True:
            source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
    
            target = torch.cat((source, source), 1)
            prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
            target = torch.cat((prefix, target), axis=1)
    
            x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
            y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
    
    
            yield (source, target, x_mask, y_mask)
    
    # instantiate model
    
    class MySinkhornTransformer(nn.Module):
        def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
            super().__init__()
            
            self.pad_token = 0
            self.sos_token = 1
    
            self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                             reversible=True, return_embeddings=True)
            self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                             receives_context=True, context_bucket_size=bucket_size, reversible=True)
            self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
        
        @torch.no_grad()
        def generate(self, x, x_mask):
            context = self.enc(x, input_mask=x_mask)
            start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
    
            return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
    
        def forward(self, x, y, x_mask, y_mask, return_loss):
            context = self.enc(x, input_mask=x_mask)
            return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
    
    
    model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
    model.cuda()
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                source, target, x_mask, y_mask = next(cycle())
                loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
                print(f'validation loss: {loss.item()}')
    
        if i % GENERATE_EVERY == 0:
            model.eval()
            
            source, target, x_mask, y_mask = next(cycle())
            
            sample = model.generate(x=source, x_mask=x_mask)
            print("input:  ", source[0])
            print("model output:  ", sample[0])
    

    After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: "x,x,x,x,x,y,y,y,y,y" like "aaaabbbb" instead of "abcdabcd". I was wondering what might be the underlying issue. Do you got any idea?

    opened by py4 18
  • Sortcut variant and bucket size

    Sortcut variant and bucket size

    From the paper it looks to me that for the sortcut variant, the queries are allowed to attend to all the key buckets post truncation for the non-causal case. If this is correct, won't the output be the same no matter what the bucket size is for the query.

    Based on my understanding of the paper, the authors only report results selecting 2 top key/value blocks with bucket sizes 8,16,32. They do not mention having different bucket sizes for query and key/value which I found in this repo.

    opened by jsmith1915 16
  • A noobish question about training...

    A noobish question about training...

    @lucidrains

    First of all, everything seems to be working now in Google Colab so thank you very much for fixing it.

    I have a quick question about training if you do not mind...

    I get very high loss results. Here is the example:

    training: 1%| | 509/100000 [1:55:22<387:23:57, 14.02s/it]training loss: 2.495532512664795

    Is this normal and I simply need to train more? Or does it mean that there is a problem somewhere?

    2.49 in 2 hours is way too much IMHO. I am pretty sure I am not doing it right so your advice would be really apppreciated.

    Thank you.

    P.S. I am running your wiki8 example in Google Colab Pro.

    opened by asigalov61 9
  • several question about implementation

    several question about implementation

    I am currently trying to implement Sparse Sinkhorn attention (non-causal, self-attention, pretrain for MLM task) with tensorflow and I would appreciate it if you could answer several questions about your code.

    1. I did not quite understand from the paper what is happening with queries and keys in SortCut, when we cut off most of the blocks. In your code it seems like you broadcast keys/queries like this [number_blocks, block_len, size_per_head] -> [:n_sortcut, block_len, size_per_head] -> [1, n_sortcut * block_len, size_per_head] -> [number_blocks, n_sortcut * block_len, size_per_head] I understand that pytorch handling expand_dim automatically and memory is preserved, but I do not understand how does the memory consumption stay linear: the resulting attention scores (dots) are still [number_blocks, block_len, num_sortcut * block_len] which is good but quadratic. Do I miss something?

    2. In the case of SortCut what is the meaning of the softmax(A)*Q? For example, for full attention outputs are vectors which are weighted average of other vectors; For local block attention outputs are vectors that are weighted average of other vectors in the block. How would you interpret the output of SortCut attention layer?

    3. In your code you concatenate regular keys and values with permuted/sorted keys and values, but in the paper (3.2. part) it seems like simple addition. Why is it different from the paper? Actually I can understand concatenation more than summation (because I am confused which attention masks to use in this case: from the regular block of permuted).

    4. Seems like you are using concatenated queries and keys as an input to make a SortNet. Maybe I completely missed the point, but I did not find anything about this in the article: the input sequence is used for SortNet instead of its projections. In fact, for some reason, I cannot make it converge if I use query and keys (maybe for a completely unrelated reasons). Have you tried to compare input sequence vs concat(q,k)?

    opened by w4-magnes 9
  • TypeError exception in AxialPositionalEncoding when using DataParallel

    TypeError exception in AxialPositionalEncoding when using DataParallel

    Hello,

    I want to run SinkhornTransformerLM using multiple GPUs, so I'm wrapping the model into torch.nn.DataParallel. However, when I do this, I get an exception:

    Traceback (most recent call last):
      File "script.py", line 27, in <module>
        model(x)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
        raise self.exc_type(msg)
    TypeError: Caught TypeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 792, in forward
        x = self.axial_pos_emb(x) + x
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 243, in forward
        return pos_emb[:, :t]
    TypeError: 'int' object is not subscriptable
    

    Looking at the code, it would seem that self.weights does not get populated. To reproduce this error, I took the first example in README.md and changed

    model(x) # (1, 2048, 20000)
    

    to

    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).to('cuda')
    model(x)
    
    opened by kl0211 8
  • Do I need to pad?

    Do I need to pad?

    Excuse me for the noob question, I have sequences of different lengths and I use:

    model = AutoregressiveWrapper(model, ignore_index=PAD_ID, pad_value=PAD_ID)
    

    Do I need to pad as in this example or not?

    def __getitem__(self, index):
            seq_tokens = self.examples[index]
            input_ids = torch.tensor(seq_tokens, dtype=torch.long)
            input_ids = F.pad(input_ids, (seq_len - len(input_ids), 0), value=self.pad_value)
            return input_ids
    
    opened by timsoraro 8
  • Performance of the different attention variants in the repo

    Performance of the different attention variants in the repo

    Really great work!!

    I have a few questions.

    -Does attention sort net work better or at par with the simple sort net. -Does attention sort net work well for the sortcut variant also. -How does the sortcut variant perform when compared to vanilla sparse sinkhorn attention.

    It would be very helpful if you could share some plots/numbers from your experiments comparing the performance of the different variants such as attention sort net, routing based attention etc.

    opened by jsmith1915 7
  • Some questions about dropout

    Some questions about dropout

    Hi again @lucidrains, I just had some quick questions about dropout with the Sinkhorn Transformer, as I was just using my Linformer implementation (which as you know is based off of this repo), but it was overfitting my dataset. Therefore, I just had some quick questions about some dropout and your implementation, and I wanted to ask whether some design choices here were intentional or not:

    1. In the original Transformer, dropout was performed after each sublayer, before the residual connection. I noticed that you only have this after the SinkhornSelfAttention class, but not after the FeedForward class. Is this intentional?
    2. Speaking of the FeedForward class, you insert dropout after the first linear layer. I couldn't find this anywhere in any literature, were you able to find a reference of why this was effective? I put it into my implementation, and it seems to help, but i just don't know where this idea came from.
    3. On a similar note, do you know why the dots tensor in the self attention classes are dropped out? Again, I put it in my linformer and it seems to work, but I can't find a reference to this in the literature.
    4. Finally, the original transformer also dropped out the input tokens, like so (From the SinkhornTransformerLM class):
        def forward(self, x, **kwargs):
            _, t, device = *x.shape, x.device
            assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
    
            x = self.to_token_emb(x)
            x = self.axial_pos_emb(x) + x
            """ Dropout would go here"""
            x = self.sinkhorn_transformer(x, **kwargs)
            return self.to_logits(x)
    

    Should they also be dropped out here as well?

    I now updated my repo such that all 4 of these dropout possibilities exist. I'll let you know if this helps overfitting.

    Thank you for your time!

    opened by tatp22 6
  • Training crashes due to inplace operation

    Training crashes due to inplace operation

    Hi there, I am trying to train this model on my custom dataset and training starts but after many iterations, training crashes due to this error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    

    After enabling anomaly detection, here is the error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    

    Traceback is not useful as it only shows error at line loss.backward() and torch.autograd module. Can you help me where the issue might be?

    Thanks

    opened by NaxAlpha 6
  • Positional Embedding

    Positional Embedding

    Hi! Found out that SinkhornTransformerLM uses two positional encodings simultaneously:

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L778-L779

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L792-L793

    I guess pos_emb can be removed as it introduces memory overhead and makes useless the utilization of Axial Positional Encoding that is designed to reduce the number of positional encoding parameters.

    Is there a special reasoning behind that?

    opened by ilya16 4
  • A wrapper of SinkhornTransformerEncDec

    A wrapper of SinkhornTransformerEncDec

    Could your please coded up a wrapper that removes a lot of the manual work in writing up a generic SinkhornTransformer encoder / decoder architecture.

    Thanks a lot! halexan

    opened by halexan 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
DLO8012: Natural Language Processing & CSL804: Computational Lab - II

NATURAL-LANGUAGE-PROCESSING-AND-COMPUTATIONAL-LAB-II DLO8012: NLP & CSL804: CL-II [SEMESTER VIII] Syllabus NLP - Reference Books THE WALL MEGA SATISH

AMEY THAKUR 7 Apr 28, 2022
The official repository of the ISBI 2022 KNIGHT Challenge

KNIGHT The official repository holding the data for the ISBI 2022 KNIGHT Challenge About The KNIGHT Challenge asks teams to develop models to classify

Nicholas Heller 4 Jan 22, 2022
Summarization, translation, sentiment-analysis, text-generation and more at blazing speed using a T5 version implemented in ONNX.

Summarization, translation, Q&A, text generation and more at blazing speed using a T5 version implemented in ONNX. This package is still in alpha stag

Abel 211 Dec 28, 2022
The source code of "Language Models are Few-shot Multilingual Learners" (MRL @ EMNLP 2021)

Language Models are Few-shot Multilingual Learners Paper This is the source code of the paper [Arxiv] [ACL Anthology]: This code has been written usin

Genta Indra Winata 45 Nov 21, 2022
A multi-lingual approach to AllenNLP CoReference Resolution along with a wrapper for spaCy.

Crosslingual Coreference Coreference is amazing but the data required for training a model is very scarce. In our case, the available training for non

Pandora Intelligence 71 Jan 04, 2023
SGMC: Spectral Graph Matrix Completion

SGMC: Spectral Graph Matrix Completion Code for AAAI21 paper "Scalable and Explainable 1-Bit Matrix Completion via Graph Signal Learning". Data Format

Chao Chen 8 Dec 12, 2022
✨Fast Coreference Resolution in spaCy with Neural Networks

✨ NeuralCoref 4.0: Coreference Resolution in spaCy with Neural Networks. NeuralCoref is a pipeline extension for spaCy 2.1+ which annotates and resolv

Hugging Face 2.6k Jan 04, 2023
🦆 Contextually-keyed word vectors

sense2vec: Contextually-keyed word vectors sense2vec (Trask et. al, 2015) is a nice twist on word2vec that lets you learn more interesting and detaile

Explosion 1.5k Dec 25, 2022
Pipelines de datos, 2021.

Este repo ilustra un proceso sencillo de automatización de transformación y modelado de datos, a través de un pipeline utilizando Luigi. Stack princip

Rodolfo Ferro 8 May 19, 2022
An assignment from my grad-level data mining course demonstrating some experience with NLP/neural networks/Pytorch

NLP-Pytorch-Assignment An assignment from my grad-level data mining course (before I started personal projects) demonstrating some experience with NLP

David Thorne 0 Feb 06, 2022
Cherche (search in French) allows you to create a neural search pipeline using retrievers and pre-trained language models as rankers.

Cherche (search in French) allows you to create a neural search pipeline using retrievers and pre-trained language models as rankers. Cherche is meant to be used with small to medium sized corpora. C

Raphael Sourty 224 Nov 29, 2022
Mycroft Core, the Mycroft Artificial Intelligence platform.

Mycroft Mycroft is a hackable open source voice assistant. Table of Contents Getting Started Running Mycroft Using Mycroft Home Device and Account Man

Mycroft 6.1k Jan 09, 2023
Pytorch NLP library based on FastAI

Quick NLP Quick NLP is a deep learning nlp library inspired by the fast.ai library It follows the same api as fastai and extends it allowing for quick

Agis pof 283 Nov 21, 2022
Python library for interactive topic model visualization. Port of the R LDAvis package.

pyLDAvis Python library for interactive topic model visualization. This is a port of the fabulous R package by Carson Sievert and Kenny Shirley. pyLDA

Ben Mabey 1.7k Dec 20, 2022
TalkNet: Audio-visual active speaker detection Model

Is someone talking? TalkNet: Audio-visual active speaker detection Model This repository contains the code for our ACM MM 2021 paper, TalkNet, an acti

142 Dec 14, 2022
A framework for training and evaluating AI models on a variety of openly available dialogue datasets.

ParlAI (pronounced “par-lay”) is a python framework for sharing, training and testing dialogue models, from open-domain chitchat, to task-oriented dia

Facebook Research 9.7k Jan 09, 2023
Protein Language Model

ProteinLM We pretrain protein language model based on Megatron-LM framework, and then evaluate the pretrained model results on TAPE (Tasks Assessing P

THUDM 77 Dec 27, 2022
Stand-alone language identification system

langid.py readme Introduction langid.py is a standalone Language Identification (LangID) tool. The design principles are as follows: Fast Pre-trained

2k Jan 04, 2023
Sequence-to-Sequence Framework in PyTorch

nmtpytorch allows training of various end-to-end neural architectures including but not limited to neural machine translation, image captioning and au

LIUM 395 Nov 21, 2022
Learn meanings behind words is a key element in NLP. This project concentrates on the disambiguation of preposition senses. Therefore, we train a bert-transformer model and surpass the state-of-the-art.

New State-of-the-Art in Preposition Sense Disambiguation Supervisor: Prof. Dr. Alexander Mehler Alexander Henlein Institutions: Goethe University TTLa

Dirk Neuhäuser 4 Apr 06, 2022