Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Overview

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256
)

x = torch.randint(0, 20000, (1, 256))
emb = model(x) # (1, 256, 512)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Custom image sizes?

    Custom image sizes?

    Hi, Thanks for your great (and very fast) contribution! I was wondering if you could help me figure out how to apply this to a different image size? It's not really an image, but rather a 2D dimensional tensor of 4096X100.

    I saw that I can change the number of channels, so I could just set channels to be 1. But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

    Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

    Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

    Thanks.

    opened by danarte 3
  • Parameter count doesnt line up with paper

    Parameter count doesnt line up with paper

    Just a note (and correct me if I misunderstood the paper) -

    The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult. Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

    Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

    Then param count is roughly 5.5 M params.

    opened by titu1994 2
  • Add Support for Stochastic Depth

    Add Support for Stochastic Depth

    This PR adds support for stochastic depth, which is used in the paper for the vision experiments. I went ahead an added it to gMLP as well for completeness.

    I tried my best to match your style. Let me know if there are any problems, or if you want me to refactor anything.

    opened by mlw214 2
  • Don't you think this is more legible?

    Don't you think this is more legible?

    ` class SpatialGatingUnit(nn.Module): def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3): super().init() dim_out = dim // 2 self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
    
        self.dim_seq = dim_seq
        self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
        self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####
    
        self.act = act
    
        init_eps /= dim_seq
        #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        #nn.init.constant_(self.proj.bias, 1.)
    
    def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
        device, n = x.device, x.shape[1]
    
        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)
    
        weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len
    
        if self.causal:
            weight.unsqueeze(-1) # TODO
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            weight.squeeze(-1)# TODO
    
        gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b
    
        #gate = F.conv1d(gate, weight, bias)   # WZ + b
    
        if exists(gate_res):
            gate = gate + gate_res
    
        return self.act(gate) * res
    

    `

    opened by ZIZUN 0
  • Potentially missing the high way pass

    Potentially missing the high way pass

    Hello,

    Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

    Thank you

    opened by Vincent-Li-9701 1
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Code for our paper "MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction" published at ICCV 2021.

MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction This repository contains the code for the p

Sven 30 Jan 05, 2023
This is an implementation for the CVPR2020 paper "Learning Invariant Representation for Unsupervised Image Restoration"

Learning Invariant Representation for Unsupervised Image Restoration (CVPR 2020) Introduction This is an implementation for the paper "Learning Invari

GarField 88 Nov 07, 2022
Self-supervised Augmentation Consistency for Adapting Semantic Segmentation (CVPR 2021)

Self-supervised Augmentation Consistency for Adapting Semantic Segmentation This repository contains the official implementation of our paper: Self-su

Visual Inference Lab @TU Darmstadt 132 Dec 21, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News December 27: v1.1.0 New loss functions: CentroidTripletLoss and VICRegLoss Mean reciprocal rank + per-class accuracies See the release notes Than

Kevin Musgrave 5k Jan 05, 2023
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Emotion classification of online comments based on RNN

emotion_classification Emotion classification of online comments based on RNN, the accuracy of the model in the test set reaches 99% data: Large Movie

1 Nov 23, 2021
A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Hyunsoo Cho 1 Dec 20, 2021
A computational block to solve entity alignment over textual attributes in a knowledge graph creation pipeline.

How to apply? Create your config.ini file following the example provided in config.ini Choose one of the options below to run: Run with Python3 pip in

Scientific Data Management Group 3 Jun 23, 2022
Learning multiple gaits of quadruped robot using hierarchical reinforcement learning

Learning multiple gaits of quadruped robot using hierarchical reinforcement learning We propose a method to learn multiple gaits of quadruped robot us

Yunho Kim 17 Dec 11, 2022
Display, filter and search log messages in your terminal

Textualog Display, filter and search logging messages in the terminal. This project is powered by rich and textual. Some of the ideas and code in this

Rik Huygen 24 Dec 10, 2022
Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch Reference Paper URL Author: Yi Tay, Dara Bahri, Donald Metzler

Myeongjun Kim 66 Nov 30, 2022
A toolkit for Lagrangian-based constrained optimization in Pytorch

Cooper About Cooper is a toolkit for Lagrangian-based constrained optimization in Pytorch. This library aims to encourage and facilitate the study of

Cooper 34 Jan 01, 2023
LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation

LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation Table of Contents: Introduction Project Structure Installation Datas

Yu Wang 492 Dec 02, 2022
Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System

Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System Authors: Yixuan Su, Lei Shu, Elman Mansimov, Arshit Gupta, Deng Cai, Yi-An Lai

Amazon Web Services - Labs 123 Dec 23, 2022
TreeSubstitutionCipher - Encryption system based on trees and substitution

Tree Substitution Cipher Generation Algorithm: Generate random tree. Tree nodes

stepa 1 Jan 08, 2022
A simple, unofficial implementation of MAE using pytorch-lightning

Masked Autoencoders in PyTorch A simple, unofficial implementation of MAE (Masked Autoencoders are Scalable Vision Learners) using pytorch-lightning.

Connor Anderson 20 Dec 03, 2022
ToFFi - Toolbox for Frequency-based Fingerprinting of Brain Signals

ToFFi Toolbox This repository contains "before peer review" version of the software related to the preprint of the publication ToFFi - Toolbox for Fre

4 Aug 31, 2022
Official Pytorch implementation of ICLR 2018 paper Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge.

Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge: Official Pytorch implementation of ICLR 2018 paper Deep Learning for Phy

emmanuel 47 Nov 06, 2022