Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Overview

Perceiver - Pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Install

$ pip install perceiver-pytorch

Usage

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    num_fourier_features = 6,    # number of fourier features, with original value (2 * K + 1)
    depth = 48,                  # depth of net, in paper, they went deep, making up for lack of attention
    num_latents = 6,             # number of latents, or induced set points, or centroids. different papers giving it different names
    cross_dim = 512,             # cross attention dimension
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

img = torch.randn(1, 224 * 224) # 1 imagenet image, pixelized

model(img) # (1, 1000)

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Latent averaging to the logits?

    Latent averaging to the logits?

    I read through the paper last night and came away confused about a few things. I looked through your code hoping for some clarity.

    One issue that doesn't seem to be explained in the paper (or I am missing it) is how the authors go from a set of latents to the logits used at the classification head. You implemented this by taking the mean of the latent set:

    https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_pytorch.py#L203

    Is this actually how the authors convert to logits?

    opened by neonbjb 7
  • PerceiverAR?

    PerceiverAR?

    Hey @lucidrains - love this repo, and still trying to wrap my head around the various difference between Perceiver architectures; how hard would it be to extend PerceiverIO to PerceiverAR; what fundamentally needs to change?

    opened by siddk 5
  • Not using the classification head in Perceiver

    Not using the classification head in Perceiver

    Hi @lucidrains, thank you for your great job!

    I'd like to use the Perceiver (not PerceiverIO) without the classification head (average and projection). Do you think we could add an option to avoid using it? I can do a PR if you want.

    Thanks!

    opened by gegallego 4
  • Decoder Attention Module needs a FF network as well in perceiver_io.py script

    Decoder Attention Module needs a FF network as well in perceiver_io.py script

    Hi,

    According to perceiver io paper's (https://arxiv.org/abs/2107.14795) architectural details, they mention that the decoder attention block contains a cross attention block (4), which is already implemented in the perceiver_io.py script (Line 151), followed by a Feedforward network, given by equation (6) in the paper, which is not present in that script. I am not aware of the repercussions of not having FF in the decoder module but it might be a good idea to have it in the implementation. Something like self.decoder_ff = PreNorm(FeedForward(queries_dim)) would do the job. Experimentally, the authors had found that omitting equation (5) is helpful.

    opened by Hritikbansal 4
  • Positional encoding are already part of the input

    Positional encoding are already part of the input

    Hello! First of all, thank you for this implementation.

    My inputs already have the proper positional encoding as part of the channel axis. Would it be possible to add a feature to deactivate the default implementation of the positional encoding?

    Thank you!

    opened by Atlis 4
  • x = self.latents + self.pos_emb

    x = self.latents + self.pos_emb

    self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
    self.pos_emb = nn.Parameter(torch.randn(num_latents, latent_dim))
    ...
    x = self.latents + self.pos_emb
    

    I'm not very familiar with pytorch, but does this make sense? I mean, what's intended when 2 trainable weight matrices are simply summed and that's that's the only place where both latents and pos_emb appear. It looks like it can be replaced with only one matrix.

    opened by galchinsky 4
  • Fourier encoding is not similar to the paper

    Fourier encoding is not similar to the paper

    First of all, thanks for sharing the code !

    I have a follow up question to #4.

    In the paper, the authors mentioned about [sin(f_kπx_d), cos(f_kπx_d)], where f_k is a bank of frequencies spaced log-linearly between 1 and µ/2. Can you maybe point out how you came to the 1/2**i scaling in the code ?

    https://github.com/lucidrains/perceiver-pytorch/blob/6ae733773d29cb29383f3ac7b45af8cb6bd2c0dc/perceiver_pytorch/perceiver_pytorch.py#L28-L35

    Thanks!

    opened by cheneeheng 4
  • Fourier encoding should be for position coordinates instead of byte array

    Fourier encoding should be for position coordinates instead of byte array

    The fourier_encode function as implemented takes as input a byte array x and directly encodes it with sin/cos before concating with the input.

    As I understand the NeRF position encodings, they encode the x/y/etc. position coordinates, and not a transformation of the data itself. From the Perceiver paper:

    We parametrize the frequency encoding to take the values [sin(fkπxd), cos(fkπxd)], where the frequencies fk is the kth band of a bank of frequencies spaced log-linearly between 1 and µ/2... For example, by allowing the network to resolve the maximum frequency present in an input array, we can encourage it to learn to compare the values of bytes at any positions in the input array. xd is the value of the input position along the dth dimension (e.g. for images d = 2 and for video d = 3). xd takes values in [−1, 1] for each dimension. We concatenate the raw positional value xd to produce the final representation of position. This results in a positional encoding of size d(2K + 1).

    NeRF position encoding examples:

    • https://github.com/bmild/nerf/blob/20a91e764a28816ee2234fcadb73bd59a613a44c/run_nerf_helpers.py#L22
    • https://github.com/ankurhanda/nerf2D
    opened by eridgd 4
  • Positional encoding frequency bands should be linearly spaced

    Positional encoding frequency bands should be linearly spaced

    A small bug, but as alluded to in this comment by @marcdumon, it seems as though the frequency bands are indeed spaced linearly in the official JAX implementation.

    opened by djl11 2
  • Bug in fourier_encode (?)

    Bug in fourier_encode (?)

    Thank you for this great implementation. I'm learning a lot from it!

    I think I found a problem in the fourier_encode method. In this line: https://github.com/lucidrains/perceiver-pytorch/blob/b33aced4e1b266aeb1383e03ab63f0a9951f9126/perceiver_pytorch/perceiver_pytorch.py#L36

    the scales are always the same whatever value of parameter base. Example:

    max_freq = 10, num_bands=6, base = 2
    => scales = [1.0000, 1.3797, 1.9037, 2.6265, 3.6239, 5.0000]
    
    max_freq = 10, num_bands=6, base = 10
    => scales = [1.0000, 1.3797, 1.9037, 2.6265, 3.6239, 5.0000]
    
    opened by marcdumon 2
  • Attention softmax is applied to incorrect dimension?

    Attention softmax is applied to incorrect dimension?

    I am studying multi-head attention. When I was reading through [1], I found that the attenion softmax is applied over the last dimension of the similarity tensor sim:

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
    
            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
    
            if exists(mask):
                <removed>
    
            # attention, what we cannot get enough of
            attn = sim.softmax(dim = -1)
    

    If I understand correctly sim has the shape (b*h) n1 n2. The softmax is computed over the last dimension n2. Shouldn't the softmax be applied to matrices with all the similarity values of a single head (i.e. with shape n1, n2)?

    [1] https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py#L97

    opened by breuderink 2
  • Issue defining base in fourier_encode for experimental.py, gated.py, mixed_latents.py

    Issue defining base in fourier_encode for experimental.py, gated.py, mixed_latents.py

    Hey Lucid, love the work, it appears you deprecated base in fourier_encode at https://github.com/lucidrains/perceiver-pytorch/commit/144b0d9716a7212b5fd6d95a2267c4d4a08b56a7

    But experimental.py, gated.py, mixed_latents.py are still trying to define the base within the forward pass. https://github.com/lucidrains/perceiver-pytorch/blob/abbb5d5949d3509c57749bd134f5068f2761aac7/perceiver_pytorch/experimental.py#L122 https://github.com/lucidrains/perceiver-pytorch/blob/2d59df42ebb0b7538af77d584f5ae5b50759618b/perceiver_pytorch/mixed_latents.py#L85 https://github.com/lucidrains/perceiver-pytorch/blob/2d59df42ebb0b7538af77d584f5ae5b50759618b/perceiver_pytorch/gated.py#L103

    Thanks again, keep up the great work.

    opened by TannerLaBorde 0
  • Audio + Text data?

    Audio + Text data?

    Can someone please guide me on how you can process both audio and .txt data through perceiver simultaneously for multimodality learning?

    An example code would be nice.

    Thanks

    opened by Sidz1812 1
  • just a suggestion

    just a suggestion

    Hi I like to start with thanking you for such a great work with a lot of great implementations. I have a small suggestion. I suggest for all your codes/modules try to add if __name__ == "__main__": so that if someone just wants to use one file/module can easily try that without having going through whole implementations. for example I am trying to use the this, in case of having a if __name__ == "__main__": I can easily try to run a random input and see how it will work. This will increase the usability with a huge amount.

    Keep up the great work :)

    opened by seyeeet 4
  • What should I change if I want to use data with input size 720*184

    What should I change if I want to use data with input size 720*184

    thanks for sharing this code, I was wondering what should I change if I want to be able to use data that can be converted into images with an input size of 720*184? thanks in advance

    opened by Oussamab21 0
  • Question regarding queries dimensionality in Perceiver IO

    Question regarding queries dimensionality in Perceiver IO

    Hi @lucidrains,

    I think I may be missing something - why do we define the perceiver IO queries vector to have a batch dimension (i.e. queries = torch.randn(1, 128, 32))? Was this just to make the code work nicely? Shouldnt we be using queries = torch.randn(128, 32) ? I expect to use the same embedding for all of my batch elements, which is IIUC what your code is doing.

    opened by pcicales 3
Releases(0.8.6)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Random Walk Graph Neural Networks

Random Walk Graph Neural Networks This repository is the official implementation of Random Walk Graph Neural Networks. Requirements Code is written in

Giannis Nikolentzos 38 Jan 02, 2023
Evaluating deep transfer learning for whole-brain cognitive decoding

Evaluating deep transfer learning for whole-brain cognitive decoding This README file contains the following sections: Project description Repository

Armin Thomas 5 Oct 31, 2022
Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

235 Dec 26, 2022
Manim is an engine for precise programmatic animations, designed for creating explanatory math videos

Manim is an engine for precise programmatic animations, designed for creating explanatory math videos. Note, there are two versions of manim. This rep

Grant Sanderson 49k Jan 09, 2023
Deep Learning Package based on TensorFlow

White-Box-Layer is a Python module for deep learning built on top of TensorFlow and is distributed under the MIT license. The project was started in M

YeongHyeon Park 7 Dec 27, 2021
Supervised 3D Pre-training on Large-scale 2D Natural Image Datasets for 3D Medical Image Analysis

Introduction This is an implementation of our paper Supervised 3D Pre-training on Large-scale 2D Natural Image Datasets for 3D Medical Image Analysis.

24 Dec 06, 2022
Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow

xRBM Library Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow Installation Using pip: pip install xrbm Examples Tut

Omid Alemi 55 Dec 29, 2022
ImageNet Adversarial Image Evaluation

ImageNet Adversarial Image Evaluation This repository contains the code and some materials used in the experimental work presented in the following pa

Utku Ozbulak 11 Dec 26, 2022
Progressive Coordinate Transforms for Monocular 3D Object Detection

Progressive Coordinate Transforms for Monocular 3D Object Detection This repository is the official implementation of PCT. Introduction In this paper,

58 Nov 06, 2022
Merlion: A Machine Learning Framework for Time Series Intelligence

Merlion: A Machine Learning Library for Time Series Table of Contents Introduction Installation Documentation Getting Started Anomaly Detection Foreca

Salesforce 2.8k Dec 30, 2022
A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization

Website, Tutorials, and Docs    Uncertainty Toolbox A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualizatio

Uncertainty Toolbox 1.4k Dec 28, 2022
Experiments and code to generate the GINC small-scale in-context learning dataset from "An Explanation for In-context Learning as Implicit Bayesian Inference"

GINC small-scale in-context learning dataset GINC (Generative In-Context learning Dataset) is a small-scale synthetic dataset for studying in-context

P-Lambda 29 Dec 19, 2022
All-in-one Docker container that allows a user to explore Nautobot in a lab environment.

Nautobot Lab This container is not for production use! Nautobot Lab is an all-in-one Docker container that allows a user to quickly get an instance of

Nautobot 29 Sep 16, 2022
A clean implementation based on AlphaZero for any game in any framework + tutorial + Othello/Gobang/TicTacToe/Connect4 and more

Alpha Zero General (any game, any framework!) A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play

Surag Nair 3.1k Jan 05, 2023
Activity tragle - Google is tracking everything, we just look at it

activity_tragle Google is tracking everything, we just look at it here. You need

BERNARD Guillaume 1 Feb 15, 2022
Official PyTorch implementation of "Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks" (AAAI 2022)

Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks This is the code for reproducing the results of th

2 Dec 27, 2021
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
Using python and scikit-learn to make stock predictions

MachineLearningStocks in python: a starter project and guide EDIT as of Feb 2021: MachineLearningStocks is no longer actively maintained MachineLearni

Robert Martin 1.3k Dec 29, 2022
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
Single-Shot Motion Completion with Transformer

Single-Shot Motion Completion with Transformer 👉 [Preprint] 👈 Abstract Motion completion is a challenging and long-discussed problem, which is of gr

FuxiCV 78 Dec 29, 2022