Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Overview

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv)

This is a Pytorch implementation of our technical report.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Training Pipeline

Pipeline

Our codes are based on the pytorch-image-models by Ross Wightman.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-L 24 768 448 150.47M 86.2 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map here. As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Coming soon

Reference

If you use this repo or find it useful, please consider citing:

@misc{jiang2021token,
      title={Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet}, 
      author={Zihang Jiang and Qibin Hou and Li Yuan and Daquan Zhou and Xiaojie Jin and Anran Wang and Jiashi Feng},
      year={2021},
      eprint={2104.10858},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Related projects

T2T-ViT, Re-labeling ImageNet.

Comments
  • error: download the pretrained model but couldn't be unzipped

    error: download the pretrained model but couldn't be unzipped

    tar -xvf lvvit_s-26M-384-84-4.pth.tar tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by Williamlizl 10
  • The accuracy of the validation set is 0,and the loss is always around 13

    The accuracy of the validation set is 0,and the loss is always around 13

    Hello! I use ILSVRC2012_img_train and ILSVRC2012_img_val, and use the provided label_top5_train_nfnet from Google Drive. I train lv-vit-s with batch_size 64 without apex for one epoch. Thanks for your advice.

    opened by yifanQi98 7
  • Pretrained weights for LV-ViT-T

    Pretrained weights for LV-ViT-T

    Hi,

    Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

    All the best, Marc

    opened by marc345 5
  • train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    Hi, thanks for you great work. When I train script, some error occurs: AttributeError: 'tuple' object has no attribute 'log_softmax'

    with amp_autocast():   
                output = model(input)  
                loss = loss_fn(output, target)  # error occurs
    
    

    and loss function is train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.0).cuda()

    by the way: Could you please tell me why we need to specify smoothing=0.0?

    opened by lxy5513 5
  • RuntimeError: CUDA error: device-side assert triggered

    RuntimeError: CUDA error: device-side assert triggered

    I am a green hand of DL. When I run the code of volo with tlt in a single or multi GPU, I get an error as follows: /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. Traceback (most recent call last): File "main.py", line 949, in main() File "main.py", line 664, in main optimizers=optimizers) File "main.py", line 773, in train_one_epoch label_size=args.token_label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords num_classes=num_classes,device=device) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 16, in get_featuremaps _label_topk[1][:, :, :].long(), RuntimeError: CUDA error: device-side assert triggered.

    I can't fix this problem right now.

    opened by JIAOJIAYUASD 4
  • Generating label for custom dataset

    Generating label for custom dataset

    Hello,

    Thank you for sharing your work. I am currently trying to generate token label to a custom dataset for model lvvit_s, but I keep getting the loss close to 7 and the Accuracy 0 (not pre-trained and using 1 GPU in Google Colab). I also tried using the pre-trained model with --transfer but got 0 in both Loss and Acc . What option should I use for a custom dataset? image

    opened by AleMaiaF 2
  • generate_label.py unable to find model lvvit_s

    generate_label.py unable to find model lvvit_s

    Hi,

    When I tried to run the label generation script for the model lvvit_s it returned an error "RuntimeError: Unknown model".

    Solution: It worked when I added the line "import tlt.models" in the file generate_label.py.

    opened by AleMaiaF 2
  • Can Token labeling reach higher than annotator model?

    Can Token labeling reach higher than annotator model?

    Greetings,

    Thank you for this incredible research.

    I would like to know if it is possible to use Token Labeling to achieve scores higher than that of the annotator model, I believe this was the case with VOLO D5 model where it achieved higher score than NFNet, model used for annotation.

    opened by ErenBalatkan 1
  • label_map does not do the same augmentation (random crop) as the input image

    label_map does not do the same augmentation (random crop) as the input image

    Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

    Shall we do

            return torchvision_F.resized_crop(
                    img, i, j, h, w, self.size, interpolation
            ), torchvision_F.resized_crop(
                    label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
            )
    

    Thanks

    opened by haooooooqi 1
  • Python3.6, ok; Python3.8, error

    Python3.6, ok; Python3.8, error

    Test: [ 0/1] Time: 11.293 (11.293) Loss: 0.7043 (0.7043) [email protected]: 42.1875 (42.1875) [email protected]: 100.0000 (100.0000) Test: [ 1/1] Time: 0.108 (5.701) Loss: 0.5847 (0.6689) [email protected]: 89.8148 (56.3187) [email protected]: 100.0000 (100.0000) free(): invalid pointer free(): invalid pointer Traceback (most recent call last): File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 303, in <module> main() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 294, in main raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.8', '-u', 'main.py', '--local_rank=1', './dataset/c/c', '--model', 'lvvit_s', '-b', '128', '--apex-amp', '--img-size', '224', '--drop-path', '0.1', '--token-label', '--token-label-size', '14', '--dense-weight', '0.0', '--num-classes', '2', '--finetune', './pretrained/lvvit_s-26M-384-84-4.pth.tar']' died with <Signals.SIGABRT: 6>. [email protected]:/puxin_libochao/TokenLabeling# CUDA_VISIBLE_DEVICES=0,1 bash ./distributed_train.sh 2 ./dataset/c/c --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes 2 --finetune ./pretrained/lvvit_s-26M-384-84-4.pth.tar

    opened by Williamlizl 1
  • A Bag of Training Techniques for ViT

    A Bag of Training Techniques for ViT

    Hi, thanks for your wonderful work. I have a question that whether training techniques mentioned in the LV-Vit can be used in other downstream task like object detection? In your paper, I see that many of this techniques are used in ImageNet. Thanks!

    opened by qdd1234 1
  • how to apply token labeling to CNN ?

    how to apply token labeling to CNN ?

    Hello ~ I'm interested in your token labeling technique, So I want to apply this technique in CNN based model because ViT is very heavy to train.

    can I get the your code with CNN token labeling? if you're not give me some detail for implementing

    thank you.

    opened by HoJ00n2 0
  • Model settings for Cifar10

    Model settings for Cifar10

    I am interested if there is any LV-ViT- model setup you have tested for Cifar10. I would like to know the proper setup of all blocks in none pretrained weights settings.

    opened by Aminullah6264 0
Owner
蒋子航
Now a Ph.D. student supervised by Prof. Feng Jiashi in ECE, NUS.
蒋子航
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
Object Depth via Motion and Detection Dataset

ODMD Dataset ODMD is the first dataset for learning Object Depth via Motion and Detection. ODMD training data are configurable and extensible, with ea

Brent Griffin 172 Dec 21, 2022
load .txt to train YOLOX, same as Yolo others

YOLOX train your data you need generate data.txt like follow format (per line- one image). prepare one data.txt like this: img_path1 x1,y1,x2,y2,clas

LiMingf 18 Aug 18, 2022
Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization

CoMatch: Semi-supervised Learning with Contrastive Graph Regularization (Salesforce Research) This is a PyTorch implementation of the CoMatch paper [B

Salesforce 107 Dec 14, 2022
FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment

FaceQgen FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment This repository is based on the paper: "FaceQgen: Semi-Supervised D

Javier Hernandez-Ortega 3 Aug 04, 2022
[ACM MM 2021] Yes, "Attention is All You Need", for Exemplar based Colorization

Transformer for Image Colorization This is an implemention for Yes, "Attention Is All You Need", for Exemplar based Colorization, and the current soft

Wang Yin 30 Dec 07, 2022
Open-Set Recognition: A Good Closed-Set Classifier is All You Need

Open-Set Recognition: A Good Closed-Set Classifier is All You Need Code for our paper: "Open-Set Recognition: A Good Closed-Set Classifier is All You

194 Jan 03, 2023
NHS AI Lab Skunkworks project: Long Stayer Risk Stratification

NHS AI Lab Skunkworks project: Long Stayer Risk Stratification A pilot project for the NHS AI Lab Skunkworks team, Long Stayer Risk Stratification use

NHSX 21 Nov 14, 2022
Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation.

AVATAR Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation. AVATAR stands for jAVA-pyThon progrAm tRanslation. AV

Wasi Ahmad 26 Dec 03, 2022
Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation

Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation Introduction This is a PyTorch

XMed-Lab 30 Sep 23, 2022
The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL), NeurIPS-2021

Directed Graph Contrastive Learning The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL). In this paper, we present the first con

Tong Zekun 28 Jan 08, 2023
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
UMT is a unified and flexible framework which can handle different input modality combinations, and output video moment retrieval and/or highlight detection results.

Unified Multi-modal Transformers This repository maintains the official implementation of the paper UMT: Unified Multi-modal Transformers for Joint Vi

Applied Research Center (ARC), Tencent PCG 84 Jan 04, 2023
Magisk module to enable hidden features on Android 12 Developer Preview 1.

Android 12 Extensions This is a Magisk module that enables hidden features on Android 12 Developer Preview 1. Features Scrolling screenshots Wallpaper

Danny Lin 384 Jan 06, 2023
Bayesian Deep Learning and Deep Reinforcement Learning for Object Shape Error Response and Correction of Manufacturing Systems

Bayesian Deep Learning for Manufacturing 2.0 (dlmfg) Object Shape Error Response (OSER) Digital Lifecycle Management - In Process Quality Improvement

Sumit Sinha 30 Oct 31, 2022
RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020)

RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020) Hong Wang, Qi Xie, Qian Zhao, and Deyu Meng [PDF] [Supplementary M

Hong Wang 6 Sep 27, 2022
A Python library that enables ML teams to share, load, and transform data in a collaborative, flexible, and efficient way :chestnut:

Squirrel Core Share, load, and transform data in a collaborative, flexible, and efficient way What is Squirrel? Squirrel is a Python library that enab

Merantix Momentum 249 Dec 07, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
TLDR: Twin Learning for Dimensionality Reduction

TLDR (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self

NAVER 105 Dec 28, 2022
An Evaluation of Generative Adversarial Networks for Collaborative Filtering.

An Evaluation of Generative Adversarial Networks for Collaborative Filtering. This repository was developed by Fernando B. Pérez Maurera. Fernando is

Fernando Benjamín PÉREZ MAURERA 0 Jan 19, 2022