Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Overview

CrossViT

This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv

If you use the codes and models from this repo, please cite our work. Thanks!

@inproceedings{
    chen2021crossvit,
    title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
    author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
    booktitle={International Conference on Computer Vision (ICCV)},
    year={2021}
}

Installation

To install requirements:

pip install -r requirements.txt

With conda:

conda create -n crossvit python=3.8
conda activate crossvit
conda install pytorch=1.7.1 torchvision  cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pretrained models

We provide models trained on ImageNet1K. You can find models here. And you can load pretrained weights into models by add --pretrained flag.

Training

To train crossvit_9_dagger_224 on ImageNet on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Other model names can be found at models/crossvit.py.

Multinode training

Distributed training is available via Slurm and submitit:

To train a crossvit_9_dagger_224 model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30

Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus.

Machine 0:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Machine 1:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Note that: some slurm configurations might need to be changed based on your cluster.

Evaluation

To evaluate a pretrained model on crossvit_9_dagger_224:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained
You might also like...
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Official implementation of paper
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Comments
  • Multilabel classification

    Multilabel classification

    @jjasghar @krook @chunfuchen thanks a lot for sharing the code , i have a problem statement which i want to train crossvit on please find the details below

    1. input is of image shape 256x128 having label vector as gt for multilabel classification
    2. can i use crossvit to train the model , what all modification has to be done with the code base ??

    THnaks for the support

    opened by abhigoku10 4
  • Honoring distributed flag + fixing reset_classifier

    Honoring distributed flag + fixing reset_classifier

    1. Honoring the args.distributed flag in calls to evaluate().
    2. A couple of changes to make the reset_classifier() method work:
    • Initializing the embed_dim instance variable in VisionTransformer.
    • Reinitializing the classification head for all branches.
    opened by abhrac 1
  • Parameter setting

    Parameter setting

    Hello, thank you for your excellent work, I would like to know how you set the parameters on the CIFAR10 dataset, mainly the size of the patch,Looking forward to your reply

    opened by happy20200 1
Owner
International Business Machines
International Business Machines
Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation

Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation Introduction ACoSP is an online pruning algorithm that compr

Merantix 8 Dec 07, 2022
Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth

Instance segmentation by jointly optimizing spatial embeddings and clustering bandwidth This codebase implements the loss function described in: Insta

209 Dec 07, 2022
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
Codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Background Mixing

Contrast and Mix (CoMix) The repository contains the codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Backgroun

Computer Vision and Intelligence Research (CVIR) 13 Dec 10, 2022
LBK 35 Dec 26, 2022
PyTorch implementation of DCT fast weight RNNs

DCT based fast weights This repository contains the official code for the paper: Training and Generating Neural Networks in Compressed Weight Space. T

Kazuki Irie 4 Dec 24, 2022
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY

M-BERT-Study CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY Motivation Multilingual BERT (M-BERT) has shown surprising cross lingual a

CogComp 1 Feb 28, 2022
J.A.R.V.I.S is an AI virtual assistant made in python.

J.A.R.V.I.S is an AI virtual assistant made in python. Running JARVIS Without Python To run JARVIS without python: 1. Head over to our installation pa

somePythonProgrammer 16 Dec 29, 2022
A universal framework for learning timestamp-level representations of time series

TS2Vec This repository contains the official implementation for the paper Learning Timestamp-Level Representations for Time Series with Hierarchical C

Zhihan Yue 284 Dec 30, 2022
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
Crab is a flexible, fast recommender engine for Python that integrates classic information filtering recommendation algorithms in the world of scientific Python packages (numpy, scipy, matplotlib).

Crab - A Recommendation Engine library for Python Crab is a flexible, fast recommender engine for Python that integrates classic information filtering r

python-recsys 1.2k Dec 21, 2022
Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Enformer - Pytorch (wip) Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow

Phil Wang 235 Dec 27, 2022
Cancer-and-Tumor-Detection-Using-Inception-model - In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks, specifically here the Inception model by google.

Cancer-and-Tumor-Detection-Using-Inception-model In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks

Deepak Nandwani 1 Jan 01, 2022
Code, environments, and scripts for the paper: "How Private Is Your RL Policy? An Inverse RL Based Analysis Framework"

Privacy-Aware Inverse RL (PRIL) Analysis Framework Code, environments, and scripts for the paper: "How Private Is Your RL Policy? An Inverse RL Based

1 Dec 06, 2021
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
DeepMReye: magnetic resonance-based eye tracking using deep neural networks

DeepMReye: magnetic resonance-based eye tracking using deep neural networks

73 Dec 21, 2022
FAMIE is a comprehensive and efficient active learning (AL) toolkit for multilingual information extraction (IE)

FAMIE: A Fast Active Learning Framework for Multilingual Information Extraction

18 Sep 01, 2022
Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning And private Server services

Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning

MaCan 4.2k Dec 29, 2022
Unified file system operation experience for different backend

megfile - Megvii FILE library Docs: http://megvii-research.github.io/megfile megfile provides a silky operation experience with different backends (cu

MEGVII Research 76 Dec 14, 2022