《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Related tags

Deep Learningpit
Overview

Rethinking Spatial Dimensions of Vision Transformers

Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper

NAVER AI LAB

teaser

Abstract

Vision Transformer (ViT) extends the application range of transformers from language processing to computer vision tasks as being an alternative architecture against the existing convolutional neural networks (CNN). Since the transformer-based architecture has been innovative for computer vision modeling, the design convention towards an effective architecture has been less studied yet. From the successful design principles of CNN, we investigate the role of the spatial dimension conversion and its effectiveness on the transformer-based architecture. We particularly attend the dimension reduction principle of CNNs; as the depth increases, a conventional CNN increases channel dimension and decreases spatial dimensions. We empirically show that such a spatial dimension reduction is beneficial to a transformer architecture as well, and propose a novel Pooling-based Vision Transformer (PiT) upon the original ViT model. We show that PiT achieves the improved model capability and generalization performance against ViT. Throughout the extensive experiments, we further show PiT outperforms the baseline on several tasks such as image classification, object detection and robustness evaluation.

Model performance

We compared performance of PiT with DeiT models in various training settings. Throughput (imgs/sec) values are measured in a machine with single V100 gpu with 128 batche size.

Network FLOPs # params imgs/sec Vanilla +CutMix +DeiT +Distill
DeiT-Ti 1.3 G 5.7 M 2564 68.7 68.5 72.2 74.5
PiT-Ti 0.71 G 4.9 M 3030 71.3 72.6 73.0 74.6
PiT-XS 1.4 G 10.6 M 2128 72.4 76.8 78.1 79.1
DeiT-S 4.6 G 22.1 M 980 68.7 76.5 79.8 81.2
PiT-S 2.9 G 23.5 M 1266 73.3 79.0 80.9 81.9
DeiT-B 17.6 G 86.6 M 303 69.3 75.3 81.8 83.4
PiT-B 12.5 G 73.8 M 348 76.1 79.9 82.0 84.0

Pretrained weights

Model name FLOPs accuracy weights
pit_ti 0.71 G 73.0 link
pit_xs 1.4 G 78.1 link
pit_s 2.9 G 80.9 link
pit_b 12.5 G 82.0 link
pit_ti_distilled 0.71 G 74.6 link
pit_xs_distilled 1.4 G 79.1 link
pit_s_distilled 2.9 G 81.9 link
pit_b_distilled 12.5 G 84.0 link

Dependancies

Our implementations are tested on following libraries with Python 3.6.9 and CUDA 10.1.

torch: 1.7.1
torchvision: 0.8.2
timm: 0.3.4
einops: 0.3.0

Install other dependencies using the following command.

pip install -r requirements.txt

How to use models

You can build PiT models directly

import torch
import pit

model = pit.pit_s(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

Or using timm function

import torch
import timm
import pit

model = timm.create_model('pit_s', pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

To use models trained with distillation, you should use _distilled model and weights.

import torch
import pit

model = pit.pit_s_distilled(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_distill_819.pth'))
print(model(torch.randn(1, 3, 224, 224)))

License

Copyright 2021-present NAVER Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Citation

@article{heo2021pit,
    title={Rethinking Spatial Dimensions of Vision Transformers},
    author={Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    journal={arXiv: 2103.16302},
    year={2021},
}
Owner
NAVER AI
Official account of NAVER AI, Korea No.1 Industrial AI Research Group
NAVER AI
Code for unmixing audio signals in four different stems "drums, bass, vocals, others". The code is adapted from "Jukebox: A Generative Model for Music"

Status: Archive (code is provided as-is, no updates expected) Disclaimer This code is a based on "Jukebox: A Generative Model for Music" Paper We adju

Wadhah Zai El Amri 24 Dec 29, 2022
Cognition-aware Cognate Detection

Cognition-aware Cognate Detection The repository which contains our code for our EACL 2021 paper titled, "Cognition-aware Cognate Detection". This wor

Prashant K. Sharma 1 Feb 01, 2022
ThunderGBM: Fast GBDTs and Random Forests on GPUs

Documentations | Installation | Parameters | Python (scikit-learn) interface What's new? ThunderGBM won 2019 Best Paper Award from IEEE Transactions o

Xtra Computing Group 647 Jan 04, 2023
Talk covering the features of skorch

Skorch Talk Skorch - A Union of Scikit-learn and PyTorch Presentation The slides can be downloaded at: download link. Google Colab Part One - MNIST Pa

Thomas J. Fan 3 Oct 20, 2020
Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models

Patch-Rotation(PatchRot) Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models Submitted to Neurips2021 To

4 Jul 12, 2021
generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search

generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search This repository contains single-threaded TreeMesh code. I'm Hua Tong, a senior stu

Hua Tong 18 Sep 21, 2022
An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance"

Lidar-Segementation An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance" from

Wangxu1996 135 Jan 06, 2023
The source code of the paper "SHGNN: Structure-Aware Heterogeneous Graph Neural Network"

SHGNN: Structure-Aware Heterogeneous Graph Neural Network The source code and dataset of the paper: SHGNN: Structure-Aware Heterogeneous Graph Neural

Wentao Xu 7 Nov 13, 2022
The project is an official implementation of our CVPR2019 paper "Deep High-Resolution Representation Learning for Human Pose Estimation"

Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) News [2020/07/05] A very nice blog from Towards Data Science introd

Leo Xiao 3.9k Jan 05, 2023
Do Smart Glasses Dream of Sentimental Visions? Deep Emotionship Analysis for Eyewear Devices

EMOShip This repository contains the EMO-Film dataset described in the paper "Do Smart Glasses Dream of Sentimental Visions? Deep Emotionship Analysis

1 Nov 18, 2022
Repositório para arquivos sobre o Módulo 1 do curso Top Coders da Let's Code + Safra

850-Safra-DS-ModuloI Repositório para arquivos sobre o Módulo 1 do curso Top Coders da Let's Code + Safra Para aprender mais Git https://learngitbranc

Brian Nunes 7 Dec 10, 2022
TextBPN Adaptive Boundary Proposal Network for Arbitrary Shape Text Detection

TextBPN Adaptive Boundary Proposal Network for Arbitrary Shape Text Detection; Accepted by ICCV2021. Note: The complete code (including training and t

S.X.Zhang 84 Dec 13, 2022
Moment-DETR code and QVHighlights dataset

Moment-DETR QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries Jie Lei, Tamara L. Berg, Mohit Bansal For dataset de

Jie Lei 雷杰 133 Dec 22, 2022
Real-time Joint Semantic Reasoning for Autonomous Driving

MultiNet MultiNet is able to jointly perform road segmentation, car detection and street classification. The model achieves real-time speed and state-

Marvin Teichmann 518 Dec 12, 2022
Code for the paper "SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness" (NeurIPS 2021)

SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness (NeurIPS2021) This repository contains code for the paper "Smo

Jongheon Jeong 17 Dec 27, 2022
PyTorch implementation of MoCo: Momentum Contrast for Unsupervised Visual Representation Learning

MoCo: Momentum Contrast for Unsupervised Visual Representation Learning This is a PyTorch implementation of the MoCo paper: @Article{he2019moco, aut

Meta Research 3.7k Jan 02, 2023
Efficient Deep Learning Systems course

Efficient Deep Learning Systems This repository contains materials for the Efficient Deep Learning Systems course taught at the Faculty of Computer Sc

Max Ryabinin 173 Dec 29, 2022
A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run.

Minimal Hand A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run. This project provides the

Yuxiao Zhou 824 Jan 07, 2023
Compositional Sketch Search

Compositional Sketch Search Official repository for ICIP 2021 Paper: Compositional Sketch Search Requirements Install and activate conda environment c

Alexander Black 8 Sep 06, 2021
Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor.

Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor. It is devel

33 Nov 11, 2022