PoolFormer: MetaFormer is Actually What You Need for Vision

Overview

PoolFormer: MetaFormer is Actually What You Need for Vision (arXiv)

This is a PyTorch implementation of PoolFormer proposed by our paper "MetaFormer is Actually What You Need for Vision".

MetaFormer

Figure 1: MetaFormer and performance of MetaFormer-based models on ImageNet-1K validation set. We argue that the competence of transformer/MLP-like models primarily stems from the general architecture MetaFormer instead of the equipped specific token mixers. To demonstrate this, we exploit an embarrassingly simple non-parametric operator, pooling, to conduct extremely basic token mixing. Surprisingly, the resulted model PoolFormer consistently outperforms the DeiT and ResMLP as shown in (b), which well supports that MetaFormer is actually what we need to achieve competitive performance.

PoolFormer Figure 2: (a) The overall framework of PoolFormer. (b) The architecture of PoolFormer block. Compared with transformer block, it replaces attention with an extremely simple non-parametric operator, pooling, to conduct only basic token mixing.

Bibtex

@article{yu2021metaformer,
  title={MetaFormer is Actually What You Need for Vision},
  author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
  journal={arXiv preprint arXiv:2111.11418},
  year={2021}
}

1. Requirements

For Image Classification (Configs of detection and segmentation will be available soon)

torch>=1.7.0; torchvision>=0.8.0; pyyaml; apex-amp (if you want to use fp16); timm (pip install git+https://github.com/rwightman/p[email protected])

data prepare: ImageNet with the following folder structure, you can extract ImageNet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Directory structure in this repo:

│poolformer/
├──misc/
├──models/
│  ├── __init__.py
│  ├── poolformer.py
├──LICENSE
├──README.md
├──distributed_train.sh
├──train.py
├──validate.py

2. PoolFormer Models

Model #params Image resolution Top1 Acc Download
poolformer_s12 12M 224 77.2 here
poolformer_s24 21M 224 80.3 here
poolformer_s36 31M 224 81.4 here
poolformer_m36 56M 224 82.1 here
poolformer_m48 73M 224 82.5 here

All the pretrained models can also be downloaded by BaiDu Yun (password: esac).

Update ResNet Scores in the paper

Updated_ResNet_Scores

[1] He et al., "Deep Residual Learning for Image Recognition", CVPR 2016.

[2] Wightman et al., "Resnet strikes back: An improved training procedure in timm", arXiv preprint arXiv:2110.00476. 2021 Oct 1.

Usage

We also provide a Colab notebook which run the steps to perform inference with poolformer.

3. Validation

To evaluate our PoolFormer models, run:

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
python3 validate.py /path/to/imagenet  --model $MODEL \
  --checkpoint /path/to/checkpoint -b 128

4. Train

We show how to train PoolFormers on 8 GPUs. The relation between learning rate and batch size is lr=bs/1024*1e-3. For convenience, assuming the batch size is 1024, then the learning rate is set as 1e-3 (for batch size of 1024, setting the learning rate as 2e-3 sometimes sees better performance).

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
DROP_PATH=0.1 # drop path rates [0.1, 0.1, 0.2, 0.3, 0.4] responding to model [s12, s24, s36, m36, m48]
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model $MODEL -b 128 --lr 1e-3 --drop-path $DROP_PATH --apex-amp

5. Acknowledgment

Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.

pytorch-image-models, mmdetection, mmsegmentation.

Besides, Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources.

LICENSE

This repo is under the Apache-2.0 license. For commercial use, please contact the authors.

Comments
  • About Normalization

    About Normalization

    Hi, thanks for your excellent work. In your ablation studies (section 4.4), you compared Group Normalization (group number is set as 1 for simplicity), Layer Normalization, and Batch Normalization. The conclusion is that Group Normalization is 0.7% or 0.8% higher than Layer Normalization or Batch Normalization. But when the number of groups is 1, Group Normalization is equivalent to Layer Normalization, right?

    opened by tinyalpha 10
  • Addition of the Organization on HuggingFace Transformers

    Addition of the Organization on HuggingFace Transformers

    Hello PoolFormer team!

    I have been working on porting the implementation of PoolFormer to HuggingFace Transformers library (you can see my PR here) and I was wondering if I can go ahead and add Sea AI labs as an organization to the HuggingFace models hub.

    This will allow all model checkpoints to be uploaded onto the hub as well as model cards, etc.

    Kind regards, Tanay Mehta

    opened by heytanay 7
  • How to measure MACs?

    How to measure MACs?

    Hi, thanks for your nice work :) I also watched your presentation record through this conference.

    I want to apply the poolformer for my work, can I ask how did you measure the MACs of the architecture introduced in your paper? Or if you were not bothered, I want to ask if I could be shared your measurement code.

    opened by DoranLyong 5
  • why use use_layer_scale

    why use use_layer_scale

    thanks for your great contribution! in the implement for poolformerblock ,there is a layer_scale after token_mixer. What is the impact of this operation?

    opened by rtfgithub 5
  • Invitation of making PR for OpenMMLab / MMSegmentation.

    Invitation of making PR for OpenMMLab / MMSegmentation.

    Hi, first congrats for acceptance of CVPR'2022. This work deserves because it is very great.

    I am a member of OpenMMLab and mainly work for developing MMSegmentation. I think if it supported officially, many more people would use it for benchmark, which would promote research in computer vision area.

    Would you like to make PR for openmmlab? We could discuss together to refactor your code and use our own GPUs to train & re-implement.

    I think it is pretty cool because it would make more reseachers and community members use this excellent work! Here is our re-implementing work: ConvNeXt.

    We do hope PoolFormer could also be added as backbones in our codebase so that many researchers could use directly it for downstream tasks.

    Looking forward to your reply!

    Best,

    opened by MengzhangLI 5
  • why the speed slower than pvtv2-b1?

    why the speed slower than pvtv2-b1?

    Recently I trained a transformer based instance seg model, tested with different backbone, here is the result and speed test:

    image

    batchsize is training batchsize. Why the speed of poolformer is the slowest one? is that normal?

    Slower than pvtv2-b1 and precision less than it...

    opened by jinfagang 5
  • Checkpoints of the Ablation study

    Checkpoints of the Ablation study

    Hi, thanks for your amazing work. I am reading the Tab 6, and I am surprised because the method is so simple and very effective, especially when the Pooling is replaced with Identity Mapping. Top1 74.3 on ImageNet-1k with only Conv1x1 and Norm layer. I am thrilled... Can you release this checkpoint so that we can verify. Thanks again. image

    opened by chuong98 5
  • Design on positional embedding?

    Design on positional embedding?

    Hello authors,

    I appreciate a lot your current work, which inspired the community. I am here to raise a very simple and quick question after checking the code and architecture design.

    I observed that network using pooling, MLP or identical as token mixer, you do not include positional embedding, while you consider this component only when you use MHA. What is the concern of this design and why other models do not rely on this embedding?

    Best,

    discussion 
    opened by jizongFox 4
  • Error: About self.pool(x)

    Error: About self.pool(x)

    Hello, I am more interested in the poolformer you proposed, but an error occurred during the use of PoolFormerBlock, as follows: Traceback (most recent call last): File "train.py", line 545, in train(hyp, opt, device, tb_writer) File "train.py", line 89, in train model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create File "E:\Work\yolov5\models\yolo.py", line 106, in init m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward File "E:\Work\yolov5\models\yolo.py", line 138, in forward return self.forward_once(x, profile) # single-scale inference, train File "E:\Work\yolov5\models\yolo.py", line 157, in forward_once x = m(x) # run # 执行网络组件操作 File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\common.py", line 194, in forward n = self.token_mixer(m) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\Confor_VC.py", line 93, in forward x1 = self.pool(x) - x # x1 = self.pool(x) - x File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\pooling.py", line 594, in forward return F.avg_pool2d(input, self.kernel_size, self.stride, TypeError: avg_pool2d(): argument 'kernel_size' (position 2) must be tuple of ints, not bool

    I want to put the poolformer behind a ConvBlock and the above problem occurred。 thank you!

    opened by QY1994-0919 4
  • About MLN(Modified Layer Normalization)

    About MLN(Modified Layer Normalization)

    This paper provides new perspectives about Transformer block, but I have some questions about one of the details. As far as I know, the LayerNorm officially provided by Pytorch implements the same function as the MLN, which computes the mean and variance along token and channel dimensions. So where is the improvement? image The official example : #Image Example N, C, H, W = 20, 5, 10, 10 input = torch.randn(N, C, H, W) #Normalize over the last three dimensions (i.e. the channel and spatial dimensions) #as shown in the image below layer_norm = nn.LayerNorm([C, H, W]) output = layer_norm(input)

    opened by youngtboy 3
  • How to achieve the grad-CAM visualization?

    How to achieve the grad-CAM visualization?

    Thanks for your awesome work and for sharing them all.

    I found out that the pictures in the supplement paper are beautiful, and I want to follow this.

    Could you share the code for this? or can tell me how to achieve the grad-CAM activation map?

    opened by DoranLyong 3
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • On the use of Apex AMP and hybrid stages

    On the use of Apex AMP and hybrid stages

    Is there a specific reason why you used Apex AMP instead of the native AMP provided by PyTorch? Have you tried native AMP?

    I tried to train poolformer_s12 and poolformer_s24 with solo-learn; with native fp16 the loss goes to nan after a few epochs, while with fp32 it works fine. Did you experience similar behavior?

    On a side note, can you provide the implementation and the hyperparameters for the hybrid stage [Pool, Pool, Attention, Attention]? It seems very interesting!

    discussion 
    opened by DonkeyShot21 6
  • Can I say PoolFormer is just a non-trainable MLP-like module?

    Can I say PoolFormer is just a non-trainable MLP-like module?

    Hi! Thanks for sharing the great work! I have some questions about PoolFormer. If I explain PoolFormer like the following attachments, can I say PoolFormer is just a non-trainable MLP-like model?

    image image

    discussion 
    opened by 072jiajia 8
  • About subtract in pooling

    About subtract in pooling

    Hi, thank you for publishing such a nice paper. I just have one question. I do not understand the subtraction of the input in eqn.4. Is it necessary? What will happen if we just do the average pooling without substrating the input?

    discussion 
    opened by Dong-Huo 16
Owner
Sea AI Lab
Sea AI Lab
Implementation of accepted AAAI 2021 paper: Deep Unsupervised Image Hashing by Maximizing Bit Entropy

Deep Unsupervised Image Hashing by Maximizing Bit Entropy This is the PyTorch implementation of accepted AAAI 2021 paper: Deep Unsupervised Image Hash

62 Dec 30, 2022
Tensorflow Repo for "DeepGCNs: Can GCNs Go as Deep as CNNs?"

DeepGCNs: Can GCNs Go as Deep as CNNs? In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly re

Guohao Li 612 Nov 15, 2022
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint Louay Hazami   ·   Rayhane Mama   ·   Ragavan Thurairatn

Rayhane Mama 144 Dec 23, 2022
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

41 Jan 02, 2023
Hypernetwork-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels

Hypernet-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels The implementation of Hypernet-Ensemble Le

Sungmin Hong 6 Jul 18, 2022
VarCLR: Variable Semantic Representation Pre-training via Contrastive Learning

    VarCLR: Variable Representation Pre-training via Contrastive Learning New: Paper accepted by ICSE 2022. Preprint at arXiv! This repository contain

squaresLab 32 Oct 24, 2022
A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets

HOW TO USE THIS PROJECT A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets Based on DeepLabCut toolbox, we run wit

1 Jan 10, 2022
A keras implementation of ENet (abandoned for the foreseeable future)

ENet-keras This is an implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from ENet-training (lua-t

Pavlos 115 Nov 23, 2021
COIN the currently largest dataset for comprehensive instruction video analysis.

COIN Dataset COIN is the currently largest dataset for comprehensive instruction video analysis. It contains 11,827 videos of 180 different tasks (i.e

86 Dec 28, 2022
Neon-erc20-example - Example of creating SPL token and wrapping it with ERC20 interface in Neon EVM

Example of wrapping SPL token by ERC2-20 interface in Neon Requirements Install

7 Mar 28, 2022
Awesome Long-Tailed Learning

Awesome Long-Tailed Learning This repo pays specially attention to the long-tailed distribution, where labels follow a long-tailed or power-law distri

Stomach_ache 284 Jan 06, 2023
CCPD: a diverse and well-annotated dataset for license plate detection and recognition

CCPD (Chinese City Parking Dataset, ECCV) UPdate on 10/03/2019. CCPD Dataset is now updated. We are confident that images in subsets of CCPD is much m

detectRecog 1.8k Dec 30, 2022
SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation

SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation This repo is the official implementation for SegTransVAE. Seg

Nguyen Truong Hai 4 Aug 04, 2022
An Implementation of SiameseRPN with Feature Pyramid Networks

SiameseRPN with FPN This project is mainly based on HelloRicky123/Siamese-RPN. What I've done is just add a Feature Pyramid Network method to the orig

3 Apr 16, 2022
A simple AI that will give you si ple task and this is made with python

Crystal-AI A simple AI that will give you si ple task and this is made with python Prerequsites: Python3.6.2 pyttsx3 pip install pyttsx3 pyaudio pip i

CrystalAnd 1 Dec 25, 2021
Moving Object Segmentation in 3D LiDAR Data: A Learning-based Approach Exploiting Sequential Data

LiDAR-MOS: Moving Object Segmentation in 3D LiDAR Data This repo contains the code for our paper: Moving Object Segmentation in 3D LiDAR Data: A Learn

Photogrammetry & Robotics Bonn 394 Dec 29, 2022
For AILAB: Cross Lingual Retrieval on Yelp Search Engine

Cross-lingual Information Retrieval Model for Document Search Train Phase CUDA_VISIBLE_DEVICES="0,1,2,3" \ python -m torch.distributed.launch --nproc_

Chilia Waterhouse 104 Nov 12, 2022
Face Mask Detector by live camera using tensorflow-keras, openCV and Python

Face Mask Detector 😷 by Live Camera Detecting masked or unmasked faces by live camera with percentange of mask occupation About Project: This an Arti

Karan Shingde 2 Apr 04, 2022