Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Overview

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic's educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

Now you can set the model to .eval() mode and it will terminate early when all samples of the batch have emitted a halting signal

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512,
    causal = True
)

x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()

model.eval() # setting to eval makes it return the logits as well as the halting indices

logits, layer_indices = model(x,  mask = mask) # (2, 512, 20000), (2)

# layer indices will contain, for each batch element, which layer they exited

Citations

@misc{banino2021pondernet,
    title   = {PonderNet: Learning to Ponder}, 
    author  = {Andrea Banino and Jan Balaguer and Charles Blundell},
    year    = {2021},
    eprint  = {2107.05407},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
You might also like...
Implementation of the Transformer variant proposed in
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

Comments
  • Evaluating ponder-net on more pondering-steps than trained on.

    Evaluating ponder-net on more pondering-steps than trained on.

    As the paper says,

    In evaluation, and under known temporal or computational limitations, N can be set naively as a constant (or not set any limit, i.e. N → ∞). For training, we found that a more effective (and interpretable) way of parameterizing N is by defining a minimum cumulative probability of halting. N is then the smallest value of n such that sum( p_sub_ j > 1 − ε)over(j=1, n) , with the hyper-parameter ε positive near 0 (in our experiments 0.05).

    from that I infer that pondering can be done to more steps than trained on. How can be done so with this implementation?

    edit: I was going through the paper again,and I think what the paper means is that the max_num_pondering_steps:N should be re evaluated at every training-step, the model should be run till the condition is met or a pre-defined num of max steps is reached, and where the cumsum_probs condition will be met will be set as 'N', with the cumsum_probs normalised with one of the methods. Then that value of 'N' will be used to calc prior geom for the kl_div (and not normalising the prior geom term).

    i.e. if the num of pondering steps are initially set to 'M', then the model will recur for 'k' steps - i.e. till the condition is met or for 'M' num of max steps; then 'N' will be calculated by first calculating the probabilities - p_0 to p_k - then normalizing through one of the methods, then calculate cumulative-sum of those probabilities, and checking where the sum is greater than threshold, and assigning it the value 'N'. After that, calculating prior geometric values with the defined hyper-parameter, for 'N' seq-len, and using this in the kl-div term against the halting probs truncated to 'N' steps.

    λp is a hyper-parameter that defines a geometric prior distribution pG(λp) on the halting policy (truncated at N)

    opened by Vbansal21 0
  • Can pondernet used for imagenet?

    Can pondernet used for imagenet?

    I plan to do a project on the complexity of tasks on image dataset like imagenet, cifar 100. If I use a vision transformer, then can I implement my project?

    opened by fryegg 2
Releases(0.0.8)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
AQP is a modular pipeline built to enable the comparison and testing of different quality metric configurations.

Audio Quality Platform - AQP An Open Modular Python Platform for Objective Speech and Audio Quality Metrics AQP is a highly modular pipeline designed

Jack Geraghty 24 Oct 01, 2022
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Xin Liu 106 Dec 30, 2022
Pytorch implementations of the paper Value Functions Factorization with Latent State Information Sharing in Decentralized Multi-Agent Policy Gradients

LSF-SAC Pytorch implementations of the paper Value Functions Factorization with Latent State Information Sharing in Decentralized Multi-Agent Policy G

Hanhan 2 Aug 14, 2022
Optical machine for senses sensing using speckle and deep learning

# Senses-speckle [Remote Photonic Detection of Human Senses Using Secondary Speckle Patterns](https://doi.org/10.21203/rs.3.rs-724587/v1) paper Python

Zeev Kalyuzhner 0 Sep 26, 2021
Frigate - NVR With Realtime Object Detection for IP Cameras

A complete and local NVR designed for HomeAssistant with AI object detection. Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras.

Blake Blackshear 6.4k Dec 31, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Scripts for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation and a convolutional neural network (CNN) for image classification

About subwAI subwAI - a project for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation

82 Jan 01, 2023
PyTorch implementation of our paper: Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition

Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition, arxiv This is a PyTorch implementation of our paper. 1. Re

DamoCV 11 Nov 19, 2022
Understanding Hyperdimensional Computing for Parallel Single-Pass Learning

Understanding Hyperdimensional Computing for Parallel Single-Pass Learning Authors: Tao Yu* Yichi Zhang* Zhiru Zhang Christopher De Sa *: Equal Contri

Cornell RelaxML 4 Sep 08, 2022
Stacs-ci - A set of modules to enable integration of STACS with commonly used CI / CD systems

Static Token And Credential Scanner CI Integrations What is it? STACS is a YARA

STACS 18 Aug 04, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021

LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021 We propose a cross encoder model (LTR_CrossEncoder) for information retrieval, re-retrie

Hieu Duong 7 Jan 12, 2022
code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

Facebook Research 94 Oct 26, 2022
Python 3 module to print out long strings of text with intervals of time inbetween

Python-Fastprint Python 3 module to print out long strings of text with intervals of time inbetween Install: pip install fastprint Sync Usage: from fa

Kainoa Kanter 2 Jun 27, 2022
FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning

FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning (FedML) developed and maintained by Scaleout Systems. FEDn enables highly scalable cross-silo and cr

Scaleout 75 Nov 09, 2022
Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

CSRA This is the official code of ICCV 2021 paper: Residual Attention: A Simple But Effective Method for Multi-Label Recoginition Demo, Train and Vali

163 Dec 22, 2022
Real-Time Seizure Detection using EEG: A Comprehensive Comparison of Recent Approaches under a Realistic Setting

Real-Time Seizure Detection using Electroencephalogram (EEG) This is the repository for "Real-Time Seizure Detection using EEG: A Comprehensive Compar

AITRICS 30 Dec 17, 2022
Attendance Monitoring with Face Recognition using Python

Attendance Monitoring with Face Recognition using Python A python GUI integrated attendance system using face recognition to take attendance. In this

Vaibhav Rajput 2 Jun 21, 2022
RMTD: Robust Moving Target Defence Against False Data Injection Attacks in Power Grids

RMTD: Robust Moving Target Defence Against False Data Injection Attacks in Power Grids Real-time detection performance. This repo contains the code an

0 Nov 10, 2021
HW3 ― GAN, ACGAN and UDA

HW3 ― GAN, ACGAN and UDA In this assignment, you are given datasets of human face and digit images. You will need to implement the models of both GAN

grassking100 1 Dec 13, 2021