Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

Overview

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

Comments
  • Training falling on version 0.0.14 and 0.0.15

    Training falling on version 0.0.14 and 0.0.15

    Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Full error listing:

    ipython-input-7-1329da5363de> in forward(self, inputs, labels)
          7   def forward(self, inputs, labels=None):
          8     loss_mx = labels != -100
    ----> 9     output = self.model(inputs)
         10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
         11     labels = labels[loss_mx].view(-1)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        376         x = self.to_token_emb(x)
        377         x = self.pos_emb(torch.arange(t, device=device)) + x
    --> 378         x = self.sinkhorn_transformer(x)
        379         return self.to_logits(x)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        359 
        360     def forward(self, x, input_mask = None):
    --> 361         return self.layers(x)
        362 
        363 class SinkhornTransformerLM(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        330     def forward(self, x, **kwargs):
        331         x = torch.cat([x, x], dim=-1)
    --> 332         x = self.layers(x, **kwargs)
        333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
        334 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
        128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
        129 
    --> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
         98         ctx.kwargs = kwargs
         99         for block in blocks:
    --> 100             x = block(x, **kwargs)
        101         ctx.y = x.detach()
        102         ctx.blocks = blocks
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
         51         with torch.no_grad():
         52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
    ---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
         54 
         55         return torch.cat([y1, y2], dim=2)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
         25 
         26         if not set_rng:
    ---> 27             return self.net(*args, **kwargs)
         28 
         29         rng_devices = []
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        112     def forward(self, x, **kwargs):
        113         x = self.norm(x)
    --> 114         return self.fn(x, **kwargs)
        115 
        116 class SortNet(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
        103 
        104     def forward(self, x):
    --> 105         return self.net(x)
        106 
        107 class PreNorm(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
         98     def forward(self, input):
         99         for module in self:
    --> 100             input = module(input)
        101         return input
        102 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
         85 
         86     def forward(self, input):
    ---> 87         return F.linear(input, self.weight, self.bias)
         88 
         89     def extra_repr(self):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
       1591         ret = torch.addmm(bias, input, weight.t())
       1592     else:
    -> 1593         output = input.matmul(weight.t())
       1594         if bias is not None:
       1595             output += bias
    
    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Also, version 0.0.11(and all other version from 0.0.8) work stable.

    opened by blizda 30
  • generation problem in a toy task

    generation problem in a toy task

    Here is the full script for my toy task (x -> xx like "abc" to "abcabc")

    from sinkhorn_transformer import SinkhornTransformerLM
    from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
    
    import random
    import tqdm
    import gzip
    import numpy as np
    import torch
    import torch.optim as optim
    from torch import nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY  = 100
    GENERATE_EVERY  = 100
    ENC_SEQ_LEN=16
    DEC_SEQ_LEN=40
    NUM_TOKENS = 256 + 2
    BUCKET_SIZE = 8
    
    # helpers
    
    def top_k(logits, thres = 0.9):
        k = int((1 - thres) * logits.shape[-1])
        val, ind = torch.topk(logits, k)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(1, ind, val)
        return probs
    
    
    def cycle():
        while True:
            source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
    
            target = torch.cat((source, source), 1)
            prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
            target = torch.cat((prefix, target), axis=1)
    
            x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
            y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
    
    
            yield (source, target, x_mask, y_mask)
    
    # instantiate model
    
    class MySinkhornTransformer(nn.Module):
        def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
            super().__init__()
            
            self.pad_token = 0
            self.sos_token = 1
    
            self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                             reversible=True, return_embeddings=True)
            self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                             receives_context=True, context_bucket_size=bucket_size, reversible=True)
            self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
        
        @torch.no_grad()
        def generate(self, x, x_mask):
            context = self.enc(x, input_mask=x_mask)
            start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
    
            return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
    
        def forward(self, x, y, x_mask, y_mask, return_loss):
            context = self.enc(x, input_mask=x_mask)
            return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
    
    
    model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
    model.cuda()
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                source, target, x_mask, y_mask = next(cycle())
                loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
                print(f'validation loss: {loss.item()}')
    
        if i % GENERATE_EVERY == 0:
            model.eval()
            
            source, target, x_mask, y_mask = next(cycle())
            
            sample = model.generate(x=source, x_mask=x_mask)
            print("input:  ", source[0])
            print("model output:  ", sample[0])
    

    After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: "x,x,x,x,x,y,y,y,y,y" like "aaaabbbb" instead of "abcdabcd". I was wondering what might be the underlying issue. Do you got any idea?

    opened by py4 18
  • Sortcut variant and bucket size

    Sortcut variant and bucket size

    From the paper it looks to me that for the sortcut variant, the queries are allowed to attend to all the key buckets post truncation for the non-causal case. If this is correct, won't the output be the same no matter what the bucket size is for the query.

    Based on my understanding of the paper, the authors only report results selecting 2 top key/value blocks with bucket sizes 8,16,32. They do not mention having different bucket sizes for query and key/value which I found in this repo.

    opened by jsmith1915 16
  • A noobish question about training...

    A noobish question about training...

    @lucidrains

    First of all, everything seems to be working now in Google Colab so thank you very much for fixing it.

    I have a quick question about training if you do not mind...

    I get very high loss results. Here is the example:

    training: 1%| | 509/100000 [1:55:22<387:23:57, 14.02s/it]training loss: 2.495532512664795

    Is this normal and I simply need to train more? Or does it mean that there is a problem somewhere?

    2.49 in 2 hours is way too much IMHO. I am pretty sure I am not doing it right so your advice would be really apppreciated.

    Thank you.

    P.S. I am running your wiki8 example in Google Colab Pro.

    opened by asigalov61 9
  • several question about implementation

    several question about implementation

    I am currently trying to implement Sparse Sinkhorn attention (non-causal, self-attention, pretrain for MLM task) with tensorflow and I would appreciate it if you could answer several questions about your code.

    1. I did not quite understand from the paper what is happening with queries and keys in SortCut, when we cut off most of the blocks. In your code it seems like you broadcast keys/queries like this [number_blocks, block_len, size_per_head] -> [:n_sortcut, block_len, size_per_head] -> [1, n_sortcut * block_len, size_per_head] -> [number_blocks, n_sortcut * block_len, size_per_head] I understand that pytorch handling expand_dim automatically and memory is preserved, but I do not understand how does the memory consumption stay linear: the resulting attention scores (dots) are still [number_blocks, block_len, num_sortcut * block_len] which is good but quadratic. Do I miss something?

    2. In the case of SortCut what is the meaning of the softmax(A)*Q? For example, for full attention outputs are vectors which are weighted average of other vectors; For local block attention outputs are vectors that are weighted average of other vectors in the block. How would you interpret the output of SortCut attention layer?

    3. In your code you concatenate regular keys and values with permuted/sorted keys and values, but in the paper (3.2. part) it seems like simple addition. Why is it different from the paper? Actually I can understand concatenation more than summation (because I am confused which attention masks to use in this case: from the regular block of permuted).

    4. Seems like you are using concatenated queries and keys as an input to make a SortNet. Maybe I completely missed the point, but I did not find anything about this in the article: the input sequence is used for SortNet instead of its projections. In fact, for some reason, I cannot make it converge if I use query and keys (maybe for a completely unrelated reasons). Have you tried to compare input sequence vs concat(q,k)?

    opened by w4-magnes 9
  • TypeError exception in AxialPositionalEncoding when using DataParallel

    TypeError exception in AxialPositionalEncoding when using DataParallel

    Hello,

    I want to run SinkhornTransformerLM using multiple GPUs, so I'm wrapping the model into torch.nn.DataParallel. However, when I do this, I get an exception:

    Traceback (most recent call last):
      File "script.py", line 27, in <module>
        model(x)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
        raise self.exc_type(msg)
    TypeError: Caught TypeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 792, in forward
        x = self.axial_pos_emb(x) + x
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 243, in forward
        return pos_emb[:, :t]
    TypeError: 'int' object is not subscriptable
    

    Looking at the code, it would seem that self.weights does not get populated. To reproduce this error, I took the first example in README.md and changed

    model(x) # (1, 2048, 20000)
    

    to

    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).to('cuda')
    model(x)
    
    opened by kl0211 8
  • Do I need to pad?

    Do I need to pad?

    Excuse me for the noob question, I have sequences of different lengths and I use:

    model = AutoregressiveWrapper(model, ignore_index=PAD_ID, pad_value=PAD_ID)
    

    Do I need to pad as in this example or not?

    def __getitem__(self, index):
            seq_tokens = self.examples[index]
            input_ids = torch.tensor(seq_tokens, dtype=torch.long)
            input_ids = F.pad(input_ids, (seq_len - len(input_ids), 0), value=self.pad_value)
            return input_ids
    
    opened by timsoraro 8
  • Performance of the different attention variants in the repo

    Performance of the different attention variants in the repo

    Really great work!!

    I have a few questions.

    -Does attention sort net work better or at par with the simple sort net. -Does attention sort net work well for the sortcut variant also. -How does the sortcut variant perform when compared to vanilla sparse sinkhorn attention.

    It would be very helpful if you could share some plots/numbers from your experiments comparing the performance of the different variants such as attention sort net, routing based attention etc.

    opened by jsmith1915 7
  • Some questions about dropout

    Some questions about dropout

    Hi again @lucidrains, I just had some quick questions about dropout with the Sinkhorn Transformer, as I was just using my Linformer implementation (which as you know is based off of this repo), but it was overfitting my dataset. Therefore, I just had some quick questions about some dropout and your implementation, and I wanted to ask whether some design choices here were intentional or not:

    1. In the original Transformer, dropout was performed after each sublayer, before the residual connection. I noticed that you only have this after the SinkhornSelfAttention class, but not after the FeedForward class. Is this intentional?
    2. Speaking of the FeedForward class, you insert dropout after the first linear layer. I couldn't find this anywhere in any literature, were you able to find a reference of why this was effective? I put it into my implementation, and it seems to help, but i just don't know where this idea came from.
    3. On a similar note, do you know why the dots tensor in the self attention classes are dropped out? Again, I put it in my linformer and it seems to work, but I can't find a reference to this in the literature.
    4. Finally, the original transformer also dropped out the input tokens, like so (From the SinkhornTransformerLM class):
        def forward(self, x, **kwargs):
            _, t, device = *x.shape, x.device
            assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
    
            x = self.to_token_emb(x)
            x = self.axial_pos_emb(x) + x
            """ Dropout would go here"""
            x = self.sinkhorn_transformer(x, **kwargs)
            return self.to_logits(x)
    

    Should they also be dropped out here as well?

    I now updated my repo such that all 4 of these dropout possibilities exist. I'll let you know if this helps overfitting.

    Thank you for your time!

    opened by tatp22 6
  • Training crashes due to inplace operation

    Training crashes due to inplace operation

    Hi there, I am trying to train this model on my custom dataset and training starts but after many iterations, training crashes due to this error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    

    After enabling anomaly detection, here is the error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    

    Traceback is not useful as it only shows error at line loss.backward() and torch.autograd module. Can you help me where the issue might be?

    Thanks

    opened by NaxAlpha 6
  • Positional Embedding

    Positional Embedding

    Hi! Found out that SinkhornTransformerLM uses two positional encodings simultaneously:

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L778-L779

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L792-L793

    I guess pos_emb can be removed as it introduces memory overhead and makes useless the utilization of Axial Positional Encoding that is designed to reduce the number of positional encoding parameters.

    Is there a special reasoning behind that?

    opened by ilya16 4
  • A wrapper of SinkhornTransformerEncDec

    A wrapper of SinkhornTransformerEncDec

    Could your please coded up a wrapper that removes a lot of the manual work in writing up a generic SinkhornTransformer encoder / decoder architecture.

    Thanks a lot! halexan

    opened by halexan 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Code to reproduce the results of the paper 'Towards Realistic Few-Shot Relation Extraction' (EMNLP 2021)

Realistic Few-Shot Relation Extraction This repository contains code to reproduce the results in the paper "Towards Realistic Few-Shot Relation Extrac

Bloomberg 8 Nov 09, 2022
Almost State-of-the-art Text Generation library

Ps: we are adding transformer model soon Text Gen 🐐 Almost State-of-the-art Text Generation library Text gen is a python library that allow you build

Emeka boris ama 63 Jun 24, 2022
Use Google's BERT for named entity recognition (CoNLL-2003 as the dataset).

For better performance, you can try NLPGNN, see NLPGNN for more details. BERT-NER Version 2 Use Google's BERT for named entity recognition (CoNLL-2003

Kaiyinzhou 1.2k Dec 26, 2022
📔️ Generate a text-based journal from a template file.

JGen 📔️ Generate a text-based journal from a template file. Contents Getting Started Example Overview Usage Details Reserved Keywords Gotchas Getting

Harrison Broadbent 21 Sep 25, 2022
Image2pcl - Enter the metaverse with 2D image to 3D projections

Image2PCL Enter the metaverse with 2D image to 3D projections! This is an implem

Benjamin Ho 0 Feb 05, 2022
voice2json is a collection of command-line tools for offline speech/intent recognition on Linux

Command-line tools for speech and intent recognition on Linux

Michael Hansen 988 Jan 04, 2023
Search-Engine - 📖 AI based search engine

Search Engine AI based search engine that was trained on 25000 samples, feel free to train on up to 1.2M sample from kaggle dataset, link below StackS

Vladislav Kruglikov 2 Nov 29, 2022
Translates basic English sentences into the Huna language (hoo-NAH)

huna-translator The Huna Language Translates basic English sentences into the Huna language (hoo-NAH). The Huna constructed language was developed in

Miles Smith 0 Jan 20, 2022
Source code for AAAI20 "Generating Persona Consistent Dialogues by Exploiting Natural Language Inference".

Generating Persona Consistent Dialogues by Exploiting Natural Language Inference Source code for RCDG model in AAAI20 Generating Persona Consistent Di

16 Oct 08, 2022
Sentiment-Analysis and EDA on the IMDB Movie Review Dataset

Sentiment-Analysis and EDA on the IMDB Movie Review Dataset The main part of the work focuses on the exploration and study of different approaches whi

Nikolas Petrou 1 Jan 12, 2022
An attempt to map the areas with active conflict in Ukraine using open source twitter data.

Live Action Map (LAM) An attempt to use open source data on Twitter to map areas with active conflict. Right now it is used for the Ukraine-Russia con

Kinshuk Dua 171 Nov 21, 2022
Prompt tuning toolkit for GPT-2 and GPT-Neo

mkultra mkultra is a prompt tuning toolkit for GPT-2 and GPT-Neo. Prompt tuning injects a string of 20-100 special tokens into the context in order to

61 Jan 01, 2023
KakaoBrain KoGPT (Korean Generative Pre-trained Transformer)

KoGPT KoGPT (Korean Generative Pre-trained Transformer) https://github.com/kakaobrain/kogpt https://huggingface.co/kakaobrain/kogpt Model Descriptions

Kakao Brain 797 Dec 26, 2022
A multi-lingual approach to AllenNLP CoReference Resolution along with a wrapper for spaCy.

Crosslingual Coreference Coreference is amazing but the data required for training a model is very scarce. In our case, the available training for non

Pandora Intelligence 71 Jan 04, 2023
COVID-19 Chatbot with Rasa 2.0: open source conversational AI

COVID-19 chatbot implementation with Rasa open source 2.0, conversational AI framework.

Aazim Parwaz 1 Dec 23, 2022
Code for paper: An Effective, Robust and Fairness-awareHate Speech Detection Framework

BiQQLSTM_HS Code and data for paper: Title: An Effective, Robust and Fairness-awareHate Speech Detection Framework. Authors: Guanyi Mou and Kyumin Lee

Guanyi Mou 2 Dec 27, 2022
RuCLIP tiny (Russian Contrastive Language–Image Pretraining) is a neural network trained to work with different pairs (images, texts).

RuCLIPtiny Zero-shot image classification model for Russian language RuCLIP tiny (Russian Contrastive Language–Image Pretraining) is a neural network

Shahmatov Arseniy 26 Sep 20, 2022
NeuTex: Neural Texture Mapping for Volumetric Neural Rendering

NeuTex: Neural Texture Mapping for Volumetric Neural Rendering Paper: https://arxiv.org/abs/2103.00762 Running Run on the provided DTU scene cd run ba

Fanbo Xiang 68 Jan 06, 2023
NeurIPS'21: Probabilistic Margins for Instance Reweighting in Adversarial Training (Pytorch implementation).

source code for NeurIPS21 paper robabilistic Margins for Instance Reweighting in Adversarial Training

9 Dec 20, 2022
Neural Lexicon Reader: Reduce Pronunciation Errors in End-to-end TTS by Leveraging External Textual Knowledge

Neural Lexicon Reader: Reduce Pronunciation Errors in End-to-end TTS by Leveraging External Textual Knowledge This is an implementation of the paper,

Mutian He 19 Oct 14, 2022