Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

Overview

ResMLP - Pytorch

Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch

Install

$ pip install res-mlp-pytorch

Usage

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{touvron2021resmlp,
    title   = {ResMLP: Feedforward networks for image classification with data-efficient training}, 
    author  = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
    year    = {2021},
    eprint  = {2105.03404},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Xview3 solution - XView3 challenge, 2nd place solution
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Unofficial Implementation of MLP-Mixer in TensorFlow
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Comments
  • torch dataset example

    torch dataset example

    I wrote this examples with a data loader:

    import os
    import natsort
    from PIL import Image
    import torch
    import torchvision.transforms as T
    from res_mlp_pytorch.res_mlp_pytorch import ResMLP
    
    class LPCustomDataSet(torch.utils.data.Dataset):
        '''
            Naive Torch Image Dataset Loader
            with support for Image loading errors
            and Image resizing
        '''
        def __init__(self, main_dir, transform):
            self.main_dir = main_dir
            self.transform = transform
            all_imgs = os.listdir(main_dir)
            self.total_imgs = natsort.natsorted(all_imgs)
    
        def __len__(self):
            return len(self.total_imgs)
    
        def __getitem__(self, idx):
            img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
            try:
                image = Image.open(img_loc).convert("RGB")
                tensor_image = self.transform(image)
                return tensor_image
            except:
                pass
                return None
    
        @classmethod
        def collate_fn(self, batch):
            '''
                Collate filtering not None images
            '''
            batch = list(filter(lambda x: x is not None, batch))
            return torch.utils.data.dataloader.default_collate(batch)
    
        @classmethod
        def transform(self,img):
            '''
                Naive image resizer
            '''
            transform = T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            return transform(img)
    

    to feed ResMLP:

    model = ResMLP(
        image_size = 256,
        patch_size = 16,
        dim = 512,
        depth = 12,
        num_classes = 1000
    )
    batch_size = 2
    my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
        os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
    train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False, 
                                   num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
    for idx, img in enumerate(train_loader):
        pred = model(img) # (1, 1000)
        print(idx, img.shape, pred.shape
    

    But I get this error

    RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead
    

    not sure if LPCustomDataSet.transform has the correct for the input image

    opened by loretoparisi 3
  • add dropout and CIFAR100 example notebook

    add dropout and CIFAR100 example notebook

    • According to ResMLP paper, it appears that dropout layer has been implemented in Machine translation when using ResMLP.
    We use Adagrad with learning rate 0.2, 32k steps of linear warmup, label smoothing 0.1, dropout rate 0.15 for En-De and 0.1 for En-Fr.
    
    • Since MLP literatures often mention that MLP is susceptible to overfitting, which is one of the reason why weight decay is so high, implementing dropout will be reasonable choice of regularization.

    Open in Colab | 🔗 Wandb Log

    • Above is my simple experimentation on CIFAR100 dataset, with three different dropout rates: [0.0, 0.25, 0.5].
    • Higher dropout yielded better test metrics(loss, acc1 and acc5).
    opened by snoop2head 0
  • What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    Thanks for your codes!

    I find it is very important to set suitable lr/scheduler/optimizer for training res-mlp models. In my experiments with a small dataset, the classification performance is very poor when I train models with lr=1e-3 or 1e-4, weight-decay=05e-4, scheduler=WarmupCosineLrScheduler, optim='sgd'. The results increase remarkably when lr=5e-3, weight-decay=0.2, scheduler=WarmupCosineLrScheduler, optim='lamb'.

    While the results are still much lower than CNN models with comparable params. trained from scratch. Could you provide any suggestions for training res-mlp?

    opened by QiushiYang 0
Releases(0.0.6)
Owner
Phil Wang
Working with Attention.
Phil Wang
This repository provides data for the VAW dataset as described in the CVPR 2021 paper titled "Learning to Predict Visual Attributes in the Wild"

Visual Attributes in the Wild (VAW) This repository provides data for the VAW dataset as described in the CVPR 2021 Paper: Learning to Predict Visual

Adobe Research 36 Dec 30, 2022
Use deep learning, genetic programming and other methods to predict stock and market movements

StockPredictions Use classic tricks, neural networks, deep learning, genetic programming and other methods to predict stock and market movements. Both

Linda MacPhee-Cobb 386 Jan 03, 2023
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
Learning to Estimate Hidden Motions with Global Motion Aggregation

Learning to Estimate Hidden Motions with Global Motion Aggregation (GMA) This repository contains the source code for our paper: Learning to Estimate

Shihao Jiang (Zac) 221 Dec 18, 2022
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Xin Liu 106 Dec 30, 2022
harmonic-percussive-residual separation algorithm wrapped as a VST3 plugin (iPlug2)

Harmonic-percussive-residual separation plug-in This work is a study on the plausibility of a sines-transients-noise decomposition inspired algorithm

Derp Learning 9 Sep 01, 2022
Small-bets - Ergodic Experiment With Python

Ergodic Experiment Based on this video. Run this experiment with this command: p

Michael Brant 3 Jan 11, 2022
This repository implements Douzero's interface to IGCA.

douzero-interface-for-ICGA This repository implements Douzero's interface to ICGA. ./douzero: This directory stores Doudizhu AI projects. ./interface:

zhanggenjin 4 Aug 07, 2022
Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

GAN stability This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converg

Lars Mescheder 885 Jan 01, 2023
End-to-end beat and downbeat tracking in the time domain.

WaveBeat End-to-end beat and downbeat tracking in the time domain. | Paper | Code | Video | Slides | Setup First clone the repo. git clone https://git

Christian J. Steinmetz 60 Dec 24, 2022
Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks.

Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks. Generally, we intergrete different kind of functional

28 Jan 08, 2023
Real-time pose estimation accelerated with NVIDIA TensorRT

trt_pose Want to detect hand poses? Check out the new trt_pose_hand project for real-time hand pose and gesture recognition! trt_pose is aimed at enab

NVIDIA AI IOT 803 Jan 06, 2023
Mitsuba 2: A Retargetable Forward and Inverse Renderer

Mitsuba Renderer 2 Documentation Mitsuba 2 is a research-oriented rendering system written in portable C++17. It consists of a small set of core libra

Mitsuba Physically Based Renderer 2k Jan 07, 2023
The backbone CSPDarkNet of YOLOX.

YOLOX-Backbone The backbone CSPDarkNet of YOLOX. In this project, you can enjoy: CSPDarkNet-S CSPDarkNet-M CSPDarkNet-L CSPDarkNet-X CSPDarkNet-Tiny C

Jianhua Yang 9 Aug 22, 2022
Transfer SemanticKITTI labeles into other dataset/sensor formats.

LiDAR-Transfer Transfer SemanticKITTI labeles into other dataset/sensor formats. Content Convert datasets (NUSCENES, FORD, NCLT) to KITTI format Minim

Photogrammetry & Robotics Bonn 64 Nov 21, 2022
An attempt at the implementation of GLOM, Geoffrey Hinton's paper for emergent part-whole hierarchies from data

GLOM TensorFlow This Python package attempts to implement GLOM in TensorFlow, which allows advances made by several different groups transformers, neu

Rishit Dagli 32 Feb 21, 2022
Auto Seg-Loss: Searching Metric Surrogates for Semantic Segmentation

Auto-Seg-Loss By Hao Li, Chenxin Tao, Xizhou Zhu, Xiaogang Wang, Gao Huang, Jifeng Dai This is the official implementation of the ICLR 2021 paper Auto

61 Dec 21, 2022
Exadel CompreFace is a free and open-source face recognition GitHub project

Exadel CompreFace is a leading free and open-source face recognition system Exadel CompreFace is a free and open-source face recognition service that

Exadel 2.6k Jan 04, 2023
UNAVOIDS: Unsupervised and Nonparametric Approach for Visualizing Outliers and Invariant Detection Scoring

UNAVOIDS: Unsupervised and Nonparametric Approach for Visualizing Outliers and Invariant Detection Scoring Code Summary aggregate.py: this script aggr

1 Dec 28, 2021
Iowa Project - My second project done at General Assembly, focused on feature engineering and understanding Linear Regression as a concept

Project 2 - Ames Housing Data and Kaggle Challenge PROBLEM STATEMENT Inferring or Predicting? What's more valuable for a housing model? When creating

Adam Muhammad Klesc 1 Jan 03, 2022