Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Overview

Perceiver Twitter

PyPI Lint with Black⬛ Upload Python Package DOI Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on top of Transformers such that the data only enters through the cross attention mechanism (see figure) and allow it to scale to hundreds of thousands of inputs, like ConvNets. This, in part also solves the Transformers Quadratic compute and memory bottleneck.

Yannic Kilcher's video was very helpful.

Installation

Run the following to install:

pip install perceiver

Developing perceiver

To install perceiver, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Perceiver.git
# or clone your own fork

cd perceiver
pip install -e .[dev]

A bit about Perceiver

The Perceiver model aims to deal with arbitrary configurations of different modalities using a single transformer-based architecture. Transformers are often flexible and make few assumptions about their inputs, but that also scale quadratically with the number of inputs in terms of both memory and computation. This model proposes a mechanism that makes it possible to deal with high-dimensional inputs, while retaining the expressivity and flexibility to deal with arbitrary input configurations.

The idea here is to introduce a small set of latent units that forms an attention bottleneck through which the inputs must pass. This avoids the quadratic scaling problem of all-to-all attention of a classical transformer. The model can be seen as performing a fully end-to-end clustering of the inputs, with the latent units as the cluster centres, leveraging a highly asymmetric crossattention layer. For spatial information the authors compensate for the lack of explicit grid structures in our model by associating Fourier feature encodings.

Usage

from perceiver import Perceiver
import tensorflow as tf

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = tf.random.normal([1, 224, 224, 3]) # replicating 1 imagenet image
model(img) # (1, 1000)

About the notebooks

perceiver_example

Open In Colab Binder

This notebook installs the perceiver package and shows an example of running it on a single imagenet image ([1, 224, 224, 3]) with 1000 classes to demonstarte the working of this model.

Want to Contribute πŸ™‹β€β™‚οΈ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? πŸ’¬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • error with tf2.4.1

    error with tf2.4.1

    Hello Rishit,

    thank you for your Perceiver implementation! I have two notes, I am not very familiar with tf2 though. You define and call a tf.keras.Sequential model here https://github.com/Rishit-dagli/Perceiver/blob/4d3b9b0514da4fb623d178e3e70df1836ebad5ba/perceiver/perceiver.py#L106 For my version of tf at least this throws an error, I think it should be defined once in __init__ and then just called in call.

    And just above it, you compute data but then you don't pass it to self.model. Is that correct?

    bug 
    opened by abred 3
  • Training code

    Training code

    Hi there,

    I've tried to set up a standard MNIST training over the last few days using the Perceiver code provided here. So far, I've not been able to come up with any solution where the model actually learns anything. A major problem so far has been the way the model is written with no support for model.fit() and the whole functional API.

    Do you happen to have any training example code for your model which you could provide here in this repo? MNIST as the default starting point would be nice, but anything would do the job as well :)

    question 
    opened by tpetri94 2
  • Create a FeedForward layer

    Create a FeedForward layer

    Create a simple FeedForward layer as a tf.keras.layers.Layer which should essentially contain a Dense layer with the modified GELU activation (#2 ), optionally I could also include a dropout layer and another Dense layer which should have the number of neurons equal to the dimension

    opened by Rishit-dagli 0
  • Implement a PreNorm layer

    Implement a PreNorm layer

    Create a Normalization layer from the tf.keras.layerr.Layers. This should essentially figure out the right axis and implement layer normalization on it.

    opened by Rishit-dagli 0
  • Don't pin TensorFlow version to a specific number

    Don't pin TensorFlow version to a specific number

    Hello,

    In setup.py you should change "tensorflow~=2.4.0" to " "tensorflow>2.4.0" to ensure any version above the minimal one is used.

    bug 
    opened by ebursztein 0
Releases(v0.1.2)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Official implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis https://arxiv.org/abs/2011.13775

CIPS -- Official Pytorch Implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis Requirements pip install -r requi

Multimodal Lab @ Samsung AI Center Moscow 201 Dec 21, 2022
AlphaBot2 Pi Core software for interfacing with the various components.

AlphaBot2-Pi-Core AlphaBot2 Pi Core software for interfacing with the various components. This project is currently a W.I.P. I will update this readme

KyleDev 1 Feb 13, 2022
Awesome Artificial Intelligence, Machine Learning and Deep Learning as we learn it

Awesome Artificial Intelligence, Machine Learning and Deep Learning as we learn it. Study notes and a curated list of awesome resources of such topics.

mani 1.2k Jan 07, 2023
Wanli Li and Tieyun Qian: Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction, IJCNN 2021

MRefG Wanli Li and Tieyun Qian: "Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction", IJCNN 2021 1. Requirements To reproduc

万理 5 Jul 26, 2022
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
Baseline and template code for node21 detection track

Nodule Detection Algorithm This codebase implements a baseline model, Faster R-CNN, for the nodule detection track in NODE21. It contains all necessar

node21challenge 11 Jan 15, 2022
CS50x-AI - Artificial Intelligence with Python from Harvard University

CS50x-AI Artificial Intelligence with Python from Harvard University πŸ“– Table of

Hosein Damavandi 6 Aug 22, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Automatic labeling, conversion of different data set formats, sample size statistics, model cascade

Simple Gadget Collection for Object Detection Tasks Automatic image annotation Conversion between different annotation formats Obtain statistical info

llt 4 Aug 24, 2022
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 04, 2020
Code for the paper "Unsupervised Contrastive Learning of Sound Event Representations", ICASSP 2021.

Unsupervised Contrastive Learning of Sound Event Representations This repository contains the code for the following paper. If you use this code or pa

Eduardo Fonseca 81 Dec 22, 2022
This repository provides a basic implementation of our GCPR 2021 paper "Learning Conditional Invariance through Cycle Consistency"

Learning Conditional Invariance through Cycle Consistency This repository provides a basic TensorFlow 1 implementation of the proposed model in our GC

BMDA - University of Basel 1 Nov 04, 2022
BED: A Real-Time Object Detection System for Edge Devices

BED: A Real-Time Object Detection System for Edge Devices About this project Thi

Data Analytics Lab at Texas A&M University 44 Nov 18, 2022
Transformer in Vision

Transformer-in-Vision Recent Transformer-based CV and related works. Welcome to comment/contribute! Keep updated. Resource SCENIC: A JAX Library for C

Yong-Lu Li 1.1k Dec 30, 2022
SPEAR: Semi suPErvised dAta progRamming

Semi-Supervised Data Programming for Data Efficient Machine Learning SPEAR is a library for data programming with semi-supervision. The package implem

decile-team 91 Dec 06, 2022
Deploying PyTorch Model to Production with FastAPI in CUDA-supported Docker

Deploying PyTorch Model to Production with FastAPI in CUDA-supported Docker A example FastAPI PyTorch Model deploy with nvidia/cuda base docker. Model

Ming 68 Jan 04, 2023
Code accompanying the paper "How Tight Can PAC-Bayes be in the Small Data Regime?"

How Tight Can PAC-Bayes be in the Small Data Regime? This is the code to reproduce all experiments for the following paper: @inproceedings{Foong:2021:

5 Dec 21, 2021
A library for hidden semi-Markov models with explicit durations

hsmmlearn hsmmlearn is a library for unsupervised learning of hidden semi-Markov models with explicit durations. It is a port of the hsmm package for

Joris Vankerschaver 69 Dec 20, 2022
scAR (single-cell Ambient Remover) is a package for data denoising in single-cell omics.

scAR scAR (single cell Ambient Remover) is a package for denoising multiple single cell omics data. It can be used for multiple tasks, such as, sgRNA

19 Nov 28, 2022
HNN: Human (Hollywood) Neural Network

HNN: Human (Hollywood) Neural Network Learn the top 1000 actors on IMDB with your very own low cost, highly parallel, CUDAless biological neural netwo

Madhava Jay 0 Dec 21, 2021