Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Overview

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256
)

x = torch.randint(0, 20000, (1, 256))
emb = model(x) # (1, 256, 512)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Custom image sizes?

    Custom image sizes?

    Hi, Thanks for your great (and very fast) contribution! I was wondering if you could help me figure out how to apply this to a different image size? It's not really an image, but rather a 2D dimensional tensor of 4096X100.

    I saw that I can change the number of channels, so I could just set channels to be 1. But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

    Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

    Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

    Thanks.

    opened by danarte 3
  • Parameter count doesnt line up with paper

    Parameter count doesnt line up with paper

    Just a note (and correct me if I misunderstood the paper) -

    The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult. Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

    Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

    Then param count is roughly 5.5 M params.

    opened by titu1994 2
  • Add Support for Stochastic Depth

    Add Support for Stochastic Depth

    This PR adds support for stochastic depth, which is used in the paper for the vision experiments. I went ahead an added it to gMLP as well for completeness.

    I tried my best to match your style. Let me know if there are any problems, or if you want me to refactor anything.

    opened by mlw214 2
  • Don't you think this is more legible?

    Don't you think this is more legible?

    ` class SpatialGatingUnit(nn.Module): def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3): super().init() dim_out = dim // 2 self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
    
        self.dim_seq = dim_seq
        self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
        self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####
    
        self.act = act
    
        init_eps /= dim_seq
        #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        #nn.init.constant_(self.proj.bias, 1.)
    
    def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
        device, n = x.device, x.shape[1]
    
        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)
    
        weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len
    
        if self.causal:
            weight.unsqueeze(-1) # TODO
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            weight.squeeze(-1)# TODO
    
        gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b
    
        #gate = F.conv1d(gate, weight, bias)   # WZ + b
    
        if exists(gate_res):
            gate = gate + gate_res
    
        return self.act(gate) * res
    

    `

    opened by ZIZUN 0
  • Potentially missing the high way pass

    Potentially missing the high way pass

    Hello,

    Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

    Thank you

    opened by Vincent-Li-9701 1
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
Hierarchical Attentive Recurrent Tracking

Hierarchical Attentive Recurrent Tracking This is an official Tensorflow implementation of single object tracking in videos by using hierarchical atte

Adam Kosiorek 147 Aug 07, 2021
List some popular DeepFake models e.g. DeepFake, FaceSwap-MarekKowal, IPGAN, FaceShifter, FaceSwap-Nirkin, FSGAN, SimSwap, CihaNet, etc.

deepfake-models List some popular DeepFake models e.g. DeepFake, CihaNet, SimSwap, FaceSwap-MarekKowal, IPGAN, FaceShifter, FaceSwap-Nirkin, FSGAN, Si

Mingcan Xiang 100 Dec 17, 2022
Quantized tflite models for ailia TFLite Runtime

ailia-models-tflite Quantized tflite models for ailia TFLite Runtime About ailia TFLite Runtime ailia TF Lite Runtime is a TensorFlow Lite compatible

ax Inc. 13 Dec 23, 2022
Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Perceiver This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on t

Rishit Dagli 84 Oct 15, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022
Fast Neural Representations for Direct Volume Rendering

Fast Neural Representations for Direct Volume Rendering Sebastian Weiss, Philipp Hermüller, Rüdiger Westermann This repository contains the code and s

Sebastian Weiss 20 Dec 03, 2022
Rohit Ingole 2 Mar 24, 2022
Calibrate your listeners! Robust communication-based training for pragmatic speakers. Findings of EMNLP 2021.

Calibrate your listeners! Robust communication-based training for pragmatic speakers Rose E. Wang, Julia White, Jesse Mu, Noah D. Goodman Findings of

Rose E. Wang 3 Apr 02, 2022
Implementation of "Debiasing Item-to-Item Recommendations With Small Annotated Datasets" (RecSys '20)

Debiasing Item-to-Item Recommendations With Small Annotated Datasets This is the code for our RecSys '20 paper. Other materials can be found here: Ful

Microsoft 34 Aug 10, 2022
MultiMix: Sparingly Supervised, Extreme Multitask Learning From Medical Images (ISBI 2021, MELBA 2021)

MultiMix This repository contains the implementation of MultiMix. Our publications for this project are listed below: "MultiMix: Sparingly Supervised,

Ayaan Haque 27 Dec 22, 2022
A simple and extensible library to create Bayesian Neural Network layers on PyTorch.

Blitz - Bayesian Layers in Torch Zoo BLiTZ is a simple and extensible library to create Bayesian Neural Network Layers (based on whats proposed in Wei

Pi Esposito 722 Jan 08, 2023
Utilities and information for the signals.numer.ai tournament

dsignals Utilities and information for the signals.numer.ai tournament using eodhistoricaldata.com eodhistoricaldata.com provides excellent historical

Degerhan Usluel 23 Dec 18, 2022
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks

This is the code associated with the paper Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks, published at CVPR 2020.

Thomas Roddick 219 Dec 20, 2022
Analysis of rationale selection in neural rationale models

Neural Rationale Interpretability Analysis We analyze the neural rationale models proposed by Lei et al. (2016) and Bastings et al. (2019), as impleme

Yiming Zheng 3 Aug 31, 2022
Train/evaluate a Keras model, get metrics streamed to a dashboard in your browser.

Hera Train/evaluate a Keras model, get metrics streamed to a dashboard in your browser. Setting up Step 1. Plant the spy Install the package pip

Keplr 495 Dec 10, 2022
The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color

The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color Overview Code and dataset for The World of an Octopus: H

1 Nov 13, 2021
Hummingbird compiles trained ML models into tensor computation for faster inference.

Hummingbird Introduction Hummingbird is a library for compiling trained traditional ML models into tensor computations. Hummingbird allows users to se

Microsoft 3.1k Dec 30, 2022