Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.

Overview

CaiT-TF (Going deeper with Image Transformers)

TensorFlow 2.8 HugginFace badge Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different CaiT [1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been populated with the original CaiT pre-trained params available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.

Refer to the "Using the models" section to get started.

Table of contents

Conversion

TensorFlow / Keras implementations are available in cait/models.py. Conversion utilities are in convert.py.

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).

model_name top1_acc(%) top5_acc(%)
cait_s24_224 83.368 96.576
cait_xxs24_224 78.524 94.212
cait_xxs36_224 79.76 94.876
cait_xxs36_384 81.976 96.064
cait_xxs24_384 80.648 95.516
cait_xs24_384 83.738 96.756
cait_s24_384 84.944 97.212
cait_s36_384 85.192 97.372
cait_m36_384 85.924 97.598
cait_m48_448 86.066 97.590

Results can be verified with the code in i1k_eval. Results are in line with [1]. Slight differences in the results stemmed from the fact that I used a different set of augmentation transformations. Original transformations suggested by the authors can be found here.

Using the models

Pre-trained models:

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).

Original Image Class Attention Maps Class Saliency Map
cropped image cls attn saliency

For the best quality, refer to the assets directory. You can also generate these plots using the following interactive demos on Hugging Face Spaces:

Randomly initialized models:

from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
 
config = base_config.get_config(
    model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))

To initialize a network with say, 5 classes, do:

config = base_config.get_config(
    model_name="cait_xxs24_224"
)
with config.unlocked():
    config.num_classes = 5
cait_xxs24_224 = CaiT(config)

To view different model configurations, refer to convert_all_models.py.

References

[1] CaiT paper: https://arxiv.org/abs/2103.17239

[2] Official CaiT code: https://github.com/facebookresearch/deit

Acknowledgements

Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
Hcaptcha-challenger - Gracefully face hCaptcha challenge with Yolov5(ONNX) embedded solution

hCaptcha Challenger πŸš€ Gracefully face hCaptcha challenge with Yolov5(ONNX) embe

593 Jan 03, 2023
Fine-grained Control of Image Caption Generation with Abstract Scene Graphs

Faster R-CNN pretrained on VisualGenome This repository modifies maskrcnn-benchmark for object detection and attribute prediction on VisualGenome data

Shizhe Chen 7 Apr 20, 2021
DA2Lite is an automated model compression toolkit for PyTorch.

DA2Lite (Deep Architecture to Lite) is a toolkit to compress and accelerate deep network models. ⭐ Star us on GitHub β€” it helps!! Frameworks & Librari

Sinhan Kang 7 Mar 22, 2022
Generate images from texts. In Russian

ruDALL-E Generate images from texts pip install rudalle==1.1.0rc0 πŸ€— HF Models: ruDALL-E Malevich (XL) ruDALL-E Emojich (XL) (readme here) ruDALL-E S

AI Forever 1.6k Dec 31, 2022
TabNet for fastai

TabNet for fastai This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (=2.0) library. The original paper https://ar

Mikhail Grankin 116 Oct 21, 2022
Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training Consistency Shift (ICCV 2021)

Ξ -NAS This repository provides the evaluation code of our submitted paper: Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training

Jiqi Zhang 18 Aug 18, 2022
Multi-task yolov5 with detection and segmentation based on yolov5

YOLOv5DS Multi-task yolov5 with detection and segmentation based on yolov5(branch v6.0) decoupled head anchor free segmentation head READMEδΈ­ζ–‡ Ablation

150 Dec 30, 2022
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
The source code of the paper "Understanding Graph Neural Networks from Graph Signal Denoising Perspectives"

GSDN-F and GSDN-EF This repository provides a reference implementation of GSDN-F and GSDN-EF as described in the paper "Understanding Graph Neural Net

Guoji Fu 18 Nov 14, 2022
NL-Augmenter 🦎 β†’ 🐍 A Collaborative Repository of Natural Language Transformations

NL-Augmenter 🦎 β†’ 🐍 The NL-Augmenter is a collaborative effort intended to add transformations of datasets dealing with natural language. Transformat

684 Jan 09, 2023
PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Yonglong Tian 2.2k Jan 08, 2023
Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022

AMRBART An implementation for ACL2022 paper "Graph Pre-training for AMR Parsing and Generation". You may find our paper here (Arxiv). Requirements pyt

xfbai 60 Jan 03, 2023
Code for the paper "There is no Double-Descent in Random Forests"

Code for the paper "There is no Double-Descent in Random Forests" This repository contains the code to run the experiments for our paper called "There

2 Jan 14, 2022
This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability.

Delayed-cellular-neural-network This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability. There is als

4 Apr 28, 2022
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 04, 2020
You are AllSet: A Multiset Function Framework for Hypergraph Neural Networks.

AllSet This is the repo for our paper: You are AllSet: A Multiset Function Framework for Hypergraph Neural Networks. We prepared all codes and a subse

Jianhao 51 Dec 24, 2022
Code to produce syntactic representations that can be used to study syntax processing in the human brain

Can fMRI reveal the representation of syntactic structure in the brain? The code base for our paper on understanding syntactic representations in the

Aniketh Janardhan Reddy 4 Dec 18, 2022
Codebase for Inducing Causal Structure for Interpretable Neural Networks

Interchange Intervention Training (IIT) Codebase for Inducing Causal Structure for Interpretable Neural Networks Release Notes 12/01/2021: Code and Pa

Zen 6 Oct 10, 2022
Simple, but essential Bayesian optimization package

BayesO: A Bayesian optimization framework in Python Simple, but essential Bayesian optimization package. http://bayeso.org Online documentation Instal

Jungtaek Kim 74 Dec 05, 2022
DUE: End-to-End Document Understanding Benchmark

This is the repository that provide tools to download data, reproduce the baseline results and evaluation. What can you achieve with this guide Based

21 Dec 29, 2022