CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

Overview

CSWin-Transformer

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". The code and models for downstream tasks are coming soon.

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution [email protected] #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.4 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.0 50.8 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 50.8 51.7 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

Comments
  • About the patches_resolution of the segmentation model

    About the patches_resolution of the segmentation model

    Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

    opened by danczs 2
  • The results of downstream task by my realization are poor

    The results of downstream task by my realization are poor

    There are some questions:

    1. the split size is still [1 2 7 7]?
    2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
    3. pading is right in my realization ? pad_l = pad_t = 0 pad_r = (W_sp - W % W_sp) % W_sp pad_b = (H_sp - H % H_sp) % H_sp q = q.transpose(-2,-1).contiguous().view(B, H, W, C) k = q.transpose(-2,-1).contiguous().view(B, H, W, C) v = q.transpose(-2,-1).contiguous().view(B, H, W, C) if pad_r > 0 or pad_b > 0: q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b)) k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b)) v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = q.shape
    opened by Sunting78 2
  • Experiment setting for semantic segmentation

    Experiment setting for semantic segmentation

    Hi, thank you for the code. I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used? Thanks

    opened by AnukritiSinghh 2
  • CSWin significantly slower than Swin?

    CSWin significantly slower than Swin?

    Greetings,

    From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

    I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

    opened by ErenBalatkan 2
  • Pretrained settings for object detection

    Pretrained settings for object detection

    Hi, I'm impressed by your excellent work.

    I have a question.

    I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

    I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

    opened by youngwanLEE 2
  • Error about building 384 models

    Error about building 384 models

    The code: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L407 shoud be:

    model = CSWinTransformer(img_size=384, patch_size=4, embed_dim=96, depth=[2,4,32,2],
    

    And as the same: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L414

    opened by TingquanGao 2
  • The problem is shown in the figure

    The problem is shown in the figure

    model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval() inp = torch.rand(1, 3, 224, 224).cuda() outs = model(inp) for out in outs: print(out.shape)

    RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

    why? image

    opened by rui-cf 2
  • about the test result on imagenet-2012

    about the test result on imagenet-2012

    Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

    DEFAULT_CROP_SIZE = 0.9
    scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
    transform = transforms.Compose(
            [
                transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
                if image_size == 224
                else transforms.Resize(image_size, interpolation=3),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )
    

    I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

    opened by rentainhe 1
  • CSwin Code for Segmentation with MMSegmentation

    CSwin Code for Segmentation with MMSegmentation

    Hi, Thank you for your work! I wonder if you plan to release the mmsegmentation code you used for the downstream segmentation task, just like in Swin-Transformer repo.

    Best,

    opened by WalBouss 1
  • How do you produce Table 9  (ablation on different attention mecahnisms) in the  paper?

    How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

    Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

    1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
    2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?
    opened by rayleizhu 0
  • about the setting of --use_chk

    about the setting of --use_chk

    give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

    opened by go-ahead-maker 0
  • Input image normalization parameters for semantic segmentation.

    Input image normalization parameters for semantic segmentation.

    Hey, guys, cool work!

    Unfortunately, for me it is not quite obvious, which normalization parameters (mean, std) did you use, when trained semantic segmentation model on ADE20K. Are they still IMAGENET_DEFAULT_MEAN or you used another values?

    opened by NikAleksFed 0
  • Using transfer to train over food101

    Using transfer to train over food101

    Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

    For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

    model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True) chk_path = './pretrained/cswin_tiny_224.pth' load_checkpoint(model, chk_path) model.reset_classifier(101, 'max')

    These are some of the runs I tried ` bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

    bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base `

    Is there something I'm missing or a proper way I should try this?

    Thanks in advance for any help! :)

    opened by andreynz691 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 03, 2022
Asterisk is a framework to generate high-quality training datasets at scale

Asterisk is a framework to generate high-quality training datasets at scale

Mona Nashaat 44 Apr 25, 2022
Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Ian Pointer 368 Dec 17, 2022
ViSER: Video-Specific Surface Embeddings for Articulated 3D Shape Reconstruction

ViSER: Video-Specific Surface Embeddings for Articulated 3D Shape Reconstruction. NeurIPS 2021.

Gengshan Yang 59 Nov 25, 2022
A Python implementation of global optimization with gaussian processes.

Bayesian Optimization Pure Python implementation of bayesian global optimization with gaussian processes. PyPI (pip): $ pip install bayesian-optimizat

fernando 6.5k Jan 02, 2023
E2e music remastering system - End-to-end Music Remastering System Using Self-supervised and Adversarial Training

End-to-end Music Remastering System This repository includes source code and pre

Junghyun (Tony) Koo 37 Dec 15, 2022
Implement slightly different caffe-segnet in tensorflow

Tensorflow-SegNet Implement slightly different (see below for detail) SegNet in tensorflow, successfully trained segnet-basic in CamVid dataset. Due t

Tseng Kuan Lun 364 Oct 27, 2022
A trashy useless Latin programming language written in python.

Codigum! The first programming langage in latin! (please keep your eyes closed when if you read the source code) It is pretty useless though. Document

Bic 2 Oct 25, 2021
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Unity Technologies 187 Dec 24, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and i

yifan liu 147 Dec 03, 2022
Tutorial repo for an end-to-end Data Science project

End-to-end Data Science project This is the repo with the notebooks, code, and additional material used in the ITI's workshop. The goal of the session

Deena Gergis 127 Dec 30, 2022
This is Unofficial Repo. Lips Don't Lie: A Generalisable and Robust Approach to Face Forgery Detection (CVPR 2021)

Lips Don't Lie: A Generalisable and Robust Approach to Face Forgery Detection This is a PyTorch implementation of the LipForensics paper. This is an U

Minha Kim 2 May 11, 2022
Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Non-Parametric Prior Actor-Critic (N-PPAC) This repository contains the code for On Pathologies in KL-Regularized Reinforcement Learning from Expert D

Cong Lu 5 May 13, 2022
Near-Optimal Sparse Allreduce for Distributed Deep Learning (published in PPoPP'22)

Near-Optimal Sparse Allreduce for Distributed Deep Learning (published in PPoPP'22) Ok-Topk is a scheme for distributed training with sparse gradients

Shigang Li 9 Oct 29, 2022
Implementation of the state of the art beat-detection, downbeat-detection and tempo-estimation model

The ISMIR 2020 Beat Detection, Downbeat Detection and Tempo Estimation Model Implementation. This is an implementation in TensorFlow to implement the

Koen van den Brink 1 Nov 12, 2021
[PNAS2021] The neural architecture of language: Integrative modeling converges on predictive processing

The neural architecture of language: Integrative modeling converges on predictive processing Code accompanying the paper The neural architecture of la

Martin Schrimpf 36 Dec 01, 2022
Code for NeurIPS 2021 paper: Invariant Causal Imitation Learning for Generalizable Policies

Invariant Causal Imitation Learning for Generalizable Policies Ioana Bica, Daniel Jarrett, Mihaela van der Schaar Neural Information Processing System

Ioana Bica 17 Dec 01, 2022
DilatedNet in Keras for image segmentation

Keras implementation of DilatedNet for semantic segmentation A native Keras implementation of semantic segmentation according to Multi-Scale Context A

303 Mar 15, 2022