Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch

Overview

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Yannic Kilcher presentation

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# do your vision transformer training

vit = Extractor(vit, return_embeddings_only = True)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

    opened by fedshyvana 2
  • Extractor in vit_pytorch will detach the tensor.

    Extractor in vit_pytorch will detach the tensor.

    Thanks for your code! I think I may find a little bug. The cloned tensor in Extractor will be detached (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/extractor.py#L39). So gradient may not propagate back to image encoder.

    opened by techkang 2
  • Maybe don't need this rearrange

    Maybe don't need this rearrange

    I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461

    opened by CiaoHe 2
  • How to train the model using my own dataset?

    How to train the model using my own dataset?

    Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

    # train by giving CoCa your text and images with `return_loss = True`
    loss = coca(
        text = text,
        images = images,
        return_loss = True  # set this to True to get the full caption + contrastive loss
    )
    
    opened by keepcodeandsmile 1
  • why train VIT visual encoder first?

    why train VIT visual encoder first?

    Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

    opened by Flowerfan 1
  • attn_mask

    attn_mask

    cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
    
    attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
    sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
    

    Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

    opened by pldlgb 0
  • Reproducing the results in the paper

    Reproducing the results in the paper

    Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

    opened by GKIBMNY 0
  • Generating the caption of a given image

    Generating the caption of a given image

    Hello,

    Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

    Thank you in advance.

    opened by claudiogreco 0
Releases(0.0.7)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This is an open source library implementing hyperbox-based machine learning algorithms

hyperbox-brain is a Python open source toolbox implementing hyperbox-based machine learning algorithms built on top of scikit-learn and is distributed

Complex Adaptive Systems (CAS) Lab - University of Technology Sydney 21 Dec 14, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
AI-Bot - 一个基于watermelon改造的OpenAI-GPT-2的智能机器人

AI-Bot 一个基于watermelon改造的OpenAI-GPT-2的智能机器人 在Binder上直接运行测试 目前有两种实现方式 TF2的GPT-2 TF

9 Nov 16, 2022
🚀 An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

🚀 An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

Made With ML 82 Jun 26, 2022
Source code for our paper "Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures"

Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures Code for the Multiplex Molecular Graph Neural Network (M

shzhang 59 Dec 10, 2022
PyTorch implementation of SIFT descriptor

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Nonnegative spatial factorization for multivariate count data

Nonnegative spatial factorization for multivariate count data This repository contains supporting code to facilitate reproducible analysis. For detail

Will Townes 24 Dec 19, 2022
Draw like Bob Ross using the power of Neural Networks (With PyTorch)!

Draw like Bob Ross using the power of Neural Networks! (+ Pytorch) Learning Process Visualization Getting started Install dependecies Requires python3

Kendrick Tan 116 Mar 07, 2022
Official Code Release for "TIP-Adapter: Training-free clIP-Adapter for Better Vision-Language Modeling"

Official Code Release for "TIP-Adapter: Training-free clIP-Adapter for Better Vision-Language Modeling" Pipeline of Tip-Adapter Tip-Adapter can provid

peng gao 187 Dec 28, 2022
TensorFlow2 Classification Model Zoo playing with TensorFlow2 on the CIFAR-10 dataset.

Training CIFAR-10 with TensorFlow2(TF2) TensorFlow2 Classification Model Zoo. I'm playing with TensorFlow2 on the CIFAR-10 dataset. Architectures LeNe

Chia-Hung Yuan 16 Sep 27, 2022
Generalized Data Weighting via Class-level Gradient Manipulation

Generalized Data Weighting via Class-level Gradient Manipulation This repository is the official implementation of Generalized Data Weighting via Clas

18 Nov 12, 2022
Data & Code for ACCENTOR Adding Chit-Chat to Enhance Task-Oriented Dialogues

ACCENTOR: Adding Chit-Chat to Enhance Task-Oriented Dialogues Overview ACCENTOR consists of the human-annotated chit-chat additions to the 23.8K dialo

Facebook Research 69 Dec 29, 2022
Code for "Learning Graph Cellular Automata"

Learning Graph Cellular Automata This code implements the experiments from the NeurIPS 2021 paper: "Learning Graph Cellular Automata" Daniele Grattaro

Daniele Grattarola 37 Oct 26, 2022
Vehicle detection using machine learning and computer vision techniques for Udacity's Self-Driving Car Engineer Nanodegree.

Vehicle Detection Video demo Overview Vehicle detection using these machine learning and computer vision techniques. Linear SVM HOG(Histogram of Orien

hata 1.1k Dec 18, 2022
Contextual Attention Localization for Offline Handwritten Text Recognition

CALText This repository contains the source code for CALText model introduced in "CALText: Contextual Attention Localization for Offline Handwritten T

0 Feb 17, 2022
The official implementation for "FQ-ViT: Fully Quantized Vision Transformer without Retraining".

FQ-ViT [arXiv] This repo contains the official implementation of "FQ-ViT: Fully Quantized Vision Transformer without Retraining". Table of Contents In

132 Jan 08, 2023
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation

Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation Woncheol Shin1, Gyubok Lee1, Jiyoung Lee1, Joonseok Lee2,3, Edward Ch

Woncheol Shin 7 Sep 26, 2022
Riemann Noise Injection With PyTorch

Riemann Noise Injection - PyTorch A module for modeling GAN noise injection based on Riemann geometry, as described in Ruili Feng, Deli Zhao, and Zhen

2 May 27, 2022
STEM: An approach to Multi-source Domain Adaptation with Guarantees

STEM: An approach to Multi-source Domain Adaptation with Guarantees Introduction This is the official implementation of ``STEM: An approach to Multi-s

5 Dec 19, 2022