The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

Related tags

Deep LearningELSA
Overview

ELSA: Enhanced Local Self-Attention for Vision Transformer

By Jingkai Zhou, Pichao Wang*, Fan Wang, Qiong Liu, Hao Li, Rong Jin

This repo is the official implementation of "ELSA: Enhanced Local Self-Attention for Vision Transformer".

Introduction

Self-attention is powerful in modeling long-range dependencies, but it is weak in local finer-level feature learning. As shown in Figure 1, the performance of local self-attention (LSA) is just on par with convolution and inferior to dynamic filters, which puzzles researchers on whether to use LSA or its counterparts, which one is better, and what makes LSA mediocre. In this work, we comprehensively investigate LSA and its counterparts. We find that the devil lies in the generation and application of spatial attention.

Based on these findings, we propose the enhanced local self-attention (ELSA) with Hadamard attention and the ghost head, as illustrated in Figure 2. Experiments demonstrate the effectiveness of ELSA. Without architecture / hyperparameter modification, The use of ELSA in drop-in replacement boosts baseline methods consistently in both upstream and downstream tasks.

Please refer to our paper for more details.

Model zoo

ImageNet Classification

Model #Params Pretrain Resolution Top1 Acc Download
ELSA-Swin-T 28M ImageNet 1K 224 82.7 google / baidu
ELSA-Swin-S 53M ImageNet 1K 224 83.5 google / baidu
ELSA-Swin-B 93M ImageNet 1K 224 84.0 google / baidu

COCO Object Detection

Backbone Method Pretrain Lr Schd Box mAP Mask mAP #Params Download
ELSA-Swin-T Mask R-CNN ImageNet-1K 1x 45.7 41.1 49M google / baidu
ELSA-Swin-T Mask R-CNN ImageNet-1K 3x 47.5 42.7 49M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 1x 48.3 43.0 72M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 3x 49.2 43.6 72M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 1x 49.8 43.0 86M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 3x 51.0 44.2 86M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 1x 51.6 44.4 110M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 3x 52.3 45.2 110M google / baidu

ADE20K Semantic Segmentation

Backbone Method Pretrain Crop Size Lr Schd mIoU (ms+flip) #Params Download
ELSA-Swin-T UPerNet ImageNet-1K 512x512 160K 47.9 61M google / baidu
ELSA-Swin-S UperNet ImageNet-1K 512x512 160K 50.4 85M google / baidu

Install

  • Clone this repo:
git clone https://github.com/damo-cv/ELSA.git elsa
cd elsa
  • Create a conda virtual environment and activate it:
conda create -n elsa python=3.7 -y
conda activate elsa
  • Install PyTorch==1.8.0 and torchvision==0.9.0 with CUDA==10.1:
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.1 -c pytorch
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
  • Install mmcv-full==1.3.0
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
  • Install other requirements:
pip install -r requirements.txt
  • Install mmdet and mmseg:
cd ./det
pip install -v -e .
cd ../seg
pip install -v -e .
cd ../
  • Build the elsa operation:
cd ./cls/models/elsa
python setup.py install
mv build/lib*/* .
cp *.so ../../../det/mmdet/models/backbones/elsa/
cp *.so ../../../seg/mmseg/models/backbones/elsa/
cd ../../../

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. Please prepare it under the following file structure:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Also, please prepare the COCO and ADE20K datasets following their links. Then, please link them to det/data and seg/data.

Evaluation

ImageNet Classification

Run following scripts to evaluate pre-trained models on the ImageNet-1K:

cd cls

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_tiny --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_small --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_base --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128 --use-ema

COCO Detection

Run following scripts to evaluate a detector on the COCO:

cd det

# single-gpu testing
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm

ADE20K Semantic Segmentation

Run following scripts to evaluate a model on the ADE20K:

cd seg

# single-gpu testing
python tools/test.py <CONFIG_FILE> <SEG_CHECKPOINT_FILE> --aug-test --eval mIoU

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <SEG_CHECKPOINT_FILE> <GPU_NUM> --aug-test --eval mIoU

Training from scratch

Due to randomness, the re-training results may have a gap of about 0.1~0.2% with the numbers in the paper.

ImageNet Classification

Run following scripts to train classifiers on the ImageNet-1K:

cd cls

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_tiny \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.1 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_small \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.3 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_base \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.5 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp --model-ema

If GPU memory is not enough when training elsa_swin_base, you can use two nodes (2 * 8 GPUs), each with a batch size of 64 images/GPU.

COCO Detection / ADE20K Semantic Segmentation

Run following scripts to train models on the COCO / ADE20K:

cd det 
# (or cd seg)

# multi-gpu training
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [model.backbone.use_checkpoint=True] [other optional arguments] 

Acknowledgement

This work was supported by Alibaba Group through Alibaba Research Intern Program and the National Natural Science Foundation of China (No.61976094).

Codebase from pytorch-image-models, ddfnet, VOLO, Swin-Transformer, Swin-Transformer-Detection, and Swin-Transformer-Semantic-Segmentation

Citing ELSA

@article{zhou2021ELSA,
  title={ELSA: Enhanced Local Self-Attention for Vision Transformer},
  author={Zhou, Jingkai and Wang, Pichao and Wang, Fan and Liu, Qiong and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2112.12786},
  year={2021}
}
Owner
DamoCV
CV team of DAMO academy
DamoCV
Portfolio asset allocation strategies: from Markowitz to RNNs

Portfolio asset allocation strategies: from Markowitz to RNNs Research project to explore different approaches for optimal portfolio allocation starti

Luigi Filippo Chiara 1 Feb 05, 2022
Code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2021

The repo provides the code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2

Yuning Mao 18 May 24, 2022
Implementation for the EMNLP 2021 paper "Interactive Machine Comprehension with Dynamic Knowledge Graphs".

Interactive Machine Comprehension with Dynamic Knowledge Graphs Implementation for the EMNLP 2021 paper. Dependencies apt-get -y update apt-get instal

Xingdi (Eric) Yuan 19 Aug 23, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 03, 2023
Framework for joint representation learning, evaluation through multimodal registration and comparison with image translation based approaches

CoMIR: Contrastive Multimodal Image Representation for Registration Framework 🖼 Registration of images in different modalities with Deep Learning 🤖

Methods for Image Data Analysis - MIDA 55 Dec 09, 2022
A little software to generate and save Julia or Mandelbrot's Fractals.

Julia-Mandelbrot-s-Fractals A little software to generate and save Julia or Mandelbrot's Fractals. Dependencies : Python 3.7 or more. (Also possible t

Olivier 0 Jul 09, 2022
[ArXiv 2021] Data-Efficient Instance Generation from Instance Discrimination

InsGen - Data-Efficient Instance Generation from Instance Discrimination Data-Efficient Instance Generation from Instance Discrimination Ceyuan Yang,

GenForce: May Generative Force Be with You 93 Dec 25, 2022
Explanatory Learning: Beyond Empiricism in Neural Networks

Explanatory Learning This is the official repository for "Explanatory Learning: Beyond Empiricism in Neural Networks". Datasets Download the datasets

GLADIA Research Group 10 Dec 06, 2022
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

880 Jan 07, 2023
Training DiffWave using variational method from Variational Diffusion Models.

Variational DiffWave Training DiffWave using variational method from Variational Diffusion Models. Quick Start python train_distributed.py discrete_10

Chin-Yun Yu 26 Dec 13, 2022
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
[CVPR'21] Multi-Modal Fusion Transformer for End-to-End Autonomous Driving

TransFuser This repository contains the code for the CVPR 2021 paper Multi-Modal Fusion Transformer for End-to-End Autonomous Driving. If you find our

695 Jan 05, 2023
[TIP2020] Adaptive Graph Representation Learning for Video Person Re-identification

Introduction This is the PyTorch implementation for Adaptive Graph Representation Learning for Video Person Re-identification. Get started git clone h

WuYiming 41 Dec 12, 2022
A simple Tensorflow based library for deep and/or denoising AutoEncoder.

libsdae - deep-Autoencoder & denoising autoencoder A simple Tensorflow based library for Deep autoencoder and denoising AE. Library follows sklearn st

Rajarshee Mitra 147 Nov 18, 2022
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
Notspot robot simulation - Python version

Notspot robot simulation - Python version This repository contains all the files and code needed to simulate the notspot quadrupedal robot using Gazeb

50 Sep 26, 2022
Example for AUAV 2022 with obstacle avoidance.

AUAV 2022 Sample This is a sample PX4 based quadrotor path planning framework based on Ubuntu 20.04 and ROS noetic for the IEEE Autonomous UAS 2022 co

James Goppert 11 Sep 16, 2022
Attendance Monitoring with Face Recognition using Python

Attendance Monitoring with Face Recognition using Python A python GUI integrated attendance system using face recognition to take attendance. In this

Vaibhav Rajput 2 Jun 21, 2022
Code release for NeX: Real-time View Synthesis with Neural Basis Expansion

NeX: Real-time View Synthesis with Neural Basis Expansion Project Page | Video | Paper | COLAB | Shiny Dataset We present NeX, a new approach to novel

536 Dec 20, 2022
Unofficial Alias-Free GAN implementation. Based on rosinality's version with expanded training and inference options.

Alias-Free GAN An unofficial version of Alias-Free Generative Adversarial Networks (https://arxiv.org/abs/2106.12423). This repository was heavily bas

dusk (they/them) 75 Dec 12, 2022