Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

Overview

🦩 Flamingo - Pytorch

Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks

Install

$ pip install flamingo-pytorch

Usage

import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above

The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id

import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)

media_locations = torch.randint(0, 2, (1, 512)).bool()

text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)

That's it!

Attention is all you need.

Full working example with Flamingo + PaLM 🌴 🦩 🌴

Integration with PaLM

First install vit-pytorch for the vision encoder

$ pip install vit-pytorch

Then

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

It is quite evident where this is all headed if you think beyond just images.

Inception

For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.

Citations

@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • PerceiverResampler missing some LayerNorms?

    PerceiverResampler missing some LayerNorms?

    Hey, it feels like PerceiverResampler is missing some LayerNorms? it seems to me we should layer-norm x before sending to attentions loop, and may be add layer-norm to ff(latents) + latents?

    opened by inspirit 7
  • Missing flatten op in PerceiverResampler?

    Missing flatten op in PerceiverResampler?

    Hi, It seems that Flamingo did "x_f = flatten(x_f) # [T, S, d] -> [T * S, d]" (batch size == 1) before putting x_f to attention.

    So, it should be like: medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) perceived = perceive(medias) # (1, 64, 1024) - (batch, num latents, dimension)

    ??

    opened by zengyan-97 6
  • wrong attention masks?

    wrong attention masks?

    https://github.com/lucidrains/flamingo-pytorch/blob/44920f4191ba3c280ff84c6ebc76025656d1dab5/flamingo_pytorch/flamingo_pytorch.py#L159

    In the flamingo paper, the language features in the gated cross-attention layers only attend to the visual features from the immediate preceding image. I believe your attention masks are created in such a way that they attend to the visual features from all preceding images. Can you confirm? If so, a fix would be to simply change the '>=' to '=='.

    opened by dhansmair 4
  • zeroing out attention not working

    zeroing out attention not working

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_pytorch.py#L179

    you are not using the inplace version of the function: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

    so I think this line does not have an effect.

    Best, David

    opened by dhansmair 2
  • Applying parallel attn with ff to existing pretrained model?

    Applying parallel attn with ff to existing pretrained model?

    Hi - awesome work! I am trying to understand ? I couldn't find a paper - only a reference to https://github.com/kingoflolz/mesh-transformer-jax. Is this right? Am I understanding that it is bascially applying multiple operations of for qkv and ff at once? Is it possible to use this trick to modify an existing pretrained model?

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_palm.py#L90

    Many thanks in advance!

    Huu

    opened by ontocord 1
  • How to use Flamingo for VQA task?

    How to use Flamingo for VQA task?

    Hi, Thanks for sharing this awesome implementation. I am very interested in using Flamingo model for my usecase. How I can use this implementation to get inference on my dataset for VQA task? I have certain images of products and want extract some information image of product by questioning it. How I can do it ?

    Please help.

    thanks

    opened by karndeepsingh 0
  • Fine-tuning of a model

    Fine-tuning of a model

    Hi, Thank you for this great work. I want to ask how can I fine-tune this model on my dataset for some downstream task like image captioning or image classification? If it is possible for you can you also please share the code?

    opened by ans92 0
  • Need a sample ipython notebook

    Need a sample ipython notebook

    Hello, @lucidrains,

    Thank you for providing this.

    For demo purposes, could you please provide a sample demo in Jupyter notebook?🫠

    Best, LITDataScience

    opened by LITDataScience 0
Releases(0.1.2)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
PyTorch code for ICPR 2020 paper Future Urban Scene Generation Through Vehicle Synthesis

Future urban scene generation through vehicle synthesis This repository contains Pytorch code for the ICPR2020 paper "Future Urban Scene Generation Th

Alessandro Simoni 4 Oct 11, 2021
Wordplay, an artificial Intelligence based crossword puzzle solver.

Wordplay, AI based crossword puzzle solver A crossword is a word puzzle that usually takes the form of a square or a rectangular grid of white- and bl

Vaibhaw 4 Nov 16, 2022
SeqAttack: a framework for adversarial attacks on token classification models

A framework for adversarial attacks against token classification models

Walter 23 Nov 25, 2022
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
A PyTorch re-implementation of Neural Radiance Fields

nerf-pytorch A PyTorch re-implementation Project | Video | Paper NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis Ben Mildenhall

Krishna Murthy 709 Jan 09, 2023
ReferFormer - Official Implementation of ReferFormer

The official implementation of the paper: Language as Queries for Referring Video Object Segmentation Language as Queries for Referring Video Object S

Jonas Wu 232 Dec 29, 2022
Neon-erc20-example - Example of creating SPL token and wrapping it with ERC20 interface in Neon EVM

Example of wrapping SPL token by ERC2-20 interface in Neon Requirements Install

7 Mar 28, 2022
IDA file loader for UF2, created for the DEFCON 29 hardware badge

UF2 Loader for IDA The DEFCON 29 badge uses the UF2 bootloader, which conveniently allows you to dump and flash the firmware over USB as a mass storag

Kevin Colley 6 Feb 08, 2022
Mesh Graphormer is a new transformer-based method for human pose and mesh reconsruction from an input image

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
🌊 Online machine learning in Python

In a nutshell River is a Python library for online machine learning. It is the result of a merger between creme and scikit-multiflow. River's ambition

OnlineML 4k Jan 02, 2023
An open source Python package for plasma science that is under development

PlasmaPy PlasmaPy is an open source, community-developed Python 3.7+ package for plasma science. PlasmaPy intends to be for plasma science what Astrop

PlasmaPy 444 Jan 07, 2023
[CVPR 2022 Oral] Rethinking Minimal Sufficient Representation in Contrastive Learning

Rethinking Minimal Sufficient Representation in Contrastive Learning PyTorch implementation of Rethinking Minimal Sufficient Representation in Contras

36 Nov 23, 2022
Model Serving Made Easy

The easiest way to build Machine Learning APIs BentoML makes moving trained ML models to production easy: Package models trained with any ML framework

BentoML 4.4k Jan 08, 2023
A PyTorch implementation of "Signed Graph Convolutional Network" (ICDM 2018).

SGCN â € A PyTorch implementation of Signed Graph Convolutional Network (ICDM 2018). Abstract Due to the fact much of today's data can be represented as

Benedek Rozemberczki 251 Nov 30, 2022
Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

CGTransformer Code for our AAAI 2022 paper "Contrastive-Geometry Transformer network for Generalized 3D Pose Transfer" Contrastive-Geometry Transforme

18 Jun 28, 2022
A new benchmark for Icon Question Answering (IconQA) and a large-scale icon dataset Icon645.

IconQA About IconQA is a new diverse abstract visual question answering dataset that highlights the importance of abstract diagram understanding and c

Pan Lu 24 Dec 30, 2022
This project uses ViT to perform image classification tasks on DATA set CIFAR10.

Vision-Transformer-Multiprocess-DistributedDataParallel-Apex Introduction This project uses ViT to perform image classification tasks on DATA set CIFA

Kaicheng Yang 3 Jun 03, 2022
Implementation of the paper "Fine-Tuning Transformers: Vocabulary Transfer"

Transformer-vocabulary-transfer Implementation of the paper "Fine-Tuning Transfo

LEYA 13 Nov 30, 2022