The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

Related tags

Deep LearningELSA
Overview

ELSA: Enhanced Local Self-Attention for Vision Transformer

By Jingkai Zhou, Pichao Wang*, Fan Wang, Qiong Liu, Hao Li, Rong Jin

This repo is the official implementation of "ELSA: Enhanced Local Self-Attention for Vision Transformer".

Introduction

Self-attention is powerful in modeling long-range dependencies, but it is weak in local finer-level feature learning. As shown in Figure 1, the performance of local self-attention (LSA) is just on par with convolution and inferior to dynamic filters, which puzzles researchers on whether to use LSA or its counterparts, which one is better, and what makes LSA mediocre. In this work, we comprehensively investigate LSA and its counterparts. We find that the devil lies in the generation and application of spatial attention.

Based on these findings, we propose the enhanced local self-attention (ELSA) with Hadamard attention and the ghost head, as illustrated in Figure 2. Experiments demonstrate the effectiveness of ELSA. Without architecture / hyperparameter modification, The use of ELSA in drop-in replacement boosts baseline methods consistently in both upstream and downstream tasks.

Please refer to our paper for more details.

Model zoo

ImageNet Classification

Model #Params Pretrain Resolution Top1 Acc Download
ELSA-Swin-T 28M ImageNet 1K 224 82.7 google / baidu
ELSA-Swin-S 53M ImageNet 1K 224 83.5 google / baidu
ELSA-Swin-B 93M ImageNet 1K 224 84.0 google / baidu

COCO Object Detection

Backbone Method Pretrain Lr Schd Box mAP Mask mAP #Params Download
ELSA-Swin-T Mask R-CNN ImageNet-1K 1x 45.7 41.1 49M google / baidu
ELSA-Swin-T Mask R-CNN ImageNet-1K 3x 47.5 42.7 49M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 1x 48.3 43.0 72M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 3x 49.2 43.6 72M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 1x 49.8 43.0 86M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 3x 51.0 44.2 86M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 1x 51.6 44.4 110M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 3x 52.3 45.2 110M google / baidu

ADE20K Semantic Segmentation

Backbone Method Pretrain Crop Size Lr Schd mIoU (ms+flip) #Params Download
ELSA-Swin-T UPerNet ImageNet-1K 512x512 160K 47.9 61M google / baidu
ELSA-Swin-S UperNet ImageNet-1K 512x512 160K 50.4 85M google / baidu

Install

  • Clone this repo:
git clone https://github.com/damo-cv/ELSA.git elsa
cd elsa
  • Create a conda virtual environment and activate it:
conda create -n elsa python=3.7 -y
conda activate elsa
  • Install PyTorch==1.8.0 and torchvision==0.9.0 with CUDA==10.1:
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.1 -c pytorch
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
  • Install mmcv-full==1.3.0
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
  • Install other requirements:
pip install -r requirements.txt
  • Install mmdet and mmseg:
cd ./det
pip install -v -e .
cd ../seg
pip install -v -e .
cd ../
  • Build the elsa operation:
cd ./cls/models/elsa
python setup.py install
mv build/lib*/* .
cp *.so ../../../det/mmdet/models/backbones/elsa/
cp *.so ../../../seg/mmseg/models/backbones/elsa/
cd ../../../

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. Please prepare it under the following file structure:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Also, please prepare the COCO and ADE20K datasets following their links. Then, please link them to det/data and seg/data.

Evaluation

ImageNet Classification

Run following scripts to evaluate pre-trained models on the ImageNet-1K:

cd cls

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_tiny --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_small --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_base --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128 --use-ema

COCO Detection

Run following scripts to evaluate a detector on the COCO:

cd det

# single-gpu testing
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm

ADE20K Semantic Segmentation

Run following scripts to evaluate a model on the ADE20K:

cd seg

# single-gpu testing
python tools/test.py <CONFIG_FILE> <SEG_CHECKPOINT_FILE> --aug-test --eval mIoU

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <SEG_CHECKPOINT_FILE> <GPU_NUM> --aug-test --eval mIoU

Training from scratch

Due to randomness, the re-training results may have a gap of about 0.1~0.2% with the numbers in the paper.

ImageNet Classification

Run following scripts to train classifiers on the ImageNet-1K:

cd cls

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_tiny \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.1 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_small \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.3 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_base \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.5 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp --model-ema

If GPU memory is not enough when training elsa_swin_base, you can use two nodes (2 * 8 GPUs), each with a batch size of 64 images/GPU.

COCO Detection / ADE20K Semantic Segmentation

Run following scripts to train models on the COCO / ADE20K:

cd det 
# (or cd seg)

# multi-gpu training
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [model.backbone.use_checkpoint=True] [other optional arguments] 

Acknowledgement

This work was supported by Alibaba Group through Alibaba Research Intern Program and the National Natural Science Foundation of China (No.61976094).

Codebase from pytorch-image-models, ddfnet, VOLO, Swin-Transformer, Swin-Transformer-Detection, and Swin-Transformer-Semantic-Segmentation

Citing ELSA

@article{zhou2021ELSA,
  title={ELSA: Enhanced Local Self-Attention for Vision Transformer},
  author={Zhou, Jingkai and Wang, Pichao and Wang, Fan and Liu, Qiong and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2112.12786},
  year={2021}
}
Owner
DamoCV
CV team of DAMO academy
DamoCV
Pytorch implementation of few-shot semantic image synthesis

Few-shot Semantic Image Synthesis Using StyleGAN Prior Our method can synthesize photorealistic images from dense or sparse semantic annotations using

40 Sep 26, 2022
Fast Learning of MNL Model From General Partial Rankings with Application to Network Formation Modeling

Fast-Partial-Ranking-MNL This repo provides a PyTorch implementation for the CopulaGNN models as described in the following paper: Fast Learning of MN

Xingjian Zhang 3 Aug 19, 2022
The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

Will Thompson 166 Jan 04, 2023
Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

High-Performance Brain-to-Text Communication via Handwriting Overview This repo is associated with this manuscript, preprint and dataset. The code can

Francis R. Willett 306 Jan 03, 2023
DiffStride: Learning strides in convolutional neural networks

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initiali

Google Research 113 Dec 13, 2022
Official Pytorch implementation for AAAI2021 paper (RSPNet: Relative Speed Perception for Unsupervised Video Representation Learning)

RSPNet Official Pytorch implementation for AAAI2021 paper "RSPNet: Relative Speed Perception for Unsupervised Video Representation Learning" [Suppleme

35 Jun 24, 2022
Neural Logic Inductive Learning

Neural Logic Inductive Learning This is the implementation of the Neural Logic Inductive Learning model (NLIL) proposed in the ICLR 2020 paper: Learn

36 Nov 28, 2022
ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation

ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation This repository provides a PyTorch implementation of ADSPM. Requirements Pyth

24 Jul 24, 2022
PyTorch implementation of normalizing flow models

PyTorch implementation of normalizing flow models

Vincent Stimper 242 Jan 02, 2023
Like a cowsay but without cows!

Foxsay This is a simple program that generates pictures of a cute fox with a message. It is like a cowsay but without cows! Fox girls are better! Usag

Anastasia Kim 28 Feb 20, 2022
Neighborhood Reconstructing Autoencoders

Neighborhood Reconstructing Autoencoders The official repository for Neighborhood Reconstructing Autoencoders (Lee, Kwon, and Park, NeurIPS 2021). T

Yonghyeon Lee 24 Dec 14, 2022
[CVPR 2022] Unsupervised Image-to-Image Translation with Generative Prior

GP-UNIT - Official PyTorch Implementation This repository provides the official PyTorch implementation for the following paper: Unsupervised Image-to-

Shuai Yang 125 Jan 03, 2023
Visual Tracking by TridenAlign and Context Embedding

Visual Tracking by TridentAlign and Context Embedding (TACT) Test code for "Visual Tracking by TridentAlign and Context Embedding" Janghoon Choi, Juns

Janghoon Choi 32 Aug 25, 2021
Simple and ready-to-use tutorials for TensorFlow

TensorFlow World To support maintaining and upgrading this project, please kindly consider Sponsoring the project developer. Any level of support is a

Amirsina Torfi 4.5k Dec 23, 2022
NHS AI Lab Skunkworks project: Long Stayer Risk Stratification

NHS AI Lab Skunkworks project: Long Stayer Risk Stratification A pilot project for the NHS AI Lab Skunkworks team, Long Stayer Risk Stratification use

NHSX 21 Nov 14, 2022
《LXMERT: Learning Cross-Modality Encoder Representations from Transformers》(EMNLP 2020)

The Most Important Thing. Our code is developed based on: LXMERT: Learning Cross-Modality Encoder Representations from Transformers

53 Dec 16, 2022
Summary Explorer is a tool to visually explore the state-of-the-art in text summarization.

Summary Explorer Summary Explorer is a tool to visually inspect the summaries from several state-of-the-art neural summarization models across multipl

Webis 42 Aug 14, 2022
An implementation of the efficient attention module.

Efficient Attention An implementation of the efficient attention module. Description Efficient attention is an attention mechanism that substantially

Shen Zhuoran 194 Dec 15, 2022
Pytorch implementation of "ARM: Any-Time Super-Resolution Method"

ARM-Net Dependencies Python 3.6 Pytorch 1.7 Results Train Data preprocessing cd data_scripts python extract_subimages_test.py python data_augmentation

Bohong Chen 55 Nov 24, 2022
MANO hand model porting for the GraspIt simulator

Learning Joint Reconstruction of Hands and Manipulated Objects - ManoGrasp Porting the MANO hand model to GraspIt! simulator Yana Hasson, Gül Varol, D

Lucas Wohlhart 10 Feb 08, 2022