The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

Related tags

Deep LearningELSA
Overview

ELSA: Enhanced Local Self-Attention for Vision Transformer

By Jingkai Zhou, Pichao Wang*, Fan Wang, Qiong Liu, Hao Li, Rong Jin

This repo is the official implementation of "ELSA: Enhanced Local Self-Attention for Vision Transformer".

Introduction

Self-attention is powerful in modeling long-range dependencies, but it is weak in local finer-level feature learning. As shown in Figure 1, the performance of local self-attention (LSA) is just on par with convolution and inferior to dynamic filters, which puzzles researchers on whether to use LSA or its counterparts, which one is better, and what makes LSA mediocre. In this work, we comprehensively investigate LSA and its counterparts. We find that the devil lies in the generation and application of spatial attention.

Based on these findings, we propose the enhanced local self-attention (ELSA) with Hadamard attention and the ghost head, as illustrated in Figure 2. Experiments demonstrate the effectiveness of ELSA. Without architecture / hyperparameter modification, The use of ELSA in drop-in replacement boosts baseline methods consistently in both upstream and downstream tasks.

Please refer to our paper for more details.

Model zoo

ImageNet Classification

Model #Params Pretrain Resolution Top1 Acc Download
ELSA-Swin-T 28M ImageNet 1K 224 82.7 google / baidu
ELSA-Swin-S 53M ImageNet 1K 224 83.5 google / baidu
ELSA-Swin-B 93M ImageNet 1K 224 84.0 google / baidu

COCO Object Detection

Backbone Method Pretrain Lr Schd Box mAP Mask mAP #Params Download
ELSA-Swin-T Mask R-CNN ImageNet-1K 1x 45.7 41.1 49M google / baidu
ELSA-Swin-T Mask R-CNN ImageNet-1K 3x 47.5 42.7 49M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 1x 48.3 43.0 72M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 3x 49.2 43.6 72M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 1x 49.8 43.0 86M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 3x 51.0 44.2 86M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 1x 51.6 44.4 110M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 3x 52.3 45.2 110M google / baidu

ADE20K Semantic Segmentation

Backbone Method Pretrain Crop Size Lr Schd mIoU (ms+flip) #Params Download
ELSA-Swin-T UPerNet ImageNet-1K 512x512 160K 47.9 61M google / baidu
ELSA-Swin-S UperNet ImageNet-1K 512x512 160K 50.4 85M google / baidu

Install

  • Clone this repo:
git clone https://github.com/damo-cv/ELSA.git elsa
cd elsa
  • Create a conda virtual environment and activate it:
conda create -n elsa python=3.7 -y
conda activate elsa
  • Install PyTorch==1.8.0 and torchvision==0.9.0 with CUDA==10.1:
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.1 -c pytorch
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
  • Install mmcv-full==1.3.0
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
  • Install other requirements:
pip install -r requirements.txt
  • Install mmdet and mmseg:
cd ./det
pip install -v -e .
cd ../seg
pip install -v -e .
cd ../
  • Build the elsa operation:
cd ./cls/models/elsa
python setup.py install
mv build/lib*/* .
cp *.so ../../../det/mmdet/models/backbones/elsa/
cp *.so ../../../seg/mmseg/models/backbones/elsa/
cd ../../../

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. Please prepare it under the following file structure:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Also, please prepare the COCO and ADE20K datasets following their links. Then, please link them to det/data and seg/data.

Evaluation

ImageNet Classification

Run following scripts to evaluate pre-trained models on the ImageNet-1K:

cd cls

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_tiny --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_small --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_base --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128 --use-ema

COCO Detection

Run following scripts to evaluate a detector on the COCO:

cd det

# single-gpu testing
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm

ADE20K Semantic Segmentation

Run following scripts to evaluate a model on the ADE20K:

cd seg

# single-gpu testing
python tools/test.py <CONFIG_FILE> <SEG_CHECKPOINT_FILE> --aug-test --eval mIoU

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <SEG_CHECKPOINT_FILE> <GPU_NUM> --aug-test --eval mIoU

Training from scratch

Due to randomness, the re-training results may have a gap of about 0.1~0.2% with the numbers in the paper.

ImageNet Classification

Run following scripts to train classifiers on the ImageNet-1K:

cd cls

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_tiny \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.1 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_small \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.3 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_base \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.5 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp --model-ema

If GPU memory is not enough when training elsa_swin_base, you can use two nodes (2 * 8 GPUs), each with a batch size of 64 images/GPU.

COCO Detection / ADE20K Semantic Segmentation

Run following scripts to train models on the COCO / ADE20K:

cd det 
# (or cd seg)

# multi-gpu training
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [model.backbone.use_checkpoint=True] [other optional arguments] 

Acknowledgement

This work was supported by Alibaba Group through Alibaba Research Intern Program and the National Natural Science Foundation of China (No.61976094).

Codebase from pytorch-image-models, ddfnet, VOLO, Swin-Transformer, Swin-Transformer-Detection, and Swin-Transformer-Semantic-Segmentation

Citing ELSA

@article{zhou2021ELSA,
  title={ELSA: Enhanced Local Self-Attention for Vision Transformer},
  author={Zhou, Jingkai and Wang, Pichao and Wang, Fan and Liu, Qiong and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2112.12786},
  year={2021}
}
Owner
DamoCV
CV team of DAMO academy
DamoCV
LQM - Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstract Object detection aims to locate and classify object instances in ima

IM Lab., POSTECH 0 Sep 28, 2022
Single Image Random Dot Stereogram for Tensorflow

TensorFlow-SIRDS Single Image Random Dot Stereogram for Tensorflow SIRDS is a means to present 3D data in a 2D image. It allows for scientific data di

Greg Peatfield 5 Aug 10, 2022
FluxTraining.jl gives you an endlessly extensible training loop for deep learning

A flexible neural net training library inspired by fast.ai

86 Dec 31, 2022
This is Official implementation for "Pose-guided Feature Disentangling for Occluded Person Re-Identification Based on Transformer" in AAAI2022

PFD:Pose-guided Feature Disentangling for Occluded Person Re-identification based on Transformer This repo is the official implementation of "Pose-gui

Tao Wang 93 Dec 18, 2022
an implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 985 Jan 08, 2023
MonoRCNN is a monocular 3D object detection method for automonous driving

MonoRCNN MonoRCNN is a monocular 3D object detection method for automonous driving, published at ICCV 2021. This project is an implementation of MonoR

87 Dec 27, 2022
A study project using the AA-RMVSNet to reconstruct buildings from multiple images

3d-building-reconstruction This is part of a study project using the AA-RMVSNet to reconstruct buildings from multiple images. Introduction It is exci

17 Oct 17, 2022
Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning

Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning. Circuit Training is an open-s

Google Research 479 Dec 25, 2022
[CVPR 2016] Unsupervised Feature Learning by Image Inpainting using GANs

Context Encoders: Feature Learning by Inpainting CVPR 2016 [Project Website] [Imagenet Results] Sample results on held-out images: This is the trainin

Deepak Pathak 829 Dec 31, 2022
A scientific and useful toolbox, which contains practical and effective long-tail related tricks with extensive experimental results

Bag of tricks for long-tailed visual recognition with deep convolutional neural networks This repository is the official PyTorch implementation of AAA

Yong-Shun Zhang 181 Dec 28, 2022
🐾 Semantic segmentation of paws from cute pet images (PyTorch)

🐾 paw-segmentation 🐾 Semantic segmentation of paws from cute pet images 🐾 Semantic segmentation of paws from cute pet images (PyTorch) 🐾 Paw Segme

Zabir Al Nazi Nabil 3 Feb 01, 2022
Official PyTorch code of DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context Graph and Relation-based Optimization (ICCV 2021 Oral).

DeepPanoContext (DPC) [Project Page (with interactive results)][Paper] DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context G

Cheng Zhang 66 Nov 16, 2022
PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick."

PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick." [Project page] [Paper

Gyungin Shin 59 Sep 25, 2022
A Python Package for Convex Regression and Frontier Estimation

pyStoNED pyStoNED is a Python package that provides functions for estimating multivariate convex regression, convex quantile regression, convex expect

Sheng Dai 17 Jan 08, 2023
[CVPR 2022] Pytorch implementation of "Templates for 3D Object Pose Estimation Revisited: Generalization to New objects and Robustness to Occlusions" paper

template-pose Pytorch implementation of "Templates for 3D Object Pose Estimation Revisited: Generalization to New objects and Robustness to Occlusions

Van Nguyen Nguyen 92 Dec 28, 2022
cl;asification problem using classification models in supervised learning

wine-quality-predition---classification cl;asification problem using classification models in supervised learning Wine Quality Prediction Analysis - C

Vineeth Reddy Gangula 1 Jan 18, 2022
Peek-a-Boo: What (More) is Disguised in a Randomly Weighted Neural Network, and How to Find It Efficiently

Peek-a-Boo: What (More) is Disguised in a Randomly Weighted Neural Network, and How to Find It Efficiently This repository is the official implementat

VITA 4 Dec 20, 2022
Data Consistency for Magnetic Resonance Imaging

Data Consistency for Magnetic Resonance Imaging Data Consistency (DC) is crucial for generalization in multi-modal MRI data and robustness in detectin

Dimitris Karkalousos 19 Dec 12, 2022
Our CIKM21 Paper "Incorporating Query Reformulating Behavior into Web Search Evaluation"

Reformulation-Aware-Metrics Introduction This codebase contains source-code of the Python-based implementation of our CIKM 2021 paper. Chen, Jia, et a

xuanyuan14 5 Mar 05, 2022
code for `Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation`

Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation (CVPR 2021) Introduction PBR is a conceptually simple yet effective

H.Chen 143 Jan 05, 2023