Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

Overview

ResMLP - Pytorch

Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch

Install

$ pip install res-mlp-pytorch

Usage

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{touvron2021resmlp,
    title   = {ResMLP: Feedforward networks for image classification with data-efficient training}, 
    author  = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
    year    = {2021},
    eprint  = {2105.03404},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Xview3 solution - XView3 challenge, 2nd place solution
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Unofficial Implementation of MLP-Mixer in TensorFlow
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Comments
  • torch dataset example

    torch dataset example

    I wrote this examples with a data loader:

    import os
    import natsort
    from PIL import Image
    import torch
    import torchvision.transforms as T
    from res_mlp_pytorch.res_mlp_pytorch import ResMLP
    
    class LPCustomDataSet(torch.utils.data.Dataset):
        '''
            Naive Torch Image Dataset Loader
            with support for Image loading errors
            and Image resizing
        '''
        def __init__(self, main_dir, transform):
            self.main_dir = main_dir
            self.transform = transform
            all_imgs = os.listdir(main_dir)
            self.total_imgs = natsort.natsorted(all_imgs)
    
        def __len__(self):
            return len(self.total_imgs)
    
        def __getitem__(self, idx):
            img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
            try:
                image = Image.open(img_loc).convert("RGB")
                tensor_image = self.transform(image)
                return tensor_image
            except:
                pass
                return None
    
        @classmethod
        def collate_fn(self, batch):
            '''
                Collate filtering not None images
            '''
            batch = list(filter(lambda x: x is not None, batch))
            return torch.utils.data.dataloader.default_collate(batch)
    
        @classmethod
        def transform(self,img):
            '''
                Naive image resizer
            '''
            transform = T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            return transform(img)
    

    to feed ResMLP:

    model = ResMLP(
        image_size = 256,
        patch_size = 16,
        dim = 512,
        depth = 12,
        num_classes = 1000
    )
    batch_size = 2
    my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
        os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
    train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False, 
                                   num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
    for idx, img in enumerate(train_loader):
        pred = model(img) # (1, 1000)
        print(idx, img.shape, pred.shape
    

    But I get this error

    RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead
    

    not sure if LPCustomDataSet.transform has the correct for the input image

    opened by loretoparisi 3
  • add dropout and CIFAR100 example notebook

    add dropout and CIFAR100 example notebook

    • According to ResMLP paper, it appears that dropout layer has been implemented in Machine translation when using ResMLP.
    We use Adagrad with learning rate 0.2, 32k steps of linear warmup, label smoothing 0.1, dropout rate 0.15 for En-De and 0.1 for En-Fr.
    
    • Since MLP literatures often mention that MLP is susceptible to overfitting, which is one of the reason why weight decay is so high, implementing dropout will be reasonable choice of regularization.

    Open in Colab | 🔗 Wandb Log

    • Above is my simple experimentation on CIFAR100 dataset, with three different dropout rates: [0.0, 0.25, 0.5].
    • Higher dropout yielded better test metrics(loss, acc1 and acc5).
    opened by snoop2head 0
  • What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    Thanks for your codes!

    I find it is very important to set suitable lr/scheduler/optimizer for training res-mlp models. In my experiments with a small dataset, the classification performance is very poor when I train models with lr=1e-3 or 1e-4, weight-decay=05e-4, scheduler=WarmupCosineLrScheduler, optim='sgd'. The results increase remarkably when lr=5e-3, weight-decay=0.2, scheduler=WarmupCosineLrScheduler, optim='lamb'.

    While the results are still much lower than CNN models with comparable params. trained from scratch. Could you provide any suggestions for training res-mlp?

    opened by QiushiYang 0
Releases(0.0.6)
Owner
Phil Wang
Working with Attention.
Phil Wang
Some pre-commit hooks for OpenMMLab projects

pre-commit-hooks Some pre-commit hooks for OpenMMLab projects. Using pre-commit-hooks with pre-commit Add this to your .pre-commit-config.yaml - rep

OpenMMLab 16 Nov 29, 2022
Artifacts for paper "MMO: Meta Multi-Objectivization for Software Configuration Tuning"

MMO: Meta Multi-Objectivization for Software Configuration Tuning This repository contains the data and code for the following paper that is currently

0 Nov 17, 2021
A Python library for common tasks on 3D point clouds

Point Cloud Utils (pcu) - A Python library for common tasks on 3D point clouds Point Cloud Utils (pcu) is a utility library providing the following fu

Francis Williams 622 Dec 27, 2022
NudeNet: Neural Nets for Nudity Classification, Detection and selective censoring

NudeNet: Neural Nets for Nudity Classification, Detection and selective censoring Uncensored version of the following image can be found at https://i.

notAI.tech 1.1k Dec 29, 2022
A curated list of Machine Learning and Deep Learning tutorials in Jupyter Notebook format ready to run in Google Colaboratory

Awesome Machine Learning Jupyter Notebooks for Google Colaboratory A curated list of Machine Learning and Deep Learning tutorials in Jupyter Notebook

Carlos Toxtli 245 Jan 01, 2023
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
ECCV2020 paper: Fashion Captioning: Towards Generating Accurate Descriptions with Semantic Rewards. Code and Data.

This repo contains some of the codes for the following paper Fashion Captioning: Towards Generating Accurate Descriptions with Semantic Rewards. Code

Xuewen Yang 56 Dec 08, 2022
Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling

TGraM Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling, Qibin He, Xian Sun, Zhiyuan Yan, Beibei Li, Kun Fu Abstract Rece

Qibin He 6 Nov 25, 2022
Referring Video Object Segmentation

Awesome-Referring-Video-Object-Segmentation Welcome to starts ⭐ & comments 💹 & sharing 😀 !! - 2021.12.12: Recent papers (from 2021) - welcome to ad

Explorer 57 Dec 11, 2022
Deep Learning Models for Causal Inference

Extensive tutorials for learning how to build deep learning models for causal inference using selection on observables in Tensorflow 2.

Bernard J Koch 151 Dec 31, 2022
Evaluating AlexNet features at various depths

Linear Separability Evaluation This repo provides the scripts to test a learned AlexNet's feature representation performance at the five different con

Yuki M. Asano 32 Dec 30, 2022
HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps.

HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps. 中文介绍 Features Non-intrusive. Your iOS project does not need to be modi

mao2020 47 Oct 22, 2022
A pyparsing-based library for parsing SOQL statements

CONTRIBUTORS WANTED!! Installation pip install python-soql-parser or, with poetry poetry add python-soql-parser Usage from python_soql_parser import p

Kicksaw 0 Jun 07, 2022
Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Note: this repo has been discontinued, please check code for newer version of the paper here Weight Normalized GAN Code for the paper "On the Effects

Sitao Xiang 182 Sep 06, 2021
Distilled coarse part of LoFTR adapted for compatibility with TensorRT and embedded divices

Coarse LoFTR TRT Google Colab demo notebook This project provides a deep learning model for the Local Feature Matching for two images that can be used

Kirill 46 Dec 24, 2022
Autonomous racing with the Anki Overdrive

Anki Autonomous Racing Autonomous racing with the Anki Overdrive. Using the Overdrive-Python API (https://github.com/xerodotc/overdrive-python) develo

3 Dec 11, 2022
This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

AdapterHub 18 Dec 09, 2022
Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks

Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks by Ángel López García-Arias, Masanori Hashimoto, Masato Motomura, and J

Ángel López García-Arias 4 May 19, 2022
The object detection pipeline is based on Ultralytics YOLOv5

AYOLOv2 The main goal of this repository is to rewrite the object detection pipeline with a better code structure for better portability and adaptabil

153 Dec 22, 2022
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the in

677 Dec 28, 2022