(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency 😱 Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 ✕ 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes ⚠️ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing 📑 the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Owner
xxxnell
Programmer & ML researcher
xxxnell
Rainbow: Combining Improvements in Deep Reinforcement Learning

Rainbow Rainbow: Combining Improvements in Deep Reinforcement Learning [1]. Results and pretrained models can be found in the releases. DQN [2] Double

Kai Arulkumaran 1.4k Dec 29, 2022
DL & CV-based indicator toolset for the vehicle drivers via live dash-cam footage.

Vehicle Indicator Toolset Deep Learning and Computer Vision based indicator toolset for vehicle drivers using live dash-cam footages. Tracking of vehi

Alex Xu 12 Dec 28, 2021
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
The official implementation of Theme Transformer

Theme Transformer This is the official implementation of Theme Transformer. Checkout our demo and paper : Demo | arXiv Environment: using python versi

Ian Shih 85 Dec 08, 2022
A time series processing library

Timeseria Timeseria is a time series processing library which aims at making it easy to handle time series data and to build statistical and machine l

Stefano Alberto Russo 11 Aug 08, 2022
DeepVoxels is an object-specific, persistent 3D feature embedding.

DeepVoxels is an object-specific, persistent 3D feature embedding. It is found by globally optimizing over all available 2D observations of

Vincent Sitzmann 196 Dec 25, 2022
Deep universal probabilistic programming with Python and PyTorch

Getting Started | Documentation | Community | Contributing Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notab

7.7k Dec 30, 2022
MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021)

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

2 Jan 29, 2022
FNet Implementation with TensorFlow & PyTorch

FNet Implementation with TensorFlow & PyTorch. TensorFlow & PyTorch implementation of the paper "FNet: Mixing Tokens with Fourier Transforms". Overvie

Abdelghani Belgaid 1 Feb 12, 2022
codes for Self-paced Deep Regression Forests with Consideration on Ranking Fairness

Self-paced Deep Regression Forests with Consideration on Ranking Fairness This is official codes for paper Self-paced Deep Regression Forests with Con

Learning in Vision 4 Sep 11, 2022
Proof of concept GnuCash Webinterface

Proof of Concept GnuCash Webinterface This may one day be a something truly great. Milestones [ ] Browse accounts and view transactions [ ] Record sim

Josh 14 Dec 28, 2022
With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function

With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function. At the momen

ChemEngAI 40 Dec 27, 2022
Combinatorially Hard Games where the levels are procedurally generated

puzzlegen Implementation of two procedurally simulated environments with gym interfaces. IceSlider: the agent needs to reach and stop on the pink squa

Autonomous Learning Group 3 Jun 26, 2022
A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation.

TiSASRec.paddle A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation. Introduction 论文:Time Interval Aware Sel

Paddorch 2 Nov 28, 2021
HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow

Class HiddenMarkovModel HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow 2.0 Installatio

Susara Thenuwara 2 Nov 03, 2021
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022
MohammadReza Sharifi 27 Dec 13, 2022
Tutel MoE: An Optimized Mixture-of-Experts Implementation

Project Tutel Tutel MoE: An Optimized Mixture-of-Experts Implementation. Supported Framework: Pytorch Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) Ho

Microsoft 344 Dec 29, 2022
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022
code from "Tensor decomposition of higher-order correlations by nonlinear Hebbian plasticity"

Code associated with the paper "Tensor decomposition of higher-order correlations by nonlinear Hebbian learning," Ocker & Buice, Neurips 2021. "plot_f

Gabriel Koch Ocker 4 Oct 16, 2022