GPT - gMLP
This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.
GPT is technically a misnomer now, since there will be no attention (transformer) at all contained in the architecture.
Install
$ pip install g-mlp-gpt
Usage
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    depth = 4,
    seq_len = 1024,
    window_size = (128, 256, 512, 1024) # window sizes for each depth
)
x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)
16k context length
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    seq_len = 16384,
    depth = 8,
    reversible = True,
    window = (128, 128, 256, 512, 1024, 1024, 2048, 2048, 4096, 4096, 8192, 8192),
    axial = (1, 1, 1, 1, 1, 1, 2, 2, 4, 4, 8, 8)
).cuda()
x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)
Citations
@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
 
