Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

Overview

ResMLP - Pytorch

Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch

Install

$ pip install res-mlp-pytorch

Usage

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{touvron2021resmlp,
    title   = {ResMLP: Feedforward networks for image classification with data-efficient training}, 
    author  = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
    year    = {2021},
    eprint  = {2105.03404},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Xview3 solution - XView3 challenge, 2nd place solution
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Unofficial Implementation of MLP-Mixer in TensorFlow
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Comments
  • torch dataset example

    torch dataset example

    I wrote this examples with a data loader:

    import os
    import natsort
    from PIL import Image
    import torch
    import torchvision.transforms as T
    from res_mlp_pytorch.res_mlp_pytorch import ResMLP
    
    class LPCustomDataSet(torch.utils.data.Dataset):
        '''
            Naive Torch Image Dataset Loader
            with support for Image loading errors
            and Image resizing
        '''
        def __init__(self, main_dir, transform):
            self.main_dir = main_dir
            self.transform = transform
            all_imgs = os.listdir(main_dir)
            self.total_imgs = natsort.natsorted(all_imgs)
    
        def __len__(self):
            return len(self.total_imgs)
    
        def __getitem__(self, idx):
            img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
            try:
                image = Image.open(img_loc).convert("RGB")
                tensor_image = self.transform(image)
                return tensor_image
            except:
                pass
                return None
    
        @classmethod
        def collate_fn(self, batch):
            '''
                Collate filtering not None images
            '''
            batch = list(filter(lambda x: x is not None, batch))
            return torch.utils.data.dataloader.default_collate(batch)
    
        @classmethod
        def transform(self,img):
            '''
                Naive image resizer
            '''
            transform = T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            return transform(img)
    

    to feed ResMLP:

    model = ResMLP(
        image_size = 256,
        patch_size = 16,
        dim = 512,
        depth = 12,
        num_classes = 1000
    )
    batch_size = 2
    my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
        os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
    train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False, 
                                   num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
    for idx, img in enumerate(train_loader):
        pred = model(img) # (1, 1000)
        print(idx, img.shape, pred.shape
    

    But I get this error

    RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead
    

    not sure if LPCustomDataSet.transform has the correct for the input image

    opened by loretoparisi 3
  • add dropout and CIFAR100 example notebook

    add dropout and CIFAR100 example notebook

    • According to ResMLP paper, it appears that dropout layer has been implemented in Machine translation when using ResMLP.
    We use Adagrad with learning rate 0.2, 32k steps of linear warmup, label smoothing 0.1, dropout rate 0.15 for En-De and 0.1 for En-Fr.
    
    • Since MLP literatures often mention that MLP is susceptible to overfitting, which is one of the reason why weight decay is so high, implementing dropout will be reasonable choice of regularization.

    Open in Colab | 🔗 Wandb Log

    • Above is my simple experimentation on CIFAR100 dataset, with three different dropout rates: [0.0, 0.25, 0.5].
    • Higher dropout yielded better test metrics(loss, acc1 and acc5).
    opened by snoop2head 0
  • What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    Thanks for your codes!

    I find it is very important to set suitable lr/scheduler/optimizer for training res-mlp models. In my experiments with a small dataset, the classification performance is very poor when I train models with lr=1e-3 or 1e-4, weight-decay=05e-4, scheduler=WarmupCosineLrScheduler, optim='sgd'. The results increase remarkably when lr=5e-3, weight-decay=0.2, scheduler=WarmupCosineLrScheduler, optim='lamb'.

    While the results are still much lower than CNN models with comparable params. trained from scratch. Could you provide any suggestions for training res-mlp?

    opened by QiushiYang 0
Releases(0.0.6)
Owner
Phil Wang
Working with Attention.
Phil Wang
Project ArXiv Citation Network

Project ArXiv Citation Network Overview This project involved the analysis of the ArXiv citation network. Usage The complete code of this project is i

Dennis Núñez-Fernández 5 Oct 20, 2022
A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

Jun-Yan Zhu 27 Aug 08, 2022
Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation.

Understanding Minimum Bayes Risk Decoding This repo provides code and documentation for the following paper: Müller and Sennrich (2021): Understanding

ZurichNLP 13 May 01, 2022
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

Facebook Research 234 Jan 01, 2023
BlockUnexpectedPackets - Preventing BungeeCord CPU overload due to Layer 7 DDoS attacks by scanning BungeeCord's logs

BlockUnexpectedPackets This script automatically blocks DDoS attacks that are sp

SparklyPower 3 Mar 31, 2022
Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Roxbili 5 Nov 19, 2022
PyTorch implementation for Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuous Sign Language Recognition.

Stochastic CSLR This is the PyTorch implementation for the ECCV 2020 paper: Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuou

Zhe Niu 28 Dec 19, 2022
Styled text-to-drawing synthesis method. Featured at the 2021 NeurIPS Workshop on Machine Learning for Creativity and Design

Styled text-to-drawing synthesis method. Featured at the 2021 NeurIPS Workshop on Machine Learning for Creativity and Design

Peter Schaldenbrand 247 Dec 23, 2022
A package for "Procedural Content Generation via Reinforcement Learning" OpenAI Gym interface.

Readme: Illuminating Diverse Neural Cellular Automata for Level Generation This is the codebase used to generate the results presented in the paper av

Sam Earle 27 Jan 05, 2023
Relative Human dataset, CVPR 2022

Relative Human (RH) contains multi-person in-the-wild RGB images with rich human annotations, including: Depth layers (DLs): relative depth relationsh

Yu Sun 112 Dec 02, 2022
Ensembling Off-the-shelf Models for GAN Training

Vision-aided GAN video (3m) | website | paper Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN t

345 Dec 28, 2022
Indonesian Car License Plate Character Recognition using Tensorflow, Keras and OpenCV.

Monopol Indonesian Car License Plate (Indonesia Mobil Nomor Polisi) Character Recognition using Tensorflow, Keras and OpenCV. Background This applicat

Jayaku Briliantio 3 Apr 07, 2022
Generalized hybrid model for mode-locked laser diodes with an extended passive cavity

GenHybridMLLmodel Generalized hybrid model for mode-locked laser diodes with an extended passive cavity This hybrid simulation strategy combines a tra

Stijn Cuyvers 3 Sep 21, 2022
SkipGNN: Predicting Molecular Interactions with Skip-Graph Networks (Scientific Reports)

SkipGNN: Predicting Molecular Interactions with Skip-Graph Networks Molecular interaction networks are powerful resources for the discovery. While dee

Kexin Huang 49 Oct 15, 2022
Official TensorFlow code for the forthcoming paper

~ Efficient-CapsNet ~ Are you tired of over inflated and overused convolutional neural networks? You're right! It's time for CAPSULES :)

Vittorio Mazzia 203 Jan 08, 2023
Package for working with hypernetworks in PyTorch.

Package for working with hypernetworks in PyTorch.

Christian Henning 71 Jan 05, 2023
Deep Learning ❤️ OneFlow

Deep Learning with OneFlow made easy 🚀 ! Carefree? carefree-learn aims to provide CAREFREE usages for both users and developers. User Side Computer V

21 Oct 27, 2022
PyTorch implementation of "ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context" (INTERSPEECH 2020)

ContextNet ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into

Sangchun Ha 24 Nov 24, 2022
Source Code for Simulations in the Publication "Can the brain use waves to solve planning problems?"

Code for Simulations in the Publication Can the brain use waves to solve planning problems? Installing Required Python Packages Please use Python vers

EMD Group 2 Jul 01, 2022
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

SRI Lab, ETH Zurich 25 Sep 14, 2022