Hierarchical Transformer Memory (HTM) - Pytorch
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a simple method to allow transformers to attend to memories of the past efficiently. Original Jax repository
Install
$ pip install htm-pytorch
Usage
import torch
from htm_pytorch import HTMAttention
attn = HTMAttention(
    dim = 512,
    heads = 8,               # number of heads for within-memory attention
    dim_head = 64,           # dimension per head for within-memory attention
    topk_mems = 8,           # how many memory chunks to select for
    mem_chunk_size = 32,     # number of tokens in each memory chunk
    add_pos_enc = True       # whether to add positional encoding to the memories
)
queries = torch.randn(1, 128, 512)     # queries
memories = torch.randn(1, 20000, 512)  # memories, of any size
mask = torch.ones(1, 20000).bool()     # memory mask
attended = attn(queries, memories, mask = mask) # (1, 128, 512)
If you want the entire HTM Block (which contains the layernorm for the input followed by a skip connection), just import HTMBlock instead
import torch
from htm_pytorch import HTMBlock
block = HTMBlock(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32
)
queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()
out = block(queries, memories, mask = mask) # (1, 128, 512)
Citations
@misc{lampinen2021mental,
    title   = {Towards mental time travel: a hierarchical memory for reinforcement learning agents}, 
    author  = {Andrew Kyle Lampinen and Stephanie C. Y. Chan and Andrea Banino and Felix Hill},
    year    = {2021},
    eprint  = {2105.14039},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

 
 
 
 
