Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.

Overview

CaiT-TF (Going deeper with Image Transformers)

TensorFlow 2.8 HugginFace badge Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different CaiT [1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been populated with the original CaiT pre-trained params available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.

Refer to the "Using the models" section to get started.

Table of contents

Conversion

TensorFlow / Keras implementations are available in cait/models.py. Conversion utilities are in convert.py.

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).

model_name top1_acc(%) top5_acc(%)
cait_s24_224 83.368 96.576
cait_xxs24_224 78.524 94.212
cait_xxs36_224 79.76 94.876
cait_xxs36_384 81.976 96.064
cait_xxs24_384 80.648 95.516
cait_xs24_384 83.738 96.756
cait_s24_384 84.944 97.212
cait_s36_384 85.192 97.372
cait_m36_384 85.924 97.598
cait_m48_448 86.066 97.590

Results can be verified with the code in i1k_eval. Results are in line with [1]. Slight differences in the results stemmed from the fact that I used a different set of augmentation transformations. Original transformations suggested by the authors can be found here.

Using the models

Pre-trained models:

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).

Original Image Class Attention Maps Class Saliency Map
cropped image cls attn saliency

For the best quality, refer to the assets directory. You can also generate these plots using the following interactive demos on Hugging Face Spaces:

Randomly initialized models:

from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
 
config = base_config.get_config(
    model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))

To initialize a network with say, 5 classes, do:

config = base_config.get_config(
    model_name="cait_xxs24_224"
)
with config.unlocked():
    config.num_classes = 5
cait_xxs24_224 = CaiT(config)

To view different model configurations, refer to convert_all_models.py.

References

[1] CaiT paper: https://arxiv.org/abs/2103.17239

[2] Official CaiT code: https://github.com/facebookresearch/deit

Acknowledgements

Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
ColossalAI-Benchmark - Performance benchmarking with ColossalAI

Benchmark for Tuning Accuracy and Efficiency Overview The benchmark includes our

HPC-AI Tech 31 Oct 07, 2022
Image segmentation with private İstanbul Dataset

Image Segmentation This repo was created for academic research and test result. Repo will update after academic article online. This repo contains wei

İrem KÖMÜRCÜ 9 Dec 11, 2022
N-Omniglot is a large neuromorphic few-shot learning dataset

N-Omniglot [Paper] || [Dataset] N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses D

11 Dec 05, 2022
This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?".

Patches Are All You Need? 🤷 This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?". Code ov

ICLR 2022 Author 934 Dec 30, 2022
Official implementation of "Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform", ICCV 2021

Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform This repository is the implementation of "Variable-Rate Deep Image C

Myungseo Song 47 Dec 13, 2022
An End-to-End Machine Learning Library to Optimize AUC (AUROC, AUPRC).

Logo by Zhuoning Yuan LibAUC: A Machine Learning Library for AUC Optimization Website | Updates | Installation | Tutorial | Research | Github LibAUC a

Optimization for AI 176 Jan 07, 2023
N-gram models- Unsmoothed, Laplace, Deleted Interpolation

N-gram models- Unsmoothed, Laplace, Deleted Interpolation

Ravika Nagpal 1 Jan 04, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
Official implementation of Protected Attribute Suppression System, ICCV 2021

Official implementation of Protected Attribute Suppression System, ICCV 2021

Prithviraj Dhar 6 Jan 01, 2023
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
Official repository of OFA. Paper: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework

Paper | Blog OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image gene

OFA Sys 1.4k Jan 08, 2023
Code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms.

RDC-SLAM This repository contains code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms. The system takes in

40 Nov 19, 2022
In the AI for TSP competition we try to solve optimization problems using machine learning.

AI for TSP Competition Goal In the AI for TSP competition we try to solve optimization problems using machine learning. The competition will be hosted

Paulo da Costa 11 Nov 27, 2022
The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch Railway

Openspoor The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch

7 Aug 22, 2022
Pytorch implementation of our paper under review -- 1xN Pattern for Pruning Convolutional Neural Networks

1xN Pattern for Pruning Convolutional Neural Networks (paper) . This is Pytorch re-implementation of "1xN Pattern for Pruning Convolutional Neural Net

Mingbao Lin (林明宝) 29 Nov 29, 2022
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
The audio-video synchronization of MKV Container Format is exploited to achieve data hiding

The audio-video synchronization of MKV Container Format is exploited to achieve data hiding, where the hidden data can be utilized for various management purposes, including hyper-linking, annotation

Maxim Zaika 1 Nov 17, 2021
Improving Object Detection by Label Assignment Distillation

Improving Object Detection by Label Assignment Distillation This is the official implementation of the WACV 2022 paper Improving Object Detection by L

Cybercore Co. Ltd 51 Dec 08, 2022
Frigate - NVR With Realtime Object Detection for IP Cameras

A complete and local NVR designed for HomeAssistant with AI object detection. Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras.

Blake Blackshear 6.4k Dec 31, 2022
AI-based, context-driven network device ranking

Batea A batea is a large shallow pan of wood or iron traditionally used by gold prospectors for washing sand and gravel to recover gold nuggets. Batea

Secureworks Taegis VDR 269 Nov 26, 2022