Official code for paper "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight"

Overview

Demysitifing Local Vision Transformer, arxiv

This is the official PyTorch implementation of our paper. We simply replace local self attention by (dynamic) depth-wise convolution with lower computational cost. The performance is on par with the Swin Transformer.

Besides, the main contribution of our paper is the theorical and detailed comparison between depth-wise convolution and local self attention from three aspects: sparse connectivity, weight sharing and dynamic weight. By this paper, we want community to rethinking the local self attention and depth-wise convolution, and the basic model architeture designing rules.

Codes and models for object detection and semantic segmentation are avaliable in Detection and Segmentation.

We also give MLP based Swin Transformer models and Inhomogenous dynamic convolution in the ablation studies. These codes and models will coming soon.

Reference

@article{han2021demystifying,
  title={Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight},
  author={Han, Qi and Fan, Zejia and Dai, Qi and Sun, Lei and Cheng, Ming-Ming and Liu, Jiaying and Wang, Jingdong},
  journal={arXiv preprint arXiv:2106.04263},
  year={2021}
}

1. Requirements

torch>=1.5.0, torchvision, timm, pyyaml; apex-amp

data perpare: ImageNet dataset with the following structure:

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

2. Trainning

For tiny model, we train with batch-size 128 on 8 GPUs. When trainning base model, we use batch-size 64 on 16 GPUs with OpenMPI to keep the total batch-size unchanged. (With the same trainning setting, the base model couldn't train with AMP due to the anomalous gradient values.)

Please change the data path in sh scripts first.

For tiny model:

bash scripts/run_dwnet_tiny_patch4_window7_224.sh 
bash scripts/run_dynamic_dwnet_tiny_patch4_window7_224.sh

For base model, use multi node with OpenMPI:

bash scripts/run_dwnet_base_patch4_window7_224.sh 
bash scripts/run_dynamic_dwnet_base_patch4_window7_224.sh

3. Evaluation

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/change_to_config_file --resume /path/to/model --data-path /path/to/imagenet --eval

4. Models

Models are provided by training on ImageNet with resolution 224.

Model #params FLOPs Top1 Acc Download
dwnet_tiny 24M 3.8G 81.2 github
dynamic_dwnet_tiny 51M 3.8G 81.8 github
dwnet_base 74M 12.9G 83.2 github
dynamic_dwnet_base 162M 13.0G 83.2 github

Detection (see Detection for details):

Backbone Pretrain Lr Schd box mAP mask mAP #params FLOPs config model
DWNet-T ImageNet-1K 3x 49.9 43.4 82M 730G config github
DWNet-B ImageNet-1K 3x 51.0 44.1 132M 924G config github
Dynamic-DWNet-T ImageNet-1K 3x 50.5 43.7 108M 730G config github
Dynamic-DWNet-B ImageNet-1K 3x 51.2 44.4 219M 924G config github

Segmentation (see Segmentation for details):

Backbone Pretrain Lr Schd mIoU #params FLOPs config model
DWNet-T ImageNet-1K 160K 45.5 56M 928G config github
DWNet-B ImageNet-1K 160K 48.3 108M 1129G config github
Dynamic-DWNet-T ImageNet-1K 160K 45.7 83M 928G config github
Dynamic-DWNet-B ImageNet-1K 160K 48.0 195M 1129G config github

LICENSE

This repo is under the MIT license. Some codes are borrow from Swin Transformer.

You might also like...
Official code repository of the paper Learning Associative Inference Using Fast Weight Memory by Schlag et al.

Learning Associative Inference Using Fast Weight Memory This repository contains the offical code for the paper Learning Associative Inference Using F

Official PyTorch code for CVPR 2020 paper
Official PyTorch code for CVPR 2020 paper "Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision"

Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision https://arxiv.org/abs/2003.00393 Abstract Active learning (AL) aims to min

Official Code for ICML 2021 paper
Official Code for ICML 2021 paper "Revisiting Point Cloud Shape Classification with a Simple and Effective Baseline"

Revisiting Point Cloud Shape Classification with a Simple and Effective Baseline Ankit Goyal, Hei Law, Bowei Liu, Alejandro Newell, Jia Deng Internati

CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

selfcontact This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] It includes the main function

CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

SMPLify-XMC This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright Lic

Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

This is the official code of our paper
This is the official code of our paper "Diversity-based Trajectory and Goal Selection with Hindsight Experience Relay" (PRICAI 2021)

Diversity-based Trajectory and Goal Selection with Hindsight Experience Replay This is the official implementation of our paper "Diversity-based Traje

The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".

R2D2 This is the official code for paper titled "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Mode

Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Comments
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • Possible bug?

    Possible bug?

    class DynamicDWConv(nn.Module):
        def __init__(self, dim, kernel_size, bias=True, stride=1, padding=1, groups=1, reduction=4):
            super().__init__()
            self.dim = dim
            self.kernel_size = kernel_size
            self.stride = stride 
            self.padding = padding 
            self.groups = groups 
    
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
            self.conv1 = nn.Conv2d(dim, dim // reduction, 1, bias=False)
            self.bn = nn.BatchNorm2d(dim // reduction)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(dim // reduction, dim * kernel_size * kernel_size, 1)
            if bias:
                self.bias = nn.Parameter(torch.zeros(dim))
            else:
                self.bias = None
    
        def forward(self, x):
            b, c, h, w = x.shape
            weight = self.conv2(self.relu(self.bn(self.conv1(self.pool(x)))))
            weight = weight.view(b * self.dim, 1, self.kernel_size, self.kernel_size)
            x = F.conv2d(x.reshape(1, -1, h, w), weight, self.bias.repeat(b), stride=self.stride, padding=self.padding, groups=b * self.groups)
            x = x.view(b, c, x.shape[-2], x.shape[-1])
            return x
    

    This function seems to give error when groups is not equal to dim.

    opened by yxchng 0
Owner
Attention for Vision and Visualization
Implementation of paper "DCS-Net: Deep Complex Subtractive Neural Network for Monaural Speech Enhancement"

DCS-Net This is the implementation of "DCS-Net: Deep Complex Subtractive Neural Network for Monaural Speech Enhancement" Steps to run the model Edit V

Jack Walters 10 Apr 04, 2022
This is a Keras implementation of a CNN for estimating age, gender and mask from a camera.

face-detector-age-gender This is a Keras implementation of a CNN for estimating age, gender and mask from a camera. Before run face detector app, expr

Devdreamsolution 2 Dec 04, 2021
PyStan, a Python interface to Stan, a platform for statistical modeling. Documentation: https://pystan.readthedocs.io

PyStan NOTE: This documentation describes a BETA release of PyStan 3. PyStan is a Python interface to Stan, a package for Bayesian inference. Stan® is

Stan 229 Dec 29, 2022
A Fast Knowledge Distillation Framework for Visual Recognition

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Zhiqiang Shen 129 Dec 24, 2022
Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

Self Supervised Learning with Fastai Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks. Install pip install self-

Kerem Turgutlu 276 Dec 23, 2022
CSAC - Collaborative Semantic Aggregation and Calibration for Separated Domain Generalization

CSAC Introduction This repository contains the implementation code for paper: Co

ScottYuan 5 Jul 22, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
PED: DETR for Crowd Pedestrian Detection

PED: DETR for Crowd Pedestrian Detection Code for PED: DETR For (Crowd) Pedestrian Detection Paper PED: DETR for Crowd Pedestrian Detection Installati

36 Sep 13, 2022
A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation".

Dual-Contrastive-Learning A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation". Y

hoshi-hiyouga 85 Dec 26, 2022
Code for the ICCV2021 paper "Personalized Image Semantic Segmentation"

PSS: Personalized Image Semantic Segmentation Paper PSS: Personalized Image Semantic Segmentation Yu Zhang, Chang-Bin Zhang, Peng-Tao Jiang, Ming-Ming

张宇 15 Jul 09, 2022
Classification of EEG data using Deep Learning

Graduation-Project Classification of EEG data using Deep Learning Epilepsy is the most common neurological disease in the world. Epilepsy occurs as a

Osman Alpaydın 5 Jun 24, 2022
DISTIL: Deep dIverSified inTeractIve Learning.

DISTIL: Deep dIverSified inTeractIve Learning. An active/inter-active learning library built on py-torch for reducing labeling costs.

decile-team 110 Dec 06, 2022
Code for ICLR 2021 Paper, "Anytime Sampling for Autoregressive Models via Ordered Autoencoding"

Anytime Autoregressive Model Anytime Sampling for Autoregressive Models via Ordered Autoencoding , ICLR 21 Yilun Xu, Yang Song, Sahaj Gara, Linyuan Go

Yilun Xu 22 Sep 08, 2022
Differentiable Optimizers with Perturbations in Pytorch

Differentiable Optimizers with Perturbations in PyTorch This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tens

Jake Tuero 54 Jun 22, 2022
Code for Fully Context-Aware Image Inpainting with a Learned Semantic Pyramid

SPN: Fully Context-Aware Image Inpainting with a Learned Semantic Pyramid Code for Fully Context-Aware Image Inpainting with a Learned Semantic Pyrami

12 Jun 27, 2022
Automated Melanoma Recognition in Dermoscopy Images via Very Deep Residual Networks

Introduction This repository contains the modified caffe library and network architectures for our paper "Automated Melanoma Recognition in Dermoscopy

Lequan Yu 47 Nov 24, 2022
The source code for CATSETMAT: Cross Attention for Set Matching in Bipartite Hypergraphs

catsetmat The source code for CATSETMAT: Cross Attention for Set Matching in Bipartite Hypergraphs To be able to run it, add catsetmat to PYTHONPATH H

2 Dec 19, 2022
Embeddinghub is a database built for machine learning embeddings.

Embeddinghub is a database built for machine learning embeddings.

Featureform 1.2k Jan 01, 2023
The final project for "Applying AI to Wearable Device Data" course from "AI for Healthcare" - Udacity.

Motion Compensated Pulse Rate Estimation Overview This project has 2 main parts. Develop a Pulse Rate Algorithm on the given training data. Then Test

Omar Laham 2 Oct 25, 2022
Video Autoencoder: self-supervised disentanglement of 3D structure and motion

Video Autoencoder: self-supervised disentanglement of 3D structure and motion This repository contains the code (in PyTorch) for the model introduced

157 Dec 22, 2022