Aggragrating Nested Transformer Official Jax Implementation

Overview

Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

This is not an officially supported Google product.

Pretrained Models and Results

Model Accuracy Checkpoint path
Nest-B 83.8 gs://gresearch/nest-checkpoints/nest-b_imagenet
Nest-S 83.3 gs://gresearch/nest-checkpoints/nest-s_imagenet
Nest-T 81.5 gs://gresearch/nest-checkpoints/nest-t_imagenet

Note: Accuracy is evaluated on the ImageNet2012 validation set.

Tensorbord.dev

See ImageNet training logs at Tensorboard.dev.

Colab

Colab is available for test: https://colab.sandbox.google.com/github/google-research/nested-transformer/blob/main/colab.ipynb

Instruction on Image Classification

Environment setup

virtualenv -p python3 --system-site-packages nestenv
source nestenv/bin/activate

pip install -r requirements.txt

Evaluate on ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from command lines. Optionally, download all pre-trained checkpoints

bash ./checkpoints/download_checkpoints.sh

Run the evaluation script to evaluate NesT-B.

python main.py --config configs/imagenet_nest.py --config.eval_only=True \
  --config.init_checkpoint="./checkpoints/nest-b_imagenet/ckpt.39" \
  --workdir="./checkpoints/nest-t_imagenet_eval"

Train on ImageNet

The default configuration trains NesT-B on TPUv2 8x8 with per device batch size 16.

python main.py --config configs/imagenet_nest.py --jax_backend_target=<TPU_IP_ADDRESS> --jax_xla_backend="tpu_driver" --workdir="./checkpoints/nest-b_imagenet"

Note: See jax/cloud_tpu_colab for info about TPU_IP_ADDRESS.

Train NesT-T on 8 GPUs.

python main.py --config configs/imagenet_nest_tiny.py --workdir="./checkpoints/nest-t_imagenet_8gpu"

The codebase does not support multi-node GPU training (>8 GPUs). The models reported in our paper is trained using TPU with 1024 total batch size.

Train on CIFAR

# Recommend to train on 2 GPUs. Training NesT-T can use 1 GPU.
CUDA_VISIBLE_DEVICES=0,1 python  main.py --config configs/cifar_nest.py --workdir="./checkpoints/nest_cifar"

Cite

@inproceedings{zhang2021aggregating,
  title={Aggregating Nested Transformers},
  author={Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2105.12723},
  year={2021}
}
Owner
Google Research
Google Research
An official implementation of "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation" (CVPR 2021) in PyTorch.

BANA This is the implementation of the paper "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation". For more inf

CV Lab @ Yonsei University 59 Dec 12, 2022
Deep motion transfer

animation-with-keypoint-mask Paper The right most square is the final result. Softmax mask (circles): \ Heatmap mask: \ conda env create -f environmen

9 Nov 01, 2022
A Marvelous ChatBot implement using PyTorch.

PyTorch Marvelous ChatBot [Update] it's 2019 now, previously model can not catch up state-of-art now. So we just move towards the future a transformer

JinTian 223 Oct 18, 2022
Distance Encoding for GNN Design

Distance-encoding for GNN design This repository is the official PyTorch implementation of the DEGNN and DEAGNN framework reported in the paper: Dista

172 Nov 08, 2022
Robust & Reliable Route Recommendation on Road Networks

NeuroMLR: Robust & Reliable Route Recommendation on Road Networks This repository is the official implementation of NeuroMLR: Robust & Reliable Route

4 Dec 20, 2022
StarGAN - Official PyTorch Implementation (CVPR 2018)

StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

Yunjey Choi 5.1k Dec 30, 2022
JORLDY an open-source Reinforcement Learning (RL) framework provided by KakaoEnterprise

Repository for Open Source Reinforcement Learning Framework JORLDY

Kakao Enterprise Corp. 330 Dec 30, 2022
This repository contains code to run experiments in the paper "Signal Strength and Noise Drive Feature Preference in CNN Image Classifiers."

Signal Strength and Noise Drive Feature Preference in CNN Image Classifiers This repository contains code to run experiments in the paper "Signal Stre

0 Jan 19, 2022
Human annotated noisy labels for CIFAR-10 and CIFAR-100.

Dataloader for CIFAR-N CIFAR-10N noise_label = torch.load('./data/CIFAR-10_human.pt') clean_label = noise_label['clean_label'] worst_label = noise_lab

<a href=[email protected]"> 117 Nov 30, 2022
KGDet: Keypoint-Guided Fashion Detection (AAAI 2021)

KGDet: Keypoint-Guided Fashion Detection (AAAI 2021) This is an official implementation of the AAAI-2021 paper "KGDet: Keypoint-Guided Fashion Detecti

Qian Shenhan 35 Dec 29, 2022
Learning from Synthetic Humans, CVPR 2017

Learning from Synthetic Humans (SURREAL) Gül Varol, Javier Romero, Xavier Martin, Naureen Mahmood, Michael J. Black, Ivan Laptev and Cordelia Schmid,

Gul Varol 538 Dec 18, 2022
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
Event queue (Equeue) dialect is an MLIR Dialect that models concurrent devices in terms of control and structure.

Event Queue Dialect Event queue (Equeue) dialect is an MLIR Dialect that models concurrent devices in terms of control and structure. Motivation The m

Cornell Capra 23 Dec 08, 2022
Pytorch implementation of DeePSiM

Pytorch implementation of DeePSiM

1 Nov 05, 2021
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 409 Jan 06, 2023
Sequence Modeling with Structured State Spaces

Structured State Spaces for Sequence Modeling This repository provides implementations and experiments for the following papers. S4 Efficiently Modeli

HazyResearch 896 Jan 01, 2023
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
Unified learning approach for egocentric hand gesture recognition and fingertip detection

Unified Gesture Recognition and Fingertip Detection A unified convolutional neural network (CNN) algorithm for both hand gesture recognition and finge

Mohammad 227 Dec 25, 2022
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
A package for music online and offline rhythmic information analysis including music Beat, downbeat, tempo and meter tracking.

BeatNet A package for music online and offline rhythmic information analysis including music Beat, downbeat, tempo and meter tracking. This repository

Mojtaba Heydari 157 Dec 27, 2022