Implements MLP-Mixer: An all-MLP Architecture for Vision.

Overview

MLP-Mixer-CIFAR10

This repository implements MLP-Mixer as proposed in MLP-Mixer: An all-MLP Architecture for Vision. The paper introduces an all MLP (Multi-layer Perceptron) architecture for computer vision tasks. Yannic Kilcher walks through the architecture in this video.

Experiments reported in this repository are on CIFAR-10.

What's included?

  • Distributed training with mixed-precision.
  • Visualization of the token-mixing MLP weights.
  • A TensorBoard callback to keep track of the learned linear projections of the image patches.
Screen.Recording.2021-05-25.at.5.49.20.PM.mov

Notebooks

Note: These notebooks are runnable on Colab. If you don't have access to a tensor-core GPU, please disable the mixed-precision block while running the code.

Results

MLP-Mixer achieves competitive results. The figure below summarizes top-1 accuracies on CIFAR-10 test set with respect to varying MLP blocks.


Notable hyperparameters are:

  • Image size: 72x72
  • Patch size: 9x9
  • Hidden dimension for patches: 64
  • Hidden dimension for patches: 128

The table below reports the parameter counts for the different MLP-Mixer variants:


ResNet20 (0.571969 Million) achieves 78.14% under the exact same training configuration. Refer to this notebook for more details.

Models

You can reproduce the results reported above. The model files are available here.

Acknowledgements

ML-GDE Program for providing GCP credits.

You might also like...
An All-MLP solution for Vision, from Google AI
An All-MLP solution for Vision, from Google AI

MLP Mixer - Pytorch An All-MLP solution for Vision, from Google AI, in Pytorch. No convolutions nor attention needed! Yannic Kilcher video Install $ p

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

Model search is a framework that implements AutoML algorithms for model architecture search at scale
Model search is a framework that implements AutoML algorithms for model architecture search at scale

Model search (MS) is a framework that implements AutoML algorithms for model architecture search at scale. It aims to help researchers speed up their exploration process for finding the right model architecture for their classification problems (i.e., DNNs with different types of layers).

A task-agnostic vision-language architecture as a step towards General Purpose Vision
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

code for paper
code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? Code for paper: Does Unsupervised Architecture Representation

Implementation of ResMLP, an all MLP solution to image classification, in Pytorch
Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

ResMLP - Pytorch Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch Install $ pip install res-mlp-py

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions.
Comments
  • Could patches number != MLP token mixing dimension?

    Could patches number != MLP token mixing dimension?

    I try to change the model into B/16 MLP-Mixer. is this setting, the patch number ( sequence length) != MLP token mixing dimension. But the code will report an error when it implements "x = layers.Add()([x, token_mixing])" because the two operation numbers have different shapes. Take an example, B/16 Settings: image 3232, 2D hidden layer 768, PP= 16*16, token mixing mlp dimentsion= 384, channel mlp dimension = 3072. Thus patch number ( sequence length) = 4, table value shape= (4, 768) When the code runs x = layers.Add()([x, token_mixing]) in the token mixing layer. rx shape=[4, 768], token_mixing shape = [384, 768]

    It is strange why the MLP-Mixer paper could set different parameters "patch number ( sequence length) != MLP token mixing dimensio"

    opened by LouiValley 2
  • Why the accuracy drops after epoch 100/100 (accuracy drops from 91% to 71%)

    Why the accuracy drops after epoch 100/100 (accuracy drops from 91% to 71%)

    I trained the Network ( NUM_MIXER_LAYERS =4 )

    At epoch 100:

    Epoch 100/100

    1/44 [..............................] - ETA: 1s - loss: 0.2472 - accuracy: 0.9160 3/44 [=>............................] - ETA: 1s - loss: 0.2424 - accuracy: 0.9162 5/44 [==>...........................] - ETA: 1s - loss: 0.2431 - accuracy: 0.9155 7/44 [===>..........................] - ETA: 1s - loss: 0.2424 - accuracy: 0.9154 9/44 [=====>........................] - ETA: 1s - loss: 0.2419 - accuracy: 0.9155 11/44 [======>.......................] - ETA: 1s - loss: 0.2423 - accuracy: 0.9150 13/44 [=======>......................] - ETA: 1s - loss: 0.2426 - accuracy: 0.9145 15/44 [=========>....................] - ETA: 1s - loss: 0.2430 - accuracy: 0.9142 17/44 [==========>...................] - ETA: 1s - loss: 0.2433 - accuracy: 0.9140 19/44 [===========>..................] - ETA: 1s - loss: 0.2435 - accuracy: 0.9138 21/44 [=============>................] - ETA: 0s - loss: 0.2438 - accuracy: 0.9136 23/44 [==============>...............] - ETA: 0s - loss: 0.2439 - accuracy: 0.9135 25/44 [================>.............] - ETA: 0s - loss: 0.2440 - accuracy: 0.9134 27/44 [=================>............] - ETA: 0s - loss: 0.2440 - accuracy: 0.9133 29/44 [==================>...........] - ETA: 0s - loss: 0.2442 - accuracy: 0.9132 31/44 [====================>.........] - ETA: 0s - loss: 0.2445 - accuracy: 0.9130 33/44 [=====================>........] - ETA: 0s - loss: 0.2447 - accuracy: 0.9129 35/44 [======================>.......] - ETA: 0s - loss: 0.2450 - accuracy: 0.9127 37/44 [========================>.....] - ETA: 0s - loss: 0.2454 - accuracy: 0.9125 39/44 [=========================>....] - ETA: 0s - loss: 0.2459 - accuracy: 0.9123 41/44 [==========================>...] - ETA: 0s - loss: 0.2463 - accuracy: 0.9121 43/44 [============================>.] - ETA: 0s - loss: 0.2469 - accuracy: 0.9119 44/44 [==============================] - 2s 46ms/step - loss: 0.2474 - accuracy: 0.9117 - val_loss: 1.1145 - val_accuracy: 0.7226

    Then it still have an extra training, 1/313 [..............................] - ETA: 24:32 - loss: 0.5860 - accuracy: 0.8125 8/313 [..............................] - ETA: 2s - loss: 1.2071 - accuracy: 0.6953  ..... 313/313 [==============================] - ETA: 0s - loss: 1.0934 - accuracy: 0.7161 313/313 [==============================] - 12s 22ms/step - loss: 1.0934 - accuracy: 0.7161 Test accuracy: 71.61

    opened by LouiValley 1
  • Consider either turning off auto-sharding or switching the auto_shard_policy to DATA

    Consider either turning off auto-sharding or switching the auto_shard_policy to DATA

    Excuse me, when I try to run it on the serve, it tips:

    Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. You can do this by creating a new tf.data.Options() object then setting options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA before applying the options object to the dataset via dataset.with_options(options). 2021-11-21 11:59:20.861052: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

    BTW, my TensorFlow version is 2.4.0, how to fix this problem?

    opened by LouiValley 1
Releases(Models)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust.

Subspace Adversarial Training Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust. However,

15 Sep 02, 2022
Source code for the ACL-IJCNLP 2021 paper entitled "T-DNA: Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adaptation" by Shizhe Diao et al.

T-DNA Source code for the ACL-IJCNLP 2021 paper entitled Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adapta

shizhediao 17 Dec 22, 2022
A library for optimization on Riemannian manifolds

TensorFlow RiemOpt A library for manifold-constrained optimization in TensorFlow. Installation To install the latest development version from GitHub:

Oleg Smirnov 83 Dec 27, 2022
Code release for "COTR: Correspondence Transformer for Matching Across Images"

COTR: Correspondence Transformer for Matching Across Images This repository contains the inference code for COTR. We plan to release the training code

UBC Computer Vision Group 360 Jan 06, 2023
A simple but complete full-attention transformer with a set of promising experimental features from various papers

x-transformers A concise but fully-featured transformer, complete with a set of promising experimental features from various papers. Install $ pip ins

Phil Wang 2.3k Jan 03, 2023
NeuralForecast is a Python library for time series forecasting with deep learning models

NeuralForecast is a Python library for time series forecasting with deep learning models. It includes benchmark datasets, data-loading utilities, evaluation functions, statistical tests, univariate m

Nixtla 1.1k Jan 03, 2023
Final project for machine learning (CSC 590). Detection of hepatitis C and progression through blood samples.

Hepatitis C Blood Based Detection Final project for machine learning (CSC 590). Dataset from Kaggle. Using data from previous hepatitis C blood panels

Jennefer Maldonado 1 Dec 28, 2021
Collection of generative models in Pytorch version.

pytorch-generative-model-collections Original : [Tensorflow version] Pytorch implementation of various GANs. This repository was re-implemented with r

Hyeonwoo Kang 2.4k Dec 31, 2022
This project is used for the paper Differentiable Programming of Isometric Tensor Network

This project is used for the paper "Differentiable Programming of Isometric Tensor Network". (arXiv:2110.03898)

Chenhua Geng 15 Dec 13, 2022
InterfaceGAN++: Exploring the limits of InterfaceGAN

InterfaceGAN++: Exploring the limits of InterfaceGAN Authors: Apavou Clément & Belkada Younes From left to right - Images generated using styleGAN and

Younes Belkada 42 Dec 23, 2022
PECOS - Prediction for Enormous and Correlated Spaces

PECOS - Predictions for Enormous and Correlated Output Spaces PECOS is a versatile and modular machine learning (ML) framework for fast learning and i

Amazon 387 Jan 04, 2023
A pyparsing-based library for parsing SOQL statements

CONTRIBUTORS WANTED!! Installation pip install python-soql-parser or, with poetry poetry add python-soql-parser Usage from python_soql_parser import p

Kicksaw 0 Jun 07, 2022
Computer vision - fun segmentation experience using classic and deep tools :)

Computer_Vision_Segmentation_Fun Segmentation of Images and Video. Tools: pytorch Models: Classic model - GrabCut Deep model - Deeplabv3_resnet101 Flo

Mor Ventura 1 Dec 18, 2021
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 05, 2023
Official Pytorch Implementation of Length-Adaptive Transformer (ACL 2021)

Length-Adaptive Transformer This is the official Pytorch implementation of Length-Adaptive Transformer. For detailed information about the method, ple

Clova AI Research 93 Dec 28, 2022
JUSTICE: A Benchmark Dataset for Supreme Court’s Judgment Prediction

JUSTICE: A Benchmark Dataset for Supreme Court’s Judgment Prediction CSCI 544 Final Project done by: Mohammed Alsayed, Shaayan Syed, Mohammad Alali, S

Smit Patel 3 Dec 28, 2022
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
Towards uncontrained hand-object reconstruction from RGB videos

Towards uncontrained hand-object reconstruction from RGB videos Yana Hasson, Gül Varol, Ivan Laptev and Cordelia Schmid Project page Paper Table of Co

Yana 69 Dec 27, 2022
Semi-Autoregressive Transformer for Image Captioning

Semi-Autoregressive Transformer for Image Captioning Requirements Python 3.6 Pytorch 1.6 Prepare data Please use git clone --recurse-submodules to clo

YE Zhou 23 Dec 09, 2022
Totally Versatile Miscellanea for Pytorch

Totally Versatile Miscellania for PyTorch Thomas Viehmann [email protected] Thi

Thomas Viehmann 428 Dec 28, 2022