DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Overview

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Cho-Jui Hsieh

This repository contains PyTorch implementation for DynamicViT.

We introduce a dynamic token sparsification framework to prune redundant tokens in vision transformers progressively and dynamically based on the input:

intro

Our code is based on pytorch-image-models, DeiT and LV-ViT

[Project Page] [arXiv]

Model Zoo

We provide our DynamicViT models pretrained on ImageNet:

name arch rho [email protected] [email protected] FLOPs url
DynamicViT-256/0.7 deit_256 0.7 76.532 93.118 1.3G Google Drive / Tsinghua Cloud
DynamicViT-384/0.7 deit_small 0.7 79.316 94.676 2.9G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.5 lvvit_s 0.5 81.970 95.756 3.7G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.7 lvvit_s 0.7 83.076 96.252 4.6G Google Drive / Tsinghua Cloud
DynamicViT-LV-M/0.7 lvvit_m 0.7 83.816 96.584 8.5G Google Drive / Tsinghua Cloud

Usage

Requirements

  • torch>=1.7.0
  • torchvision>=0.8.1
  • timm==0.4.5

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Model preparation: download pre-trained DeiT and LV-ViT models for training DynamicViT:

sh download_pretrain.sh

Demo

We provide a Jupyter notebook where you can run the visualization of DynamicViT.

To run the demo, you need to install matplotlib.

demo

Evaluation

To evaluate a pre-trained DynamicViT model on ImageNet val with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --model-path /path/to/model --base_rate 0.7 

Training

To train DynamicViT models on ImageNet, run:

DeiT-small

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_deit-small --arch deit_small --input-size 224 --batch-size 96 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-S

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-s --arch lvvit_s --input-size 224 --batch-size 64 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-M

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-m --arch lvvit_m --input-size 224 --batch-size 48 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

You can train models with different keeping ratio by adjusting base_rate. DynamicViT can also achieve comparable performance with only 15 epochs training (around 0.1% lower accuracy).

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021dynamicvit,
  title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
  author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
  journal={arXiv preprint arXiv:2106.02034},
  year={2021}
}
A general 3D Object Detection codebase in PyTorch.

Det3D is the first 3D Object Detection toolbox which provides off the box implementations of many 3D object detection algorithms such as PointPillars, SECOND, PIXOR, etc, as well as state-of-the-art

Benjin Zhu 1.4k Jan 05, 2023
Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol.

Updated Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol. Introduction This balenaCloud (previously

Remko 1 Oct 17, 2021
Pytorch implementation of MixNMatch

MixNMatch: Multifactor Disentanglement and Encoding for Conditional Image Generation [Paper] Yuheng Li, Krishna Kumar Singh, Utkarsh Ojha, Yong Jae Le

910 Dec 30, 2022
Python package to generate image embeddings with CLIP without PyTorch/TensorFlow

imgbeddings A Python package to generate embedding vectors from images, using OpenAI's robust CLIP model via Hugging Face transformers. These image em

Max Woolf 81 Jan 04, 2023
DNA-RECON { Automatic Web Reconnaissance Tool }

ABOUT TOOL : DNA-RECON is an automatic web reconnaissance tool written in python. This tool made for reconnaissance and information gathering with an

NIKUNJ BHATT 25 Aug 11, 2021
MARS: Learning Modality-Agnostic Representation for Scalable Cross-media Retrieva

Introduction This is the source code of our TCSVT 2021 paper "MARS: Learning Modality-Agnostic Representation for Scalable Cross-media Retrieval". Ple

7 Aug 24, 2022
DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers

DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers Authors: Jaemin Cho, Abhay Zala, and Mohit Bansal (

Jaemin Cho 98 Dec 15, 2022
Official pytorch implementation of DeformSyncNet: Deformation Transfer via Synchronized Shape Deformation Spaces

DeformSyncNet: Deformation Transfer via Synchronized Shape Deformation Spaces Minhyuk Sung*, Zhenyu Jiang*, Panos Achlioptas, Niloy J. Mitra, Leonidas

Zhenyu Jiang 21 Aug 30, 2022
An implementation of based on pytorch and mmcv

FisherPruning-Pytorch An implementation of Group Fisher Pruning for Practical Network Compression based on pytorch and mmcv Main Functions Pruning f

Peng Lu 15 Dec 17, 2022
The original weights of some Caffe models, ported to PyTorch.

pytorch-caffe-models This repo contains the original weights of some Caffe models, ported to PyTorch. Currently there are: GoogLeNet (Going Deeper wit

Katherine Crowson 9 Nov 04, 2022
Pytorch implementation of the AAAI 2022 paper "Cross-Domain Empirical Risk Minimization for Unbiased Long-tailed Classification"

[AAAI22] Cross-Domain Empirical Risk Minimization for Unbiased Long-tailed Classification We point out the overlooked unbiasedness in long-tailed clas

PatatiPatata 28 Oct 18, 2022
Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

Bayesian Methods Research Group 56 Nov 15, 2022
Keras implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Keras implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Yam Peleg 63 Sep 21, 2022
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
PyTorch implementation of a collections of scalable Video Transformer Benchmarks.

PyTorch implementation of Video Transformer Benchmarks This repository is mainly built upon Pytorch and Pytorch-Lightning. We wish to maintain a colle

Xin Ma 156 Jan 08, 2023
CS583: Deep Learning

CS583: Deep Learning

Shusen Wang 2.6k Dec 30, 2022
StellarGraph - Machine Learning on Graphs

StellarGraph Machine Learning Library StellarGraph is a Python library for machine learning on graphs and networks. Table of Contents Introduction Get

S T E L L A R 2.6k Jan 05, 2023
A python script to lookup Passport Index Dataset

visa-cli A python script to lookup Passport Index Dataset Installation pip install visa-cli Usage usage: visa-cli [-h] [-d DESTINATION_COUNTRY] [-f]

rand-net 16 Oct 18, 2022
Convert onnx models to pytorch.

onnx2torch onnx2torch is an ONNX to PyTorch converter. Our converter: Is easy to use – Convert the ONNX model with the function call convert; Is easy

ENOT 264 Dec 30, 2022
tsflex - feature-extraction benchmarking

tsflex - feature-extraction benchmarking This repository withholds the benchmark results and visualization code of the tsflex paper and toolkit. Flow

PreDiCT.IDLab 5 Mar 25, 2022