An example of Scatterbrain implementation (combining local attention and Performer)

Overview

We use the template from https://github.com/ashleve/lightning-hydra-template. Please read the instructions there to understand the repo structure.

Implementation & Experiments

An example of Scatterbrain implementation (combining local attention and Performer) is in the file src/models/modules/attention/sblocal.py.

T2T-ViT inference on ImageNet

To run the T2T-ViT inference on ImageNet experiment:

  1. Download the pretrained weights from the [T2T-ViT repo][https://github.com/yitu-opensource/T2T-ViT/releases]:
mkdir -p checkpoints/t2tvit
cd checkpoints/t2tvit
wget https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar
  1. Convert the weights to the format compatible with our implementation of T2T-ViT:
# cd to scatterbrain path
python scripts/convert_checkpoint_t2t_vit.py checkpoints/t2tvit/81.7_T2T_ViTt_14.pth.tar
  1. Download the ImageNet dataset (just the validation set will suffice). Below, /path/to/imagenet refers to the directory that contains the train and val directories.
  2. Run the inference experiments:
python run.py experiment=imagenet-t2tvit-eval.yaml model/t2tattn_cfg=full datamodule.data_dir=/path/to/imagenet/ eval.ckpt=checkpoints/t2tvit/81.7_T2T_ViTt_14.pth.tar  # 81.7% acc
python run.py experiment=imagenet-t2tvit-eval.yaml model/t2tattn_cfg=local datamodule.data_dir=/path/to/imagenet/ eval.ckpt=checkpoints/t2tvit/81.7_T2T_ViTt_14.pth.tar  # 80.6% acc
python run.py experiment=imagenet-t2tvit-eval.yaml model/t2tattn_cfg=performer datamodule.data_dir=/path/to/imagenet/ eval.ckpt=checkpoints/t2tvit/81.7_T2T_ViTt_14.pth.tar  # 77.8-79.0% acc (there's randomness)
python run.py experiment=imagenet-t2tvit-eval.yaml model/t2tattn_cfg=sblocal datamodule.data_dir=/path/to/imagenet/ eval.ckpt=checkpoints/t2tvit/81.7_T2T_ViTt_14.pth.tar  # 81.1% acc

Requirements

Python 3.8+, Pytorch 1.9+, torchvision, torchtext, pytorch-fast-transformers, munch, einops, timm, hydra-core, hydra-colorlog, python-dotenv, rich, pytorch-lightning, lightning-bolts.

We provide a Dockerfile that lists all the required packages.

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@inproceedings{chen2021scatterbrain,
  title={Scatterbrain: Unifying Sparse and Low-rank Attention},
  author={Beidi Chen and Tri Dao and Eric Winsor and Zhao Song and Atri Rudra and Christopher R\'{e}},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}
Owner
HazyResearch
We are a CS research group led by Prof. Chris Ré.
HazyResearch
CVPR 2021 Official Pytorch Code for UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training

UC2 UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training Mingyang Zhou, Luowei Zhou, Shuohang Wang, Yu Cheng, Linjie Li, Zhou Yu,

Mingyang Zhou 28 Dec 30, 2022
An open-source Deep Learning Engine for Healthcare that aims to treat & prevent major diseases

AlphaCare Background AlphaCare is a work-in-progress, open-source Deep Learning Engine for Healthcare that aims to treat and prevent major diseases. T

Siraj Raval 44 Nov 05, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
Safe Bayesian Optimization

SafeOpt - Safe Bayesian Optimization This code implements an adapted version of the safe, Bayesian optimization algorithm, SafeOpt [1], [2]. It also p

Felix Berkenkamp 111 Dec 11, 2022
Rational Activation Functions - Replacing Padé Activation Units

Rational Activations - Learnable Rational Activation Functions First introduce as PAU in Padé Activation Units: End-to-end Learning of Activation Func

<a href=[email protected]"> 38 Nov 22, 2022
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022
Official repository of ICCV21 paper "Viewpoint Invariant Dense Matching for Visual Geolocalization"

Viewpoint Invariant Dense Matching for Visual Geolocalization: PyTorch implementation This is the implementation of the ICCV21 paper: G Berton, C. Mas

Gabriele Berton 44 Jan 03, 2023
A framework for multi-step probabilistic time-series/demand forecasting models

JointDemandForecasting.py A framework for multi-step probabilistic time-series/demand forecasting models File stucture JointDemandForecasting contains

Stanford Intelligent Systems Laboratory 3 Sep 28, 2022
Rotation-Only Bundle Adjustment

ROBA: Rotation-Only Bundle Adjustment Paper, Video, Poster, Presentation, Supplementary Material In this repository, we provide the implementation of

Seong 51 Nov 29, 2022
RLMeta is a light-weight flexible framework for Distributed Reinforcement Learning Research.

RLMeta rlmeta - a flexible lightweight research framework for Distributed Reinforcement Learning based on PyTorch and moolib Installation To build fro

Meta Research 281 Dec 22, 2022
MlTr: Multi-label Classification with Transformer

MlTr: Multi-label Classification with Transformer This is official implement of "MlTr: Multi-label Classification with Transformer". Abstract The task

程星 38 Nov 08, 2022
PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network"

HAN PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network" This repository is for HAN introduced in the

五维空间 140 Nov 23, 2022
OntoProtein: Protein Pretraining With Ontology Embedding

OntoProtein This is the implement of the paper "OntoProtein: Protein Pretraining With Ontology Embedding". OntoProtein is an effective method that mak

ZJUNLP 80 Dec 14, 2022
Source code for CVPR 2020 paper "Learning to Forget for Meta-Learning"

L2F - Learning to Forget for Meta-Learning Sungyong Baik, Seokil Hong, Kyoung Mu Lee Source code for CVPR 2020 paper "Learning to Forget for Meta-Lear

Sungyong Baik 29 May 22, 2022
Contrastive Loss Gradient Attack (CLGA)

Contrastive Loss Gradient Attack (CLGA) Official implementation of Unsupervised Graph Poisoning Attack via Contrastive Loss Back-propagation, WWW22 Bu

12 Dec 23, 2022
Codes for [NeurIPS'21] You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership.

You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership Codes for [NeurIPS'21] You are caught stealing my winni

VITA 8 Nov 01, 2022
Galaxy images labelled by morphology (shape). Aimed at ML development and teaching

Galaxy images labelled by morphology (shape). Aimed at ML debugging and teaching.

Mike Walmsley 14 Nov 28, 2022
“Data Augmentation for Cross-Domain Named Entity Recognition” (EMNLP 2021)

Data Augmentation for Cross-Domain Named Entity Recognition Authors: Shuguang Chen, Gustavo Aguilar, Leonardo Neves and Thamar Solorio This repository

<a href=[email protected]"> 18 Sep 10, 2022
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
QHack—the quantum machine learning hackathon

Official repo for QHack—the quantum machine learning hackathon

Xanadu 72 Dec 21, 2022