Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Overview

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic's educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

Now you can set the model to .eval() mode and it will terminate early when all samples of the batch have emitted a halting signal

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512,
    causal = True
)

x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()

model.eval() # setting to eval makes it return the logits as well as the halting indices

logits, layer_indices = model(x,  mask = mask) # (2, 512, 20000), (2)

# layer indices will contain, for each batch element, which layer they exited

Citations

@misc{banino2021pondernet,
    title   = {PonderNet: Learning to Ponder}, 
    author  = {Andrea Banino and Jan Balaguer and Charles Blundell},
    year    = {2021},
    eprint  = {2107.05407},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
You might also like...
Implementation of the Transformer variant proposed in
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

Comments
  • Evaluating ponder-net on more pondering-steps than trained on.

    Evaluating ponder-net on more pondering-steps than trained on.

    As the paper says,

    In evaluation, and under known temporal or computational limitations, N can be set naively as a constant (or not set any limit, i.e. N → ∞). For training, we found that a more effective (and interpretable) way of parameterizing N is by defining a minimum cumulative probability of halting. N is then the smallest value of n such that sum( p_sub_ j > 1 − ε)over(j=1, n) , with the hyper-parameter ε positive near 0 (in our experiments 0.05).

    from that I infer that pondering can be done to more steps than trained on. How can be done so with this implementation?

    edit: I was going through the paper again,and I think what the paper means is that the max_num_pondering_steps:N should be re evaluated at every training-step, the model should be run till the condition is met or a pre-defined num of max steps is reached, and where the cumsum_probs condition will be met will be set as 'N', with the cumsum_probs normalised with one of the methods. Then that value of 'N' will be used to calc prior geom for the kl_div (and not normalising the prior geom term).

    i.e. if the num of pondering steps are initially set to 'M', then the model will recur for 'k' steps - i.e. till the condition is met or for 'M' num of max steps; then 'N' will be calculated by first calculating the probabilities - p_0 to p_k - then normalizing through one of the methods, then calculate cumulative-sum of those probabilities, and checking where the sum is greater than threshold, and assigning it the value 'N'. After that, calculating prior geometric values with the defined hyper-parameter, for 'N' seq-len, and using this in the kl-div term against the halting probs truncated to 'N' steps.

    λp is a hyper-parameter that defines a geometric prior distribution pG(λp) on the halting policy (truncated at N)

    opened by Vbansal21 0
  • Can pondernet used for imagenet?

    Can pondernet used for imagenet?

    I plan to do a project on the complexity of tasks on image dataset like imagenet, cifar 100. If I use a vision transformer, then can I implement my project?

    opened by fryegg 2
Releases(0.0.8)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models

Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models This repo contains a barebones implementation for the atta

16 Dec 04, 2022
CARL provides highly configurable contextual extensions to several well-known RL environments.

CARL (context adaptive RL) provides highly configurable contextual extensions to several well-known RL environments.

AutoML-Freiburg-Hannover 51 Dec 28, 2022
Invariant Causal Prediction for Block MDPs

MISA Abstract Generalization across environments is critical to the successful application of reinforcement learning algorithms to real-world challeng

Meta Research 41 Sep 17, 2022
Cross-platform CLI tool to generate your Github profile's stats and summary.

ghs Cross-platform CLI tool to generate your Github profile's stats and summary. Preview Hop on to examples for other usecases. Jump to: Installation

HackerRank 134 Dec 20, 2022
A curated list of awesome resources related to Semantic Search🔎 and Semantic Similarity tasks.

A curated list of awesome resources related to Semantic Search🔎 and Semantic Similarity tasks.

224 Jan 04, 2023
Boosted CVaR Classification (NeurIPS 2021)

Boosted CVaR Classification Runtian Zhai, Chen Dan, Arun Sai Suggala, Zico Kolter, Pradeep Ravikumar NeurIPS 2021 Table of Contents Quick Start Train

Runtian Zhai 4 Feb 15, 2022
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

114 Jul 19, 2022
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
S-attack library. Official implementation of two papers "Are socially-aware trajectory prediction models really socially-aware?" and "Vehicle trajectory prediction works, but not everywhere".

S-attack library: A library for evaluating trajectory prediction models This library contains two research projects to assess the trajectory predictio

VITA lab at EPFL 71 Jan 04, 2023
[arXiv] What-If Motion Prediction for Autonomous Driving ❓🚗💨

WIMP - What If Motion Predictor Reference PyTorch Implementation for What If Motion Prediction [PDF] [Dynamic Visualizations] Setup Requirements The W

William Qi 96 Dec 29, 2022
Predicts an answer in yes or no.

Oui-ou-non-prediction Predicts an answer in 'yes' or 'no'. It is based on the game 'effeuiller la marguerite' in which the person plucks flower petals

Ananya Gupta 1 Jan 15, 2022
Image segmentation with private İstanbul Dataset

Image Segmentation This repo was created for academic research and test result. Repo will update after academic article online. This repo contains wei

İrem KÖMÜRCÜ 9 Dec 11, 2022
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
Trying to understand alias-free-gan.

alias-free-gan-explanation Trying to understand alias-free-gan in my own way. [Chinese Version 中文版本] CC-BY-4.0 License. Tzu-Heng Lin motivation of thi

Tzu-Heng Lin 12 Mar 17, 2022
Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Wonjong Jang 8 Nov 01, 2022
Network Enhancement implementation in pytorch

network_enahncement_pytorch Network Enhancement implementation in pytorch Research paper Network Enhancement: a general method to denoise weighted bio

Yen 1 Nov 12, 2021
A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries.

Yolo-Powered-Detector A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries

Luke Wilson 1 Dec 03, 2021
UnFlow: Unsupervised Learning of Optical Flow with a Bidirectional Census Loss

UnFlow: Unsupervised Learning of Optical Flow with a Bidirectional Census Loss This repository contains the TensorFlow implementation of the paper UnF

Simon Meister 270 Nov 06, 2022
Discovering Explanatory Sentences in Legal Case Decisions Using Pre-trained Language Models.

Statutory Interpretation Data Set This repository contains the data set created for the following research papers: Savelka, Jaromir, and Kevin D. Ashl

17 Dec 23, 2022
PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021.

GCResNet PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021. The code will

11 May 19, 2022