Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Overview

CrossViT

This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv

If you use the codes and models from this repo, please cite our work. Thanks!

@inproceedings{
    chen2021crossvit,
    title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
    author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
    booktitle={International Conference on Computer Vision (ICCV)},
    year={2021}
}

Installation

To install requirements:

pip install -r requirements.txt

With conda:

conda create -n crossvit python=3.8
conda activate crossvit
conda install pytorch=1.7.1 torchvision  cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pretrained models

We provide models trained on ImageNet1K. You can find models here. And you can load pretrained weights into models by add --pretrained flag.

Training

To train crossvit_9_dagger_224 on ImageNet on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Other model names can be found at models/crossvit.py.

Multinode training

Distributed training is available via Slurm and submitit:

To train a crossvit_9_dagger_224 model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30

Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus.

Machine 0:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Machine 1:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Note that: some slurm configurations might need to be changed based on your cluster.

Evaluation

To evaluate a pretrained model on crossvit_9_dagger_224:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained
You might also like...
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Official implementation of paper
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Comments
  • Multilabel classification

    Multilabel classification

    @jjasghar @krook @chunfuchen thanks a lot for sharing the code , i have a problem statement which i want to train crossvit on please find the details below

    1. input is of image shape 256x128 having label vector as gt for multilabel classification
    2. can i use crossvit to train the model , what all modification has to be done with the code base ??

    THnaks for the support

    opened by abhigoku10 4
  • Honoring distributed flag + fixing reset_classifier

    Honoring distributed flag + fixing reset_classifier

    1. Honoring the args.distributed flag in calls to evaluate().
    2. A couple of changes to make the reset_classifier() method work:
    • Initializing the embed_dim instance variable in VisionTransformer.
    • Reinitializing the classification head for all branches.
    opened by abhrac 1
  • Parameter setting

    Parameter setting

    Hello, thank you for your excellent work, I would like to know how you set the parameters on the CIFAR10 dataset, mainly the size of the patch,Looking forward to your reply

    opened by happy20200 1
Owner
International Business Machines
International Business Machines
The Pytorch implementation for "Video-Text Pre-training with Learned Regions"

Region_Learner The Pytorch implementation for "Video-Text Pre-training with Learned Regions" (arxiv) We are still cleaning up the code further and pre

Rui Yan 0 Mar 20, 2022
(CVPR 2021) PAConv: Position Adaptive Convolution with Dynamic Kernel Assembling on Point Clouds

PAConv: Position Adaptive Convolution with Dynamic Kernel Assembling on Point Clouds by Mutian Xu*, Runyu Ding*, Hengshuang Zhao, and Xiaojuan Qi. Int

CVMI Lab 228 Dec 25, 2022
The code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" in PyTorch.

PyCIL: A Python Toolbox for Class-Incremental Learning Introduction • Methods Reproduced • Reproduced Results • How To Use • License • Acknowledgement

Fu-Yun Wang 258 Dec 31, 2022
PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)

PSTR (CVPR2022) This code is an official implementation of "PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)". End-to-end one-step

Jiale Cao 28 Dec 13, 2022
Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression.

Spatio-Temporal Entropy Model A Pytorch Reproduction of Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression. More details can

16 Nov 28, 2022
Official PyTorch implementation of UACANet: Uncertainty Aware Context Attention for Polyp Segmentation

UACANet: Uncertainty Aware Context Attention for Polyp Segmentation Official pytorch implementation of UACANet: Uncertainty Aware Context Attention fo

Taehun Kim 85 Dec 14, 2022
[CVPR 2021] Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach

Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach This is the repo to host the dataset TextSeg and code for TexRNe

SHI Lab 174 Dec 19, 2022
Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Segmentation Transformer Implementation of Segmentation Transformer in PyTorch, a new model to achieve SOTA in semantic segmentation while using trans

Abhay Gupta 161 Dec 08, 2022
A collection of IPython notebooks covering various topics.

ipython-notebooks This repo contains various IPython notebooks I've created to experiment with libraries and work through exercises, and explore subje

John Wittenauer 2.6k Jan 01, 2023
PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition

PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition The unofficial code of CDistNet. Now, we ha

25 Jul 20, 2022
[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Fudan Zhang Vision Group 897 Jan 05, 2023
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
My published benchmark for a Kaggle Simulations Competition

Lux AI Working Title Bot Please refer to the Kaggle notebook for the comment section. The comment section contains my explanation on my code structure

Tong Hui Kang 29 Aug 22, 2022
Deep Compression for Dense Point Cloud Maps.

DEPOCO This repository implements the algorithms described in our paper Deep Compression for Dense Point Cloud Maps. How to get started (using Docker)

Photogrammetry & Robotics Bonn 67 Dec 06, 2022
darija <-> english dictionary

darija-dictionary Having advanced IT solutions that are well adapted to the Moroccan context passes inevitably through understanding Moroccan dialect.

DODa 102 Jan 01, 2023
Temporally Efficient Vision Transformer for Video Instance Segmentation, CVPR 2022, Oral

Temporally Efficient Vision Transformer for Video Instance Segmentation Temporally Efficient Vision Transformer for Video Instance Segmentation (CVPR

Hust Visual Learning Team 203 Dec 31, 2022
Deep Learning to Improve Breast Cancer Detection on Screening Mammography

Shield: This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. Deep Learning to Improve Breast

Li Shen 305 Jan 03, 2023
Dynamic hair modeling from monocular videos using deep neural networks

Dynamic Hair Modeling The source code of the networks for our paper "Dynamic hair modeling from monocular videos using deep neural networks" (SIGGRAPH

53 Oct 18, 2022
Workshop Materials Delivered on 28/02/2022

intro-to-cnn-p1 Repo for hosting workshop materials delivered on 28/02/2022 Questions you will answer in this workshop Learning Objectives What are co

Beginners Machine Learning 5 Feb 28, 2022
Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

This repository contains the code release for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. This implementation is written in JAX, and is a fork of Google's JaxNeRF

Google 625 Dec 30, 2022