ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Overview

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021

Update:

2021/03/11: update our new results. Now our T2T-ViT-14 with 21.5M parameters can reach 81.5% top1-acc with 224x224 image resolution, and 83.3% top1-acc with 384x384 resolution.

2021/02/21: T2T-ViT can be trained on most of common GPUs: 1080Ti, 2080Ti, TiTAN V, V100 stably with '--amp' (Automatic Mixed Precision). In some specifical GPU like Tesla T4, 'amp' would cause NAN loss when training T2T-ViT. If you get NAN loss in training, you can disable amp by removing '--amp' in the training scripts.

2021/01/28: release codes and upload most of the pretrained models of T2T-ViT.

Reference

If you find this repo useful, please consider citing:

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

Our codes are based on the official imagenet example by PyTorch and pytorch-image-models by Ross Wightman

1. Requirements

timm, pip install timm==0.3.4

torch>=1.4.0

torchvision>=0.5.0

pyyaml

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

2. T2T-ViT Models

Model T2T Transformer Top1 Acc #params MACs Download
T2T-ViT-14 Performer 81.5 21.5M 4.8G here
T2T-ViT-19 Performer 81.9 39.2M 8.5G here
T2T-ViT-24 Performer 82.3 64.1M 13.8G here
T2T-ViT-14, 384 Performer 83.3 21.7M here
T2T-ViT-24, Token Labeling Performer 84.2 65M here
T2T-ViT_t-14 Transformer 81.7 21.5M 6.1G here
T2T-ViT_t-19 Transformer 82.4 39.2M 9.8G here
T2T-ViT_t-24 Transformer 82.6 64.1M 15.0G here

The 'T2T-ViT-14, 384' means we train T2T-ViT-14 with image size of 384 x 384.

The 'T2T-ViT-24, Token Labeling' means we train T2T-ViT-24 with Token Labeling.

The three lite variants of T2T-ViT (Comparing with MobileNets):

Model T2T Transformer Top1 Acc #params MACs Download
T2T-ViT-7 Performer 71.7 4.3M 1.1G here
T2T-ViT-10 Performer 75.2 5.9M 1.5G here
T2T-ViT-12 Performer 76.5 6.9M 1.8G here

Usage

The way to use our pretrained T2T-ViT:

from models.t2t_vit import *
from utils import load_for_transfer_learning 

# create model
model = t2t_vit_14()

# load the pretrained weights
load_for_transfer_learning(model, /path/to/pretrained/weights, use_ema=True, strict=False, num_classes=1000)  # change num_classes based on dataset, can work for different image size as we interpolate the position embeding for different image size.

3. Validation

Test the T2T-ViT-14 (take Performer in T2T module),

Download the T2T-ViT-14, then test it by running:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_14 -b 100 --eval_checkpoint path/to/checkpoint

The results look like:

Test: [   0/499]  Time: 2.083 (2.083)  Loss:  0.3578 (0.3578)  [email protected]: 96.0000 (96.0000)  [email protected]: 99.0000 (99.0000)
Test: [  50/499]  Time: 0.166 (0.202)  Loss:  0.5823 (0.6404)  [email protected]: 85.0000 (86.1569)  [email protected]: 99.0000 (97.5098)
...
Test: [ 499/499]  Time: 0.272 (0.172)  Loss:  1.3983 (0.8261)  [email protected]: 62.0000 (81.5000)  [email protected]: 93.0000 (95.6660)
Top-1 accuracy of the model is: 81.5%

Test the three lite variants: T2T-ViT-7, T2T-ViT-10, T2T-ViT-12 (take Performer in T2T module),

Download the T2T-ViT-7, T2T-ViT-10 or T2T-ViT-12, then test it by running:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_7 -b 100 --eval_checkpoint path/to/checkpoint

Test the model T2T-ViT-14, 384 with 83.3% top-1 accuracy:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_14 --img-size 384 -b 100 --eval_checkpoint path/to/T2T-ViT-14-384 

4. Train

Train the three lite variants: T2T-ViT-7, T2T-ViT-10 and T2T-ViT-12 (take Performer in T2T module):

If only 4 GPUs are available,

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 path/to/data --model t2t_vit_7 -b 128 --lr 1e-3 --weight-decay .03 --amp --img-size 224

The top1-acc in 4 GPUs would be slightly lower than 8 GPUs (around 0.1%-0.3% lower).

If 8 GPUs are available:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_7 -b 64 --lr 1e-3 --weight-decay .03 --amp --img-size 224

Train the T2T-ViT-14 and T2T-ViT_t-14 (run on 4 or 8 GPUs):

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 path/to/data --model t2t_vit_14 -b 128 --lr 1e-3 --weight-decay .05 --amp --img-size 224
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_14 -b 64 --lr 5e-4 --weight-decay .05 --amp --img-size 224

If you want to train our T2T-ViT on images with 384x384 resolution, please use '--img-size 384'.

Train the T2T-ViT-19, T2T-ViT-24 or T2T-ViT_t-19, T2T-ViT_t-24:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_19 -b 64 --lr 5e-4 --weight-decay .065 --amp --img-size 224

5. Transfer T2T-ViT to CIFAR10/CIFAR100

Model ImageNet CIFAR10 CIFAR100 #params
T2T-ViT-14 81.5 98.3 88.4 21.5M
T2T-ViT-19 81.9 98.4 89.0 39.2M

We resize CIFAR10/100 to 224x224 and finetune our pretrained T2T-ViT-14/19 to CIFAR10/100 by running:

CUDA_VISIBLE_DEVICES=0,1 transfer_learning.py --lr 0.05 --b 64 --num-classes 10 --img-size 224 --transfer-learning True --transfer-model /path/to/pretrained/T2T-ViT-19

6. Visualization

Visualize the image features of ResNet50, you can open and run the visualization_resnet.ipynb file in jupyter notebook or jupyter lab; some results are given as following:

Visualize the image features of ViT, you can open and run the visualization_vit.ipynb file in jupyter notebook or jupyter lab; some results are given as following:

Visualize attention map, you can refer to this file. A simple example by visualizing the attention map in attention block 4 and 5 is:

Comments
  • Nan during training even without '--amp'

    Nan during training even without '--amp'

    Hello! I would like to train T2t_vit_14 model on ImageNet-100 dataset and 3 gpus Quadro RTX 5000 but I have gotten Nan in loss and the error. Could you please help to run the code?

    I run the following command: CUDA_VISIBLE_DEVICES=1,2,6,7 bash distributed_train.sh 4 /data/datasets/imagenet-100/ --model T2t_vit_14 -b 128 --lr 1e-3 --weight-decay .03 --cutmix 0.0 --reprob 0.25 --img-size 224

    Some printout: 200,4.4025687376658125,4.13563300743103,7.760000015258789,24.43999983520508 201,4.374216079711914,4.1354576759338375,7.639999951171875,24.200000134277342 202,4.392171382904053,4.136218957519532,7.7399999694824215,24.439999853515626 203,4.384297768274943,4.140018928909302,7.659999923706055,24.29999981689453 204,4.371897141138713,4.14544691696167,7.619999905395508,23.91999990234375 205,4.374680519104004,4.1505038471221924,7.7600000366210935,23.93999978027344 206,4.359750032424927,4.154387146377563,7.619999943542481,24.040000201416017 207,4.37085485458374,4.158743778991699,7.820000009155273,24.019999743652345 208,4.367704391479492,4.161868629837036,7.62000002746582,23.86000018310547 209,nan,4.160795520019532,7.7200000274658205,24.01999976196289 210,nan,nan,1.0,5.0 211,nan,nan,1.0,5.0 212,nan,nan,1.0,5.0 213,nan,nan,1.0,5.0

    Error:

    File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "main.py", line 764, in <module> main() File "main.py", line 560, in main amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) File "main.py", line 637, in train_epoch loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/timm/utils/cuda.py", line 20, in __call__ scaled_loss.backward(create_graph=create_graph) File "/usr/lib/python3.6/contextlib.py", line 88, in __exit__ next(self.gen) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/usr/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 261, in <module> main() File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 257, in main cmd=cmd) subprocess.CalledProcessError: Command '['/home/ekrivosheev/cv_env/bin/python', '-u', 'main.py', '--local_rank=3', '/data/datasets/imagenet-100/', '--model', 'T2t_vit_14', '-b', '128', '--lr', '1e-3', '--weight-decay', '.03', '--cutmix', '0.0', '--reprob', '0.25', '--img-size', '224']' returned non-zero exit status 1.

    opened by Evgeneus 8
  • NAN Loss for provided model

    NAN Loss for provided model

    I trained the model with the following two scripts. Both result nan loss after 1 epoch training. Any thought to address this issue?

    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model T2t_vit_7 -b 64 --lr 1e-3 --weight-decay .03 --cutmix 0.0 --reprob 0.25 --img-size 224

    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model T2t_vit_14 -b 64 --lr 5e-4 --weight-decay .05 --img-size 224

    Training in distributed mode with multiple processes, 1 GPU per process. Process 0, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 6, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 7, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 3, total 8. adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token Training in distributed mode with multiple processes, 1 GPU per process. Process 1, total 8. adopt performer encoder for tokens-to-token Model T2t_vit_14 created, param count: 21545550 Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.9 Using native Torch AMP. Training in mixed precision. Using native Torch DistributedDataParallel. Scheduled epochs: 310 Train: 0 [ 0/2502 ( 0%)] Loss: 7.023479 (7.0235) Time: 3.680s, 139.14/s (3.680s, 139.14/s) LR: 1.000e-06 Data: 1.776 (1.776) Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Train: 0 [ 50/2502 ( 2%)] Loss: 6.971423 (6.9975) Time: 0.323s, 1586.02/s (0.385s, 1330.47/s) LR: 1.000e-06 Data: 0.006 (0.041) Train: 0 [ 100/2502 ( 4%)] Loss: 6.978786 (6.9912) Time: 0.305s, 1679.64/s (0.351s, 1457.64/s) LR: 1.000e-06 Data: 0.006 (0.024) Train: 0 [ 150/2502 ( 6%)] Loss: 6.975621 (6.9873) Time: 0.300s, 1705.67/s (0.340s, 1507.75/s) LR: 1.000e-06 Data: 0.005 (0.018) Train: 0 [ 200/2502 ( 8%)] Loss: 6.966157 (6.9831) Time: 0.360s, 1422.92/s (0.334s, 1530.97/s) LR: 1.000e-06 Data: 0.006 (0.015) Train: 0 [ 250/2502 ( 10%)] Loss: 6.980019 (6.9826) Time: 0.309s, 1657.73/s (0.331s, 1545.27/s) LR: 1.000e-06 Data: 0.005 (0.013) Train: 0 [ 300/2502 ( 12%)] Loss: 6.964942 (6.9801) Time: 0.327s, 1565.87/s (0.329s, 1556.59/s) LR: 1.000e-06 Data: 0.006 (0.012) Train: 0 [ 350/2502 ( 14%)] Loss: 6.957265 (6.9772) Time: 0.332s, 1541.96/s (0.327s, 1563.37/s) LR: 1.000e-06 Data: 0.005 (0.011) Train: 0 [ 400/2502 ( 16%)] Loss: 6.953742 (6.9746) Time: 0.318s, 1609.71/s (0.326s, 1570.11/s) LR: 1.000e-06 Data: 0.006 (0.011) Train: 0 [ 450/2502 ( 18%)] Loss: 6.967467 (6.9739) Time: 0.309s, 1658.46/s (0.325s, 1573.87/s) LR: 1.000e-06 Data: 0.007 (0.010) Train: 0 [ 500/2502 ( 20%)] Loss: 6.970360 (6.9736) Time: 0.322s, 1590.08/s (0.325s, 1577.36/s) LR: 1.000e-06 Data: 0.007 (0.010) Train: 0 [ 550/2502 ( 22%)] Loss: 6.931087 (6.9700) Time: 0.313s, 1637.96/s (0.324s, 1579.20/s) LR: 1.000e-06 Data: 0.005 (0.009) Train: 0 [ 600/2502 ( 24%)] Loss: 6.939621 (6.9677) Time: 0.329s, 1555.19/s (0.324s, 1580.93/s) LR: 1.000e-06 Data: 0.007 (0.009) Train: 0 [ 650/2502 ( 26%)] Loss: 6.943333 (6.9660) Time: 0.318s, 1607.70/s (0.324s, 1582.42/s) LR: 1.000e-06 Data: 0.005 (0.009) Train: 0 [ 700/2502 ( 28%)] Loss: 6.940698 (6.9643) Time: 0.316s, 1621.93/s (0.323s, 1584.56/s) LR: 1.000e-06 Data: 0.006 (0.009) Train: 0 [ 750/2502 ( 30%)] Loss: 6.941026 (6.9628) Time: 0.323s, 1584.28/s (0.323s, 1586.07/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [ 800/2502 ( 32%)] Loss: 6.936088 (6.9612) Time: 0.310s, 1649.05/s (0.323s, 1587.13/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [ 850/2502 ( 34%)] Loss: 6.931849 (6.9596) Time: 0.308s, 1662.24/s (0.322s, 1588.20/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [ 900/2502 ( 36%)] Loss: 6.947849 (6.9590) Time: 0.320s, 1599.60/s (0.322s, 1589.06/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [ 950/2502 ( 38%)] Loss: 6.928242 (6.9575) Time: 0.308s, 1659.89/s (0.322s, 1590.35/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1000/2502 ( 40%)] Loss: 6.926805 (6.9560) Time: 0.310s, 1649.80/s (0.322s, 1591.55/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [1050/2502 ( 42%)] Loss: 6.950564 (6.9557) Time: 0.308s, 1660.43/s (0.322s, 1592.16/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1100/2502 ( 44%)] Loss: 6.930144 (6.9546) Time: 0.300s, 1707.17/s (0.321s, 1593.30/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1150/2502 ( 46%)] Loss: 6.919596 (6.9532) Time: 0.331s, 1547.59/s (0.321s, 1593.54/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1200/2502 ( 48%)] Loss: 6.922656 (6.9520) Time: 0.310s, 1652.26/s (0.321s, 1594.28/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1250/2502 ( 50%)] Loss: 6.919957 (6.9507) Time: 0.311s, 1645.52/s (0.321s, 1595.21/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1300/2502 ( 52%)] Loss: 6.930165 (6.9500) Time: 0.333s, 1539.73/s (0.321s, 1595.62/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1350/2502 ( 54%)] Loss: 6.918827 (6.9488) Time: 0.331s, 1544.88/s (0.321s, 1596.13/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1400/2502 ( 56%)] Loss: 6.923580 (6.9480) Time: 0.311s, 1644.41/s (0.321s, 1596.67/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1450/2502 ( 58%)] Loss: 6.924307 (6.9472) Time: 0.333s, 1538.95/s (0.321s, 1597.32/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1500/2502 ( 60%)] Loss: 6.909927 (6.9460) Time: 0.309s, 1659.58/s (0.320s, 1597.74/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1550/2502 ( 62%)] Loss: 6.924455 (6.9453) Time: 0.339s, 1512.00/s (0.320s, 1598.03/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1600/2502 ( 64%)] Loss: 6.931414 (6.9449) Time: 0.315s, 1623.24/s (0.320s, 1598.55/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1650/2502 ( 66%)] Loss: 6.916759 (6.9441) Time: 0.332s, 1542.18/s (0.320s, 1599.07/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1700/2502 ( 68%)] Loss: 6.941891 (6.9440) Time: 0.314s, 1632.83/s (0.320s, 1599.53/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1750/2502 ( 70%)] Loss: 6.922241 (6.9434) Time: 0.312s, 1640.83/s (0.320s, 1599.91/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1800/2502 ( 72%)] Loss: 6.918221 (6.9427) Time: 0.315s, 1625.92/s (0.320s, 1600.40/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1850/2502 ( 74%)] Loss: 6.903537 (6.9417) Time: 0.322s, 1587.80/s (0.320s, 1600.59/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1900/2502 ( 76%)] Loss: 6.934650 (6.9415) Time: 0.315s, 1623.17/s (0.320s, 1601.00/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1950/2502 ( 78%)] Loss: 6.916628 (6.9409) Time: 0.315s, 1625.91/s (0.320s, 1601.38/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2000/2502 ( 80%)] Loss: 6.907085 (6.9401) Time: 0.302s, 1695.00/s (0.320s, 1601.57/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2050/2502 ( 82%)] Loss: 6.915219 (6.9395) Time: 0.331s, 1547.05/s (0.320s, 1601.70/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2100/2502 ( 84%)] Loss: 6.920197 (6.9390) Time: 0.337s, 1520.82/s (0.320s, 1601.97/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2150/2502 ( 86%)] Loss: 6.924037 (6.9387) Time: 0.325s, 1574.30/s (0.320s, 1602.26/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2200/2502 ( 88%)] Loss: 6.920416 (6.9383) Time: 0.300s, 1705.11/s (0.319s, 1602.63/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2250/2502 ( 90%)] Loss: 6.898316 (6.9374) Time: 0.310s, 1649.44/s (0.319s, 1602.97/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2300/2502 ( 92%)] Loss: 6.924686 (6.9371) Time: 0.309s, 1655.87/s (0.319s, 1602.88/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2350/2502 ( 94%)] Loss: 6.907205 (6.9365) Time: 0.326s, 1572.94/s (0.319s, 1602.90/s) LR: 1.000e-06 Data: 0.005 (0.007) /home/shawn/anaconda3/envs/deit/lib/python3.8/site-packages/PIL/TiffImagePlugin.py:788: UserWarning: Corrupt EXIF data. Expecting to read 4 bytes but only got 0. warnings.warn(str(msg)) Train: 0 [2400/2502 ( 96%)] Loss: 6.908824 (6.9359) Time: 0.310s, 1652.27/s (0.319s, 1603.15/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2450/2502 ( 98%)] Loss: 6.911987 (6.9355) Time: 0.317s, 1615.97/s (0.319s, 1603.37/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2500/2502 (100%)] Loss: 6.918730 (6.9351) Time: 0.312s, 1641.96/s (0.319s, 1603.78/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2501/2502 (100%)] Loss: 6.918357 (6.9348) Time: 0.644s, 795.44/s (0.319s, 1603.13/s) LR: 1.000e-06 Data: 0.344 (0.007) Test: [ 0/97] Time: 1.865 (1.865) Loss: 6.8164 (6.8164) [email protected]: 0.0000 ( 0.0000) [email protected]: 0.0000 ( 0.0000) Test: [ 50/97] Time: 0.100 (0.192) Loss: 6.8828 (6.8914) [email protected]: 0.0000 ( 0.0613) [email protected]: 0.0000 ( 0.5859) Test: [ 97/97] Time: 0.220 (0.162) Loss: 6.7188 (6.8880) [email protected]: 0.0000 ( 0.1820) [email protected]: 0.0000 ( 0.9180) Test (EMA): [ 0/97] Time: 2.051 (2.051) Loss: 7.0312 (7.0312) [email protected]: 0.0000 ( 0.0000) [email protected]: 1.1719 ( 1.1719) Test (EMA): [ 50/97] Time: 0.109 (0.193) Loss: 6.9570 (6.9737) [email protected]: 0.0000 ( 0.1072) [email protected]: 0.0000 ( 0.5093) Test (EMA): [ 97/97] Time: 0.224 (0.163) Loss: 7.0273 (6.9708) [email protected]: 0.0000 ( 0.0900) [email protected]: 0.0000 ( 0.5080) Current checkpoints: ('./output/train/20210219-222319-T2t_vit_14-224/checkpoint-0.pth.tar', 0.09)

    Train: 1 [ 0/2502 ( 0%)] Loss: 6.897799 (6.8978) Time: 2.695s, 189.97/s (2.695s, 189.97/s) LR: 1.673e-04 Data: 2.323 (2.323) Train: 1 [ 50/2502 ( 2%)] Loss: nan ( nan) Time: 0.279s, 1834.73/s (0.337s, 1518.12/s) LR: 1.673e-04 Data: 0.005 (0.051) Train: 1 [ 100/2502 ( 4%)] Loss: nan ( nan) Time: 0.276s, 1857.70/s (0.309s, 1655.29/s) LR: 1.673e-04 Data: 0.006 (0.029) Train: 1 [ 150/2502 ( 6%)] Loss: nan ( nan) Time: 0.289s, 1773.38/s (0.300s, 1705.98/s) LR: 1.673e-04 Data: 0.007 (0.021) Train: 1 [ 200/2502 ( 8%)] Loss: nan ( nan) Time: 0.273s, 1877.76/s (0.295s, 1733.59/s) LR: 1.673e-04 Data: 0.005 (0.018) Train: 1 [ 250/2502 ( 10%)] Loss: nan ( nan) Time: 0.268s, 1912.76/s (0.292s, 1752.17/s) LR: 1.673e-04 Data: 0.005 (0.015) Train: 1 [ 300/2502 ( 12%)] Loss: nan ( nan) Time: 0.285s, 1793.85/s (0.290s, 1764.29/s) LR: 1.673e-04 Data: 0.005 (0.014) Train: 1 [ 350/2502 ( 14%)] Loss: nan ( nan) Time: 0.281s, 1819.69/s (0.289s, 1769.46/s) LR: 1.673e-04 Data: 0.006 (0.013) Train: 1 [ 400/2502 ( 16%)] Loss: nan ( nan) Time: 0.268s, 1908.61/s (0.290s, 1767.59/s) LR: 1.673e-04 Data: 0.005 (0.012) Train: 1 [ 450/2502 ( 18%)] Loss: nan ( nan) Time: 0.287s, 1783.58/s (0.289s, 1773.71/s) LR: 1.673e-04 Data: 0.006 (0.011) Train: 1 [ 500/2502 ( 20%)] Loss: nan ( nan) Time: 0.285s, 1796.56/s (0.288s, 1778.22/s) LR: 1.673e-04 Data: 0.005 (0.011) Train: 1 [ 550/2502 ( 22%)] Loss: nan ( nan) Time: 0.280s, 1825.68/s (0.287s, 1781.91/s) LR: 1.673e-04 Data: 0.005 (0.010) Train: 1 [ 600/2502 ( 24%)] Loss: nan ( nan) Time: 0.275s, 1859.97/s (0.287s, 1785.50/s) LR: 1.673e-04 Data: 0.009 (0.010) Train: 1 [ 650/2502 ( 26%)] Loss: nan ( nan) Time: 0.278s, 1841.99/s (0.286s, 1788.40/s) LR: 1.673e-04 Data: 0.005 (0.010) Train: 1 [ 700/2502 ( 28%)] Loss: nan ( nan) Time: 0.275s, 1860.43/s (0.286s, 1790.68/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 750/2502 ( 30%)] Loss: nan ( nan) Time: 0.287s, 1784.59/s (0.286s, 1792.93/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 800/2502 ( 32%)] Loss: nan ( nan) Time: 0.277s, 1848.72/s (0.285s, 1794.68/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 850/2502 ( 34%)] Loss: nan ( nan) Time: 0.286s, 1792.44/s (0.285s, 1795.76/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 900/2502 ( 36%)] Loss: nan ( nan) Time: 0.279s, 1833.06/s (0.285s, 1795.15/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [ 950/2502 ( 38%)] Loss: nan ( nan) Time: 0.277s, 1847.88/s (0.285s, 1795.23/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1000/2502 ( 40%)] Loss: nan ( nan) Time: 0.286s, 1789.41/s (0.285s, 1796.69/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1050/2502 ( 42%)] Loss: nan ( nan) Time: 0.277s, 1848.11/s (0.285s, 1798.21/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1100/2502 ( 44%)] Loss: nan ( nan) Time: 0.284s, 1799.80/s (0.285s, 1799.40/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1150/2502 ( 46%)] Loss: nan ( nan) Time: 0.285s, 1799.56/s (0.284s, 1800.19/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1200/2502 ( 48%)] Loss: nan ( nan) Time: 0.294s, 1742.39/s (0.284s, 1801.04/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1250/2502 ( 50%)] Loss: nan ( nan) Time: 0.285s, 1796.71/s (0.284s, 1802.07/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1300/2502 ( 52%)] Loss: nan ( nan) Time: 0.274s, 1870.25/s (0.284s, 1802.85/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1350/2502 ( 54%)] Loss: nan ( nan) Time: 0.271s, 1886.95/s (0.284s, 1803.84/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1400/2502 ( 56%)] Loss: nan ( nan) Time: 0.288s, 1776.96/s (0.284s, 1804.18/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1450/2502 ( 58%)] Loss: nan ( nan) Time: 0.282s, 1818.29/s (0.284s, 1802.31/s) LR: 1.673e-04 Data: 0.006 (0.007) Train: 1 [1500/2502 ( 60%)] Loss: nan ( nan) Time: 0.262s, 1952.51/s (0.284s, 1803.01/s) LR: 1.673e-04 Data: 0.007 (0.007)

    opened by yix081 8
  • Cannot get the reported MACs in paper

    Cannot get the reported MACs in paper

    Hi,

    I've calcuated the MACs of the model, and found it is not consistent with the paper reported.

    If I understand correctly, The T2T-ViTt-14 model would have this T2T module and extra 14 original ViT blocks. The MACs for that 14 depth-ViT blocks would be 0.321 x 14 = 4.494 G.

    For the first token-to-token attention, you will calculate attention of 56x56 tokens, which is 3136 tokens, with feature dim=64. Consider only getting the affinity matrix and getting the value would have MACs: 3136 * 3136 * 64 + 3136 * 3136 * 64 = 1.26 G, which already adds up to 5.754 G, higher than the reported 5.2G. My full calculation of the T2T-ViTt-14 model would be 6.09 G MACs. Can you tell me if I miscalculate something?

    Best, Haiping

    opened by happywu 5
  • How to visualize the attention map of t2t-vit?

    How to visualize the attention map of t2t-vit?

    The refered example file was using vit, which provided attention weights output. But the t2t-vit model only have logists output, so i can't reuse their code. I really wish a more detaild way to visualize the attention map of t2t-vit. It's really important to me. thx a lot

    opened by Salen158 4
  • Hard to train

    Hard to train

    Hi.Dear @yuanli2333 I try to use t2t-vit for downstream sem.seg tasks. However ,as we know Vit backbone it's very hard to train. The default settings of train epochs in ImageNet is 300. I have try two different network structure with t2t-vit 14. The 1st train with SGD optimizer and cosine-warmup.After 120 epochs, the loss curves as follow QQ截图20210327144126 The 2nd train with Adam optimizer and cosine-warmup.(not use timm.create_optimizer to set adamw sice i need to set different lr for different blocks.) The set of lr is similar to your setting.After 40 epochs, the loss curves also as follow. QQ截图20210327143246 It's look like that the 2nd training much better and the loss is still in decrease.But I'm not sure is it on the right path.(according to my calculation, it will take 6 days to train 300 epochs with a single 3090 GPU, so I don't have time to trial & error:sob::sob::sob:) Could you show me your training log as a reference or give me some advice? Thank you very much.

    opened by huixiancheng 3
  • small question about lr_scheduler

    small question about lr_scheduler

    Thanks for the opensource code!! Could you tell me the meaning of metric in lr step? https://github.com/yitu-opensource/T2T-ViT/blob/f436fe4043069989ec5e0c2d07407b6d898493a7/main.py#L577-L579 https://github.com/yitu-opensource/T2T-ViT/blob/f436fe4043069989ec5e0c2d07407b6d898493a7/main.py#L688-L689

    In my understanding.Look like in timm it's don't have special meaning.

    opened by huixiancheng 3
  • Questions about feature visualization of vit_large_patch16_384

    Questions about feature visualization of vit_large_patch16_384

    For Figure 2. in the paper, I tried to plot the feature visualization of T2T-ViT-24 trained on ImageNet using the code provided in visualization_vit.ipynb and the same input image “dog.png”. The input image was resized to (1024, 1024), and I found the feature maps have the size of (64, 64). However, the plotted feature maps look very different from those in your paper. The following figure is my feature maps from T2T-ViT-24 block 1:

    layer_0

    There are lots of noises in my feature maps and the low-level structure features such as edges and lines are not clear. I’m not sure what caused the discrepancy. Also, the resolution of feature maps in the paper looks higher that 64*64. Could you please provide more instructions on feature visualization of this model? That would help me understand your work better! Thank you in advance!

    opened by Hongyu-He 3
  • Very low performance within the first 10 epochs

    Very low performance within the first 10 epochs

    @yuanli2333 Really impressive results with fully transformer architecture!

    I have tried to reproduce the results of T2t_vit_t_14, T2t_vit_t_19, and T2t_vit_t_24 while finding their top-1 accuracy is very low within the first few epochs:

    # results based on T2t_vit_t_14
    epoch,train_loss,eval_loss,eval_top1,eval_top5
    0,6.9310056154544535,6.97243375,0.09599999996185303,0.4740000003051758
    1,6.575615681134737,6.9578225,0.104,0.512
    2,6.16587602175199,6.94662625,0.116,0.5079999998855591
    3,5.808463848554171,6.9398025,0.114,0.572
    4,5.42472545000223,6.93104625,0.156,0.632
    5,5.137583054029024,6.92024125,0.142,0.77
    6,4.931810901715205,6.90476,0.194,0.8419999999809266
    7,4.7973018517861,6.874345,0.246,1.044
    8,4.646611140324519,6.82100625,0.358,1.554
    

    where we can see that the top1 accuracy is only 0.358 at the 8-th epoch. I am wondering whether this result is reasonable?

    Thanks!

    opened by PkuRainBow 3
  • The released models seem broken and cannot be opened?

    The released models seem broken and cannot be opened?

    Thanks a lot for sharing your codes. But it seems the T2T-ViT Models are broken and I can't open them. Would you like to upload them again? Many thanks :-)

    opened by QiushiYang 3
  • No forward with  Token_performer?

    No forward with Token_performer?

    I try to run the T2T-ViT, but meet a error: there is no forward on the Token_performer. So will you provide the forward part code for Token_performer? Hoping for your reply.

    opened by GuideWsp 3
  • could you share the log of training T2T-ViT_t-24 and T2T-ViT_t-19?

    could you share the log of training T2T-ViT_t-24 and T2T-ViT_t-19?

    Thanks for your wonderful work, and I wonder whether you could share the log of training T2T-ViT_t-24 and T2T-ViT_t-19? As I want to compare my method with yours from the training point. Many thanks.

    opened by wangpichao 2
  •  Input size is not a square, what should I do in this line?

    Input size is not a square, what should I do in this line?

    Thanks to your excellent work! There has a tensor which size is (3, 56, 112), What should I do in this line to modify ? self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately

    opened by JerryKingQAQ 0
  • colab

    colab

    Hi I want to train your model on my dataset that made from 15000 sample train image and 120000 train image.also I want to use google colab,is it possible to train this model in colab?How do i run your github model in colab?is it better to train model in sample train image or train image?

    opened by Maryam-Hosseini 0
  • Downloads all say

    Downloads all say "tar: This does not look like a tar archive" when I try to un-tar

    I've tried several tools, and downloaded all the files just to see if maybe the first one was corrupt, but every file I try I get this error (from multiple archive utilities including classic tar).

    Anyone else seeing this, or do I just have a problem on my end?

    opened by lowfuel 2
  • [maybe a bug] loss nan

    [maybe a bug] loss nan

    https://github.com/yitu-opensource/T2T-ViT/blob/main/models/token_performer.py#L18 My code has turned on fp16, so the 1e-8 on this line to prevent division by 0 is not enough for my code... the loss of the network calculation appears nan due to this code : https://github.com/yitu-opensource/T2T-ViT/blob/main/models/token_performer.py#L50

    opened by xmy0916 4
Owner
YITUTech
YITUTech
RL agent to play μRTS with Stable-Baselines3

Gym-μRTS with Stable-Baselines3/PyTorch This repo contains an attempt to reproduce Gridnet PPO with invalid action masking algorithm to play μRTS usin

Oleksii Kachaiev 24 Nov 11, 2022
[IEEE Transactions on Computational Imaging] Self-Gated Memory Recurrent Network for Efficient Scalable HDR Deghosting

Few-shot Deep HDR Deghosting This repository contains code and pretrained models for our paper: Self-Gated Memory Recurrent Network for Efficient Scal

Susmit Agrawal 4 Dec 29, 2021
Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment

PENecro This project is based on "Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment", published on hardwear.io USA 202

Ta-Lun Yen 10 May 17, 2022
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

JingZhang 52 Dec 20, 2022
The implementation for "Comprehensive Knowledge Distillation with Causal Intervention".

Comprehensive Knowledge Distillation with Causal Intervention This repository is a PyTorch implementation of "Comprehensive Knowledge Distillation wit

Xiang Deng 10 Nov 03, 2022
GAN-based Matrix Factorization for Recommender Systems

GAN-based Matrix Factorization for Recommender Systems This repository contains the datasets' splits, the source code of the experiments and their res

Ervin Dervishaj 9 Nov 06, 2022
UNAVOIDS: Unsupervised and Nonparametric Approach for Visualizing Outliers and Invariant Detection Scoring

UNAVOIDS: Unsupervised and Nonparametric Approach for Visualizing Outliers and Invariant Detection Scoring Code Summary aggregate.py: this script aggr

1 Dec 28, 2021
Bayesian inference for Permuton-induced Chinese Restaurant Process (NeurIPS2021).

Permuton-induced Chinese Restaurant Process Note: Currently only the Matlab version is available, but a Python version will be available soon! This is

NTT Communication Science Laboratories 3 Dec 17, 2022
Weight estimation in CT by multi atlas techniques

maweight A Python package for multi-atlas based weight estimation for CT images, including segmentation by registration, feature extraction and model

György Kovács 0 Dec 24, 2021
Membership Inference Attack against Graph Neural Networks

MIA GNN Project Starter If you meet the version mismatch error for Lasagne library, please use following command to upgrade Lasagne library. pip insta

6 Nov 09, 2022
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 268 Dec 24, 2022
HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method)

Methods HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method) Dynamically selecting the best propagation method for each node

Yong 7 Dec 18, 2022
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
(CVPR 2022 Oral) Official implementation for "Surface Representation for Point Clouds"

RepSurf - Surface Representation for Point Clouds [CVPR 2022 Oral] By Haoxi Ran* , Jun Liu, Chengjie Wang ( * : corresponding contact) The pytorch off

Haoxi Ran 264 Dec 23, 2022
Fiddle is a Python-first configuration library particularly well suited to ML applications.

Fiddle Fiddle is a Python-first configuration library particularly well suited to ML applications. Fiddle enables deep configurability of parameters i

Google 227 Dec 26, 2022
A curated list of awesome deep long-tailed learning resources.

A curated list of awesome deep long-tailed learning resources.

vanint 210 Dec 25, 2022
Image Processing, Image Smoothing, Edge Detection and Transforms

opevcvdl-hw1 This project uses openCV and Qt to achieve the requirements. Version Python 3.7 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.1

Kenny Cheng 3 Aug 17, 2022
POT : Python Optimal Transport

POT: Python Optimal Transport This open source Python library provide several solvers for optimization problems related to Optimal Transport for signa

Python Optimal Transport 1.7k Dec 31, 2022
Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

Triangle Multiplicative Module - Pytorch Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or c

Phil Wang 22 Oct 28, 2022
Img-process-manual - Utilize Python Numpy and Matplotlib to realize OpenCV baisc image processing function

Img-process-manual - Opencv Library basic graphic processing algorithm coding reproduction based on Numpy and Matplotlib library

Jack_Shaw 2 Dec 12, 2022