[NeurIPS 2021]: Are Transformers More Robust Than CNNs? (Pytorch implementation & checkpoints)

Overview

Are Transformers More Robust Than CNNs?

Pytorch implementation for NeurIPS 2021 Paper: Are Transformers More Robust Than CNNs?

Our implementation is based on DeiT.

Introduction

Transformer emerges as a powerful tool for visual recognition. In addition to demonstrating competitive performance on a broad range of visual benchmarks, recent works also argue that Transformers are much more robust than Convolutions Neural Networks (CNNs). Nonetheless, surprisingly, we find these conclusions are drawn from unfair experimental settings, where Transformers and CNNs are compared at different scales and are applied with distinct training frameworks. In this paper, we aim to provide the first fair & in-depth comparisons between Transformers and CNNs, focusing on robustness evaluations.

With our unified training setup, we first challenge the previous belief that Transformers outshine CNNs when measuring adversarial robustness. More surprisingly, we find CNNs can easily be as robust as Transformers on defending against adversarial attacks, if they properly adopt Transformers' training recipes. While regarding generalization on out-of-distribution samples, we show pre-training on (external) large-scale datasets is not a fundamental request for enabling Transformers to achieve better performance than CNNs. Moreover, our ablations suggest such stronger generalization is largely benefited by the Transformer's self-attention-like architectures per se, rather than by other training setups. We hope this work can help the community better understand and benchmark the robustness of Transformers and CNNs.

Pretrained models

We provide both pretrained vanilla models and adversarially trained models.

Vanilla Training

Main Results

Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res50-Ori download link 76.9 3.2 57.9 8.3
Res50-Align download link 76.3 4.5 55.6 8.2
Res50-Best download link 75.7 6.3 52.3 10.8
DeiT-Small download link 76.8 12.2 48.0 13.0

Model Size

ResNets:

  • ResNets fully aligned (with DeiT's training recipe) model, denoted as res*:
Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res18* 11.69M download link 67.83 1.92 64.14 7.92
Res50* 25.56M download link 76.28 4.53 55.62 8.17
Res101* 44.55M download link 77.97 8.84 49.19 11.60
  • ResNets best model (for Out-of-Distribution (OOD) generalization), denoted as res-best:
Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res18-best 11.69M download link 66.81 2.03 62.65 9.45
Res50-best 25.56M download link 75.74 6.32 52.25 10.77
Res101-best 44.55M download link 77.83 11.49 47.35 13.28

DeiTs:

Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
DeiT-Mini 9.98M download link 72.89 8.19 54.68 9.88
DeiT-Small 22.05M download link 76.82 12.21 47.99 12.98

Model Distillation

Architecture Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Teacher DeiT-Small download link 76.8 12.2 48.0 13.0
Student Res50*-Distill download link 76.7 5.2 54.2 9.8
Teacher Res50* download link 76.3 4.5 55.6 8.2
Student DeiT-S-Distill download link 76.2 10.9 49.3 11.9

Adversarial Training

Pretrained Model Clean Acc PGD-100 Auto Attack
Res50-ReLU download link 66.77 32.26 26.41
Res50-GELU download link 67.38 40.27 35.51
DeiT-Small download link 66.50 40.32 35.50

Vanilla Training

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision, and the training and validation data is expected to be in the train folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Environment

Install dependencies:

pip3 install -r requirements.txt

Training Scripts

To train a ResNet model on ImageNet run:

bash script/res.sh

To train a DeiT model on ImageNet run:

bash script/deit.sh

Generalization to Out-of-Distribution Sample

Data Preparation

Download and extract ImageNet-A, ImageNet-C, Stylized-ImageNet val images:

/path/to/datasets/
  val/
    class1/
      img1.jpeg
    class/2
      img2.jpeg

Evaluation Scripts

To evaluate pre-trained models, run:

bash script/generation_to_ood.sh

It is worth noting that for ImageNet-C evaluation, the error rate is calculated based on the Noise, Blur, Weather and Digital categories.

Adversarial Training

To perform adversarial training on ResNet run:

bash script/advres.sh

To do adversarial training on DeiT run:

bash scripts/advdeit.sh

Robustness to Adversarial Example

PGD Attack Evaluation

To evaluate the pre-trained models, run:

bash script/eval_advtraining.sh

AutoAttack Evaluation

./autoattack contains the AutoAttack public package, with a little modification to best support ImageNet evaluation.

cd autoattack/
bash autoattack.sh

Patch Attack Evaluation

Please refer to PatchAttack

Citation

If you use our code, models or wish to refer to our results, please use the following BibTex entry:

@inproceedings{bai2021transformers,
  title     = {Are Transformers More Robust Than CNNs?},
  author    = {Bai, Yutong and Mei, Jieru and Yuille, Alan and Xie, Cihang},
  booktitle = {Thirty-Fifth Conference on Neural Information Processing Systems},
  year      = {2021},
}
Owner
Yutong Bai
CS Ph.D student @ JHU, CCVL
Yutong Bai
InterfaceGAN++: Exploring the limits of InterfaceGAN

InterfaceGAN++: Exploring the limits of InterfaceGAN Authors: Apavou Clément & Belkada Younes From left to right - Images generated using styleGAN and

Younes Belkada 42 Dec 23, 2022
Supporting code for "Autoregressive neural-network wavefunctions for ab initio quantum chemistry".

naqs-for-quantum-chemistry This repository contains the codebase developed for the paper Autoregressive neural-network wavefunctions for ab initio qua

Tom Barrett 24 Dec 23, 2022
PyTorch Implementation for AAAI'21 "Do Response Selection Models Really Know What's Next? Utterance Manipulation Strategies for Multi-turn Response Selection"

UMS for Multi-turn Response Selection Implements the model described in the following paper Do Response Selection Models Really Know What's Next? Utte

Taesun Whang 47 Nov 22, 2022
This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (ICLR 2022)

Equivariant Subgraph Aggregation Networks (ESAN) This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (IC

Beatrice Bevilacqua 59 Dec 13, 2022
The code for SAG-DTA: Prediction of Drug–Target Affinity Using Self-Attention Graph Network.

SAG-DTA The code is the implementation for the paper 'SAG-DTA: Prediction of Drug–Target Affinity Using Self-Attention Graph Network'. Requirements py

Shugang Zhang 7 Aug 02, 2022
Data-driven reduced order modeling for nonlinear dynamical systems

SSMLearn Data-driven Reduced Order Models for Nonlinear Dynamical Systems This package perform data-driven identification of reduced order model based

Haller Group, Nonlinear Dynamics 27 Dec 13, 2022
Unsupervised clustering of high content screen samples

Microscopium Unsupervised clustering and dataset exploration for high content screens. See microscopium in action Public dataset BBBC021 from the Broa

60 Dec 05, 2022
Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstrac

2 Apr 14, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
TensorFlow for Raspberry Pi

TensorFlow on Raspberry Pi It's officially supported! As of TensorFlow 1.9, Python wheels for TensorFlow are being officially supported. As such, this

Sam Abrahams 2.2k Dec 16, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Tensorflow2 Keras-based Semantic Segmentation Models Implementation

Tensorflow2 Keras-based Semantic Segmentation Models Implementation

Hah Min Lew 1 Feb 08, 2022
Code for the CVPR2021 paper "Patch-NetVLAD: Multi-Scale Fusion of Locally-Global Descriptors for Place Recognition"

Patch-NetVLAD: Multi-Scale Fusion of Locally-Global Descriptors for Place Recognition This repository contains code for the CVPR2021 paper "Patch-NetV

QVPR 368 Jan 06, 2023
nextPARS, a novel Illumina-based implementation of in-vitro parallel probing of RNA structures.

nextPARS, a novel Illumina-based implementation of in-vitro parallel probing of RNA structures. Here you will find the scripts necessary to produce th

Jesse Willis 0 Jan 20, 2022
Self-labelling via simultaneous clustering and representation learning. (ICLR 2020)

Self-labelling via simultaneous clustering and representation learning 🆗 🆗 🎉 NEW models (20th August 2020): Added standard SeLa pretrained torchvis

Yuki M. Asano 469 Jan 02, 2023
CPF: Learning a Contact Potential Field to Model the Hand-object Interaction

Contact Potential Field This repo contains model, demo, and test codes of our paper: CPF: Learning a Contact Potential Field to Model the Hand-object

Lixin YANG 99 Dec 26, 2022
My published benchmark for a Kaggle Simulations Competition

Lux AI Working Title Bot Please refer to the Kaggle notebook for the comment section. The comment section contains my explanation on my code structure

Tong Hui Kang 29 Aug 22, 2022
This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin

SPARQLing Database Queries from Intermediate Question Decompositions This repo is the implementation of the following paper: SPARQLing Database Querie

Yandex Research 20 Dec 19, 2022
The final project of "Applying AI to 2D Medical Imaging Data" of "AI for Healthcare" nanodegree - Udacity.

Pneumonia Detection from X-Rays Project Overview In this project, you will apply the skills that you have acquired in this 2D medical imaging course t

Omar Laham 1 Jan 14, 2022