Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Overview

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic's educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

Now you can set the model to .eval() mode and it will terminate early when all samples of the batch have emitted a halting signal

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512,
    causal = True
)

x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()

model.eval() # setting to eval makes it return the logits as well as the halting indices

logits, layer_indices = model(x,  mask = mask) # (2, 512, 20000), (2)

# layer indices will contain, for each batch element, which layer they exited

Citations

@misc{banino2021pondernet,
    title   = {PonderNet: Learning to Ponder}, 
    author  = {Andrea Banino and Jan Balaguer and Charles Blundell},
    year    = {2021},
    eprint  = {2107.05407},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
You might also like...
Implementation of the Transformer variant proposed in
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

Comments
  • Evaluating ponder-net on more pondering-steps than trained on.

    Evaluating ponder-net on more pondering-steps than trained on.

    As the paper says,

    In evaluation, and under known temporal or computational limitations, N can be set naively as a constant (or not set any limit, i.e. N → ∞). For training, we found that a more effective (and interpretable) way of parameterizing N is by defining a minimum cumulative probability of halting. N is then the smallest value of n such that sum( p_sub_ j > 1 − ε)over(j=1, n) , with the hyper-parameter ε positive near 0 (in our experiments 0.05).

    from that I infer that pondering can be done to more steps than trained on. How can be done so with this implementation?

    edit: I was going through the paper again,and I think what the paper means is that the max_num_pondering_steps:N should be re evaluated at every training-step, the model should be run till the condition is met or a pre-defined num of max steps is reached, and where the cumsum_probs condition will be met will be set as 'N', with the cumsum_probs normalised with one of the methods. Then that value of 'N' will be used to calc prior geom for the kl_div (and not normalising the prior geom term).

    i.e. if the num of pondering steps are initially set to 'M', then the model will recur for 'k' steps - i.e. till the condition is met or for 'M' num of max steps; then 'N' will be calculated by first calculating the probabilities - p_0 to p_k - then normalizing through one of the methods, then calculate cumulative-sum of those probabilities, and checking where the sum is greater than threshold, and assigning it the value 'N'. After that, calculating prior geometric values with the defined hyper-parameter, for 'N' seq-len, and using this in the kl-div term against the halting probs truncated to 'N' steps.

    λp is a hyper-parameter that defines a geometric prior distribution pG(λp) on the halting policy (truncated at N)

    opened by Vbansal21 0
  • Can pondernet used for imagenet?

    Can pondernet used for imagenet?

    I plan to do a project on the complexity of tasks on image dataset like imagenet, cifar 100. If I use a vision transformer, then can I implement my project?

    opened by fryegg 2
Releases(0.0.8)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Sequence Modeling with Structured State Spaces

Structured State Spaces for Sequence Modeling This repository provides implementations and experiments for the following papers. S4 Efficiently Modeli

HazyResearch 896 Jan 01, 2023
Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020.

RegNet Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020. Paper | Official Implementation RegNet offer a very

Vishal R 2 Feb 11, 2022
The official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness.

This repository is the official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness. Requirements pip install -r requi

Jie Ren 17 Dec 12, 2022
Lacmus is a cross-platform application that helps to find people who are lost in the forest using computer vision and neural networks.

lacmus The program for searching through photos from the air of lost people in the forest using Retina Net neural nwtwork. The project is being develo

Lacmus Foundation 168 Dec 27, 2022
DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe.

DeepLab Introduction DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe. It combines densely-compute

Ali 234 Nov 14, 2022
Official page of Patchwork (RA-L'21 w/ IROS'21)

Patchwork Official page of "Patchwork: Concentric Zone-based Region-wise Ground Segmentation with Ground Likelihood Estimation Using a 3D LiDAR Sensor

Hyungtae Lim 254 Jan 05, 2023
A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.

Torch-RecHub A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend. 安装 pip install torch-rechub 主要特性 scikit-learn风格易用

Mincai Lai 67 Jan 04, 2023
Easy-to-use micro-wrappers for Gym and PettingZoo based RL Environments

SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers'). We supp

Farama Foundation 357 Jan 06, 2023
A Unified Framework and Analysis for Structured Knowledge Grounding

UnifiedSKG 📚 : Unifying and Multi-Tasking Structured Knowledge Grounding with Text-to-Text Language Models Code for paper UnifiedSKG: Unifying and Mu

HKU NLP Group 370 Dec 21, 2022
The Fundamental Clustering Problems Suite (FCPS) summaries 54 state-of-the-art clustering algorithms, common cluster challenges and estimations of the number of clusters as well as the testing for cluster tendency.

FCPS Fundamental Clustering Problems Suite The package provides over sixty state-of-the-art clustering algorithms for unsupervised machine learning pu

9 Nov 27, 2022
Exploring Versatile Prior for Human Motion via Motion Frequency Guidance (3DV2021)

Exploring Versatile Prior for Human Motion via Motion Frequency Guidance This is the codebase for video-based human motion reconstruction in human-mot

Jiachen Xu 5 Jul 14, 2022
Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation

DynaBOA Code repositoty for the paper: Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation Shanyan Guan, Jingwei Xu, Michell

198 Dec 29, 2022
Realistic lighting in ursina!

Ursina Lighting Realistic lighting in ursina! If you want to have realistic lighting in ursina, import the UrsinaLighting.py in your project and use t

17 Jul 07, 2022
LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021

LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021 We propose a cross encoder model (LTR_CrossEncoder) for information retrieval, re-retrie

Xuan Hieu Duong 7 Jan 12, 2022
MAVE: : A Product Dataset for Multi-source Attribute Value Extraction

MAVE: : A Product Dataset for Multi-source Attribute Value Extraction The dataset contains 3 million attribute-value annotations across 1257 unique ca

Google Research Datasets 89 Jan 08, 2023
Experiment about Deep Person Re-identification with EfficientNet-v2

We evaluated the baseline with Resnet50 and Efficienet-v2 without using pretrained models. Also Resnet50-IBN-A and Efficientnet-v2 using pretrained on ImageNet. We used two datasets: Market-1501 and

lan.nguyen2k 77 Jan 03, 2023
An index of recommendation algorithms that are based on Graph Neural Networks.

An index of recommendation algorithms that are based on Graph Neural Networks.

FIB LAB, Tsinghua University 564 Jan 07, 2023
E2VID_ROS - E2VID_ROS: E2VID to a real-time system

E2VID_ROS Introduce We extend E2VID to a real-time system. Because Python ROS ca

Robin Shaun 7 Apr 17, 2022
Codes of the paper Deformable Butterfly: A Highly Structured and Sparse Linear Transform.

Deformable Butterfly: A Highly Structured and Sparse Linear Transform DeBut Advantages DeBut generalizes the square power of two butterfly factor matr

Rui LIN 8 Jun 10, 2022
(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework

(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework Background: Outlier detection (OD) is a key data mining task for identify

Yue Zhao 127 Jan 05, 2023