Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Overview

Perceiver Twitter

PyPI Lint with Black⬛ Upload Python Package DOI Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on top of Transformers such that the data only enters through the cross attention mechanism (see figure) and allow it to scale to hundreds of thousands of inputs, like ConvNets. This, in part also solves the Transformers Quadratic compute and memory bottleneck.

Yannic Kilcher's video was very helpful.

Installation

Run the following to install:

pip install perceiver

Developing perceiver

To install perceiver, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Perceiver.git
# or clone your own fork

cd perceiver
pip install -e .[dev]

A bit about Perceiver

The Perceiver model aims to deal with arbitrary configurations of different modalities using a single transformer-based architecture. Transformers are often flexible and make few assumptions about their inputs, but that also scale quadratically with the number of inputs in terms of both memory and computation. This model proposes a mechanism that makes it possible to deal with high-dimensional inputs, while retaining the expressivity and flexibility to deal with arbitrary input configurations.

The idea here is to introduce a small set of latent units that forms an attention bottleneck through which the inputs must pass. This avoids the quadratic scaling problem of all-to-all attention of a classical transformer. The model can be seen as performing a fully end-to-end clustering of the inputs, with the latent units as the cluster centres, leveraging a highly asymmetric crossattention layer. For spatial information the authors compensate for the lack of explicit grid structures in our model by associating Fourier feature encodings.

Usage

from perceiver import Perceiver
import tensorflow as tf

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = tf.random.normal([1, 224, 224, 3]) # replicating 1 imagenet image
model(img) # (1, 1000)

About the notebooks

perceiver_example

Open In Colab Binder

This notebook installs the perceiver package and shows an example of running it on a single imagenet image ([1, 224, 224, 3]) with 1000 classes to demonstarte the working of this model.

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • error with tf2.4.1

    error with tf2.4.1

    Hello Rishit,

    thank you for your Perceiver implementation! I have two notes, I am not very familiar with tf2 though. You define and call a tf.keras.Sequential model here https://github.com/Rishit-dagli/Perceiver/blob/4d3b9b0514da4fb623d178e3e70df1836ebad5ba/perceiver/perceiver.py#L106 For my version of tf at least this throws an error, I think it should be defined once in __init__ and then just called in call.

    And just above it, you compute data but then you don't pass it to self.model. Is that correct?

    bug 
    opened by abred 3
  • Training code

    Training code

    Hi there,

    I've tried to set up a standard MNIST training over the last few days using the Perceiver code provided here. So far, I've not been able to come up with any solution where the model actually learns anything. A major problem so far has been the way the model is written with no support for model.fit() and the whole functional API.

    Do you happen to have any training example code for your model which you could provide here in this repo? MNIST as the default starting point would be nice, but anything would do the job as well :)

    question 
    opened by tpetri94 2
  • Create a FeedForward layer

    Create a FeedForward layer

    Create a simple FeedForward layer as a tf.keras.layers.Layer which should essentially contain a Dense layer with the modified GELU activation (#2 ), optionally I could also include a dropout layer and another Dense layer which should have the number of neurons equal to the dimension

    opened by Rishit-dagli 0
  • Implement a PreNorm layer

    Implement a PreNorm layer

    Create a Normalization layer from the tf.keras.layerr.Layers. This should essentially figure out the right axis and implement layer normalization on it.

    opened by Rishit-dagli 0
  • Don't pin TensorFlow version to a specific number

    Don't pin TensorFlow version to a specific number

    Hello,

    In setup.py you should change "tensorflow~=2.4.0" to " "tensorflow>2.4.0" to ensure any version above the minimal one is used.

    bug 
    opened by ebursztein 0
Releases(v0.1.2)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
An Approach to Explore Logistic Regression Models

User-centered Regression An Approach to Explore Logistic Regression Models This tool applies the potential of Attribute-RadViz in identifying correlat

0 Nov 12, 2021
Calculates JMA (Japan Meteorological Agency) seismic intensity (shindo) scale from acceleration data recorded in NumPy array

shindo.py Calculates JMA (Japan Meteorological Agency) seismic intensity (shindo) scale from acceleration data stored in NumPy array Introduction Japa

RR_Inyo 3 Sep 23, 2022
Python tools for 3D face: 3DMM, Mesh processing(transform, camera, light, render), 3D face representations.

face3d: Python tools for processing 3D face Introduction This project implements some basic functions related to 3D faces. You can use this to process

Yao Feng 2.3k Dec 30, 2022
Code for the ICCV2021 paper "Personalized Image Semantic Segmentation"

PSS: Personalized Image Semantic Segmentation Paper PSS: Personalized Image Semantic Segmentation Yu Zhang, Chang-Bin Zhang, Peng-Tao Jiang, Ming-Ming

张宇 15 Jul 09, 2022
Website which uses Deep Learning to generate horror stories.

Creepypasta - Text Generator Website which uses Deep Learning to generate horror stories. View Demo · View Website Repo · Report Bug · Request Feature

Dhairya Sharma 5 Oct 14, 2022
Measures input lag without dedicated hardware, performing motion detection on recorded or live video

What is InputLagTimer? This tool can measure input lag by analyzing a video where both the game controller and the game screen can be seen on a webcam

Bruno Gonzalez 4 Aug 18, 2022
Official Pytorch implementation of "Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes", CVPR 2022

Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes / 3DCrowdNet News 💪 3DCrowdNet achieves the state-of-the-art accuracy on 3D

Hongsuk Choi 113 Dec 21, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning This repository is the official implementation of CARE.

ChongjianGE 89 Dec 02, 2022
Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning

Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning

Emile van Krieken 140 Dec 30, 2022
Code accompanying "Evolving spiking neuron cellular automata and networks to emulate in vitro neuronal activity," accepted to IEEE SSCI ICES 2021

Evolving-spiking-neuron-cellular-automata-and-networks-to-emulate-in-vitro-neuronal-activity Code accompanying "Evolving spiking neuron cellular autom

SOCRATES: Self-Organizing Computational substRATES 2 Dec 02, 2022
The Ludii general game system, developed as part of the ERC-funded Digital Ludeme Project.

The Ludii General Game System Ludii is a general game system being developed as part of the ERC-funded Digital Ludeme Project (DLP). This repository h

Digital Ludeme Project 50 Jan 04, 2023
Pytorch implementation of four neural network based domain adaptation techniques: DeepCORAL, DDC, CDAN and CDAN+E. Evaluated on benchmark dataset Office31.

Deep-Unsupervised-Domain-Adaptation Pytorch implementation of four neural network based domain adaptation techniques: DeepCORAL, DDC, CDAN and CDAN+E.

Alan Grijalva 49 Dec 20, 2022
AVD Quickstart Containerlab

AVD Quickstart Containerlab WARNING This repository is still under construction. It's fully functional, but has number of limitations. For example: RE

Carl Buchmann 3 Apr 10, 2022
Repository for reproducing `Model-Based Robust Deep Learning`

Model-Based Robust Deep Learning (MBRDL) In this repository, we include the code necessary for reproducing the code used in Model-Based Robust Deep Le

Alex Robey 16 Sep 19, 2022
AI Face Mesh: This is a simple face mesh detection program based on Artificial intelligence.

AI Face Mesh: This is a simple face mesh detection program based on Artificial Intelligence which made with Python. It's able to detect 468 different

Md. Rakibul Islam 1 Jan 13, 2022
SCU OlympicsRunning Baseline

Competition 1v1 running Environment check details in Jidi Competition RLChina2021智能体竞赛 做出的修改: 奖励重塑:修改了环境,重新设置了奖励的分配,使得奖励组成不只有零和博弈,还有探索环境的奖励。 算法微调:修改了官

ZiSeoi Wong 2 Nov 23, 2021
Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation

DistMIS Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation. DistriMIS Distributing Deep Learning Hyperparameter Tuning

HiEST 2 Sep 09, 2022
ElasticFace: Elastic Margin Loss for Deep Face Recognition

This is the official repository of the paper: ElasticFace: Elastic Margin Loss for Deep Face Recognition Paper on arxiv: arxiv Model Log file Pretrain

Fadi Boutros 113 Dec 14, 2022
Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral)

Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral) This is the official implementat

Yifan Zhang 259 Dec 25, 2022
Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation

Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation Our paper is accepted by ICCV2021. Picture: Overview of the proposed Plug-an

Yunfei Liu 32 Dec 10, 2022