Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.

Overview

CaiT-TF (Going deeper with Image Transformers)

TensorFlow 2.8 HugginFace badge Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different CaiT [1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been populated with the original CaiT pre-trained params available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.

Refer to the "Using the models" section to get started.

Table of contents

Conversion

TensorFlow / Keras implementations are available in cait/models.py. Conversion utilities are in convert.py.

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).

model_name top1_acc(%) top5_acc(%)
cait_s24_224 83.368 96.576
cait_xxs24_224 78.524 94.212
cait_xxs36_224 79.76 94.876
cait_xxs36_384 81.976 96.064
cait_xxs24_384 80.648 95.516
cait_xs24_384 83.738 96.756
cait_s24_384 84.944 97.212
cait_s36_384 85.192 97.372
cait_m36_384 85.924 97.598
cait_m48_448 86.066 97.590

Results can be verified with the code in i1k_eval. Results are in line with [1]. Slight differences in the results stemmed from the fact that I used a different set of augmentation transformations. Original transformations suggested by the authors can be found here.

Using the models

Pre-trained models:

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).

Original Image Class Attention Maps Class Saliency Map
cropped image cls attn saliency

For the best quality, refer to the assets directory. You can also generate these plots using the following interactive demos on Hugging Face Spaces:

Randomly initialized models:

from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
 
config = base_config.get_config(
    model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))

To initialize a network with say, 5 classes, do:

config = base_config.get_config(
    model_name="cait_xxs24_224"
)
with config.unlocked():
    config.num_classes = 5
cait_xxs24_224 = CaiT(config)

To view different model configurations, refer to convert_all_models.py.

References

[1] CaiT paper: https://arxiv.org/abs/2103.17239

[2] Official CaiT code: https://github.com/facebookresearch/deit

Acknowledgements

Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images"

GANInversion_with_ConsecutiveImgs Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images" https://a

QingyangXu 38 Dec 07, 2022
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
COIN the currently largest dataset for comprehensive instruction video analysis.

COIN Dataset COIN is the currently largest dataset for comprehensive instruction video analysis. It contains 11,827 videos of 180 different tasks (i.e

86 Dec 28, 2022
The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

John Salib 2 Jan 30, 2022
Pytorch implementation of "Geometrically Adaptive Dictionary Attack on Face Recognition" (WACV 2022)

Geometrically Adaptive Dictionary Attack on Face Recognition This is the Pytorch code of our paper "Geometrically Adaptive Dictionary Attack on Face R

6 Nov 21, 2022
DECAF: Deep Extreme Classification with Label Features

DECAF DECAF: Deep Extreme Classification with Label Features @InProceedings{Mittal21, author = "Mittal, A. and Dahiya, K. and Agrawal, S. and Sain

46 Nov 06, 2022
Facebook AI Image Similarity Challenge: Descriptor Track

Facebook AI Image Similarity Challenge: Descriptor Track This repository contains the code for our solution to the Facebook AI Image Similarity Challe

Sergio MP 17 Dec 14, 2022
Speedy Implementation of Instance-based Learning (IBL) agents in Python

A Python library to create single or multi Instance-based Learning (IBL) agents that are built based on Instance Based Learning Theory (IBLT) 1 Instal

0 Nov 18, 2021
Implementation of "Distribution Alignment: A Unified Framework for Long-tail Visual Recognition"(CVPR 2021)

Implementation of "Distribution Alignment: A Unified Framework for Long-tail Visual Recognition"(CVPR 2021)

105 Nov 07, 2022
Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks

OnsagerNet Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks This is the original pyTorch implemenati

Haijun.Yu 3 Aug 24, 2022
Monocular 3D Object Detection: An Extrinsic Parameter Free Approach (CVPR2021)

Monocular 3D Object Detection: An Extrinsic Parameter Free Approach (CVPR2021) Yunsong Zhou, Yuan He, Hongzi Zhu, Cheng Wang, Hongyang Li, Qinhong Jia

Yunsong Zhou 51 Dec 14, 2022
An educational tool to introduce AI planning concepts using mobile manipulator robots.

JEDAI Explains Decision-Making AI Virtual Machine Image The recommended way of using JEDAI is to use pre-configured Virtual Machine image that is avai

Autonomous Agents and Intelligent Robots 13 Nov 15, 2022
SE3 Pose Interp - Interpolate camera pose or trajectory in SE3, pose interpolation, trajectory interpolation

SE3 Pose Interpolation Pose estimated from SLAM system are always discrete, and

Ran Cheng 4 Dec 15, 2022
A transformer model to predict pathogenic mutations

MutFormer MutFormer is an application of the BERT (Bidirectional Encoder Representations from Transformers) NLP (Natural Language Processing) model wi

Wang Genomics Lab 2 Nov 29, 2022
Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Soubhik Sanyal 689 Dec 25, 2022
This script runs neural style transfer against the provided content image.

Neural Style Transfer Content Style Output Description: This script runs neural style transfer against the provided content image. The content image m

Martynas Subonis 0 Nov 25, 2021
An AFL implementation with UnTracer (our coverage-guided tracer)

UnTracer-AFL This repository contains an implementation of our prototype coverage-guided tracing framework UnTracer in the popular coverage-guided fuz

113 Dec 17, 2022
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
Fully Adaptive Bayesian Algorithm for Data Analysis (FABADA) is a new approach of noise reduction methods. In this repository is shown the package developed for this new method based on \citepaper.

Fully Adaptive Bayesian Algorithm for Data Analysis FABADA FABADA is a novel non-parametric noise reduction technique which arise from the point of vi

18 Oct 20, 2022
SciKit-Learn Laboratory (SKLL) makes it easy to run machine learning experiments.

SciKit-Learn Laboratory This Python package provides command-line utilities to make it easier to run machine learning experiments with scikit-learn. O

ETS 528 Nov 25, 2022