Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

Overview

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv

  • 官方原版代码(基于PyTorch)pit.

  • 本项目基于 PaddleViT 实现,在其基础上与原版代码实现了更进一步的对齐,并通过完整训练与测试完成对pit_ti模型的复现.

1. 简介

从CNN的成功设计原理出发,作者研究了空间尺寸转换的作用及其在基于Transformer的体系结构上的有效性。

具体来说,类似于CNN的降维原则(随着深度的增加,传统的CNN会增加通道尺寸并减小空间尺寸),作者用实验表明了这同样有利于Transformer的性能提升,并提出了基于池化的Vision Transformer,即PiT(模型示意图如下)。

drawing

PiT 模型示意图

2. 数据集和复现精度

数据集

原文使用的为ImageNet-1k 2012(ILSVRC2012),共1000类,训练集/测试集图片分布:1281167/50000,数据集大小为144GB。

本项目使用的为官方推荐的图片压缩过的更轻量的Light_ILSVRC2012,数据集大小为65GB。其在AI Studio上的地址为:Light_ILSVRC2012_part_0.tarLight_ILSVRC2012_part_1.tar

复现精度

Model 目标精度[email protected] 实现精度[email protected] Image Size batch_size Crop_pct epoch #Params
pit_ti 73.0 73.01 224 256*4GPUs 0.9 300
(+10 COOLDOWN)
4.8M

【注】上表中的实现精度在原版ILSVRC2012验证集上测试得到。 值得一提的是,本项目在Light_ILSVRC2012的验证集上的Validation [email protected]达到了73.17

本项目训练得到的最佳模型参数与训练日志log均存放于output文件夹下。

日志文件说明

本项目通过AI Studio的脚本任务运行,中途中断了4次,因此共有5个日志文件。为了方便检阅,本人手动将log命名为log_开始epoch-结束epoch.txt格式。具体来说:

  • output/log_1-76.txt:epoch1~epoch76。这一版代码定义每10个epoch保存一次模型权重,每2个epoch验证一次,同时若验证精度高于历史精度,则保存为Best_PiT.pdparams,因此在epoch76训练结束但还未验证的时候中断,下一次的训练只能从验证精度最高的epoch74继续训练。

  • output/log_75-142.txt:epoch75~epoch142。从这一版代码开始,新增了每次训练之后都保存一下模型参数为PiT-Latest.pdparams,这样无论哪个epoch训练中断都可以继续训练啦。

  • output/log_143-225.txt:epoch143~epoch225。

  • output/log_226-303.txt:epoch226~epoch303。

  • output/log_304-310.txt:epoch304~epoch310。

  • output/log_eval.txt:使用训练得到的最好模型(epoch308)在原版ILSVRC2012验证集上验证日志。

3. 准备环境

推荐环境配置:

本人环境配置:

  • 硬件:Tesla V100 * 4(由衷感谢百度飞桨平台提供高性能算力支持)

  • PaddlePaddle==2.2.1

  • Python==3.7

4. 快速开始

本项目现已通过脚本任务形式部署到AI Studio上,您可以选择fork下来直接运行sh run.sh,数据集处理等脚本均已部署好。链接:paddle_pit

或者您也可以git本repo在本地运行:

第一步:克隆本项目

git clone https://github.com/hatimwen/paddle_pit.git
cd paddle_pit

第二步:修改参数

请根据实际情况,修改scripts路径下的脚本内容(如:gpu,数据集路径data_path,batch_size等)。

第三步:验证模型

多卡请运行:

sh scripts/run_eval_multi.sh

单卡请运行:

sh scripts/run_eval.sh

第四步:训练模型

多卡请运行:

sh scripts/run_train_multi.sh

单卡请运行:

sh scripts/run_train.sh

第五步:验证预测

python predict.py \
-pretrained='output/Best_PiT' \
-img_path='images/ILSVRC2012_val_00004506.JPEG'

验证图片(类别:藏獒, id: 244)

输出结果为:

class_id: 244, prob: 9.12291145324707

对照ImageNet类别id(ImageNet数据集编号对应的类别内容),可知244为藏獒,预测结果正确。

5.代码结构

|-- paddle_pit
    |-- output              # 日志及模型文件
    |-- configs             # 参数
        |-- pit_ti.yaml
    |-- datasets
        |-- ImageNet1K      # 数据集路径
    |-- scripts             # 运行脚本
        |-- run_train.sh
        |-- run_train_multi.sh
        |-- run_eval.sh
        |-- run_eval_multi.sh
    |-- augment.py          # 数据增强
    |-- config.py           # 最底层配置文件
    |-- datasets.py         # dataset与dataloader
    |-- droppath.py         # droppath定义
    |-- losses.py           # loss定义
    |-- main_multi_gpu.py   # 多卡训练测试代码
    |-- main_single_gpu.py  # 单卡训练测试代码
    |-- mixup.py            # mixup定义
    |-- model_ema.py        # EMA定义
    |-- pit.py              # pit模型结构定义
    |-- random_erasing.py   # random_erasing定义
    |-- regnet.py           # 教师模型定义(本项目并未对此验证,仅作保留)
    |-- transforms.py       # RandomHorizontalFlip定义
    |-- utils.py            # CosineLRScheduler及AverageMeter定义
    |-- README.md
    |-- requirements.txt

6. 参考及引用

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

最后,非常感谢百度举办的飞桨论文复现挑战赛(第五期)让本人对Paddle理解更加深刻。 同时也非常感谢朱欤老师团队用Paddle实现的PaddleViT,本项目大部分代码都是从中copy来的,而仅仅实现了其与原版代码训练步骤的进一步对齐与完整的训练过程,但本人也同样受益匪浅! ♥️

Contact

Owner
Hongtao Wen
Hongtao Wen
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
This repository contains the implementation of the following paper: Cross-Descriptor Visual Localization and Mapping

Cross-Descriptor Visual Localization and Mapping This repository contains the implementation of the following paper: "Cross-Descriptor Visual Localiza

Mihai Dusmanu 81 Oct 06, 2022
This repository contains implementations and illustrative code to accompany DeepMind publications

DeepMind Research This repository contains implementations and illustrative code to accompany DeepMind publications. Along with publishing papers to a

DeepMind 11.3k Dec 31, 2022
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
Pytorch implementation of Implicit Behavior Cloning.

Implicit Behavior Cloning - PyTorch (wip) Pytorch implementation of Implicit Behavior Cloning. Install conda create -n ibc python=3.8 pip install -r r

Kevin Zakka 49 Dec 25, 2022
A Lightweight Face Recognition and Facial Attribute Analysis (Age, Gender, Emotion and Race) Library for Python

deepface Deepface is a lightweight face recognition and facial attribute analysis (age, gender, emotion and race) framework for python. It is a hybrid

Sefik Ilkin Serengil 5.2k Jan 02, 2023
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 05, 2022
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
Laplacian Score-regularized Concrete Autoencoders

Laplacian Score-regularized Concrete Autoencoders Requirements: torch = 1.9 scikit-learn = 0.24 omegaconf = 2.0.6 scipy = 1.6.0 matplotlib How to

JS 6 Dec 07, 2022
Material related to the Principles of Cloud Computing course.

CloudComputingCourse Material related to the Principles of Cloud Computing course. This repository comprises material that I use to teach my Principle

Aniruddha Gokhale 15 Dec 02, 2022
U-Net Brain Tumor Segmentation

U-Net Brain Tumor Segmentation 🚀 :Feb 2019 the data processing implementation in this repo is not the fastest way (code need update, contribution is

Hao 448 Jan 02, 2023
IOT: Instance-wise Layer Reordering for Transformer Structures

Introduction This repository contains the code for Instance-wise Ordered Transformer (IOT), which is introduced in the ICLR2021 paper IOT: Instance-wi

IOT 19 Nov 15, 2022
Koç University deep learning framework.

Knet Knet (pronounced "kay-net") is the Koç University deep learning framework implemented in Julia by Deniz Yuret and collaborators. It supports GPU

1.4k Dec 31, 2022
OpenMMLab Computer Vision Foundation

English | 简体中文 Introduction MMCV is a foundational library for computer vision research and supports many research projects as below: MMCV: OpenMMLab

OpenMMLab 4.6k Jan 09, 2023
PyTorch ,ONNX and TensorRT implementation of YOLOv4

PyTorch ,ONNX and TensorRT implementation of YOLOv4

4.2k Jan 01, 2023
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022
OpenL3: Open-source deep audio and image embeddings

OpenL3 OpenL3 is an open-source Python library for computing deep audio and image embeddings. Please refer to the documentation for detailed instructi

Music and Audio Research Laboratory - NYU 326 Jan 02, 2023
VR-Caps: A Virtual Environment for Active Capsule Endoscopy

VR-Caps: A Virtual Environment for Capsule Endoscopy Overview We introduce a virtual active capsule endoscopy environment developed in Unity that prov

DeepMIA Lab 90 Dec 27, 2022
For medical image segmentation

LeViT_UNet For medical image segmentation Our model is based on LeViT (https://github.com/facebookresearch/LeViT). You'd better gitclone its codes. Th

13 Dec 24, 2022