Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Overview

Segformer - Pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch.

Install

$ pip install segformer-pytorch

Usage

For example, MiT-B0

import torch
from segformer_pytorch import Segformer

model = Segformer(
    patch_size = 4,                 # patch size
    dims = (32, 64, 160, 256),      # dimensions of each stage
    heads = (1, 2, 5, 8),           # heads of each stage
    ff_expansion = (8, 8, 4, 4),    # feedforward expansion factor of each stage
    reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
    num_layers = 2,                 # num layers of each stage
    decoder_dim = 256,              # decoder dimension
    num_classes = 4                 # number of segmentation classes
)

x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 4, 64, 64)  # output is (H/4, W/4) map of the number of segmentation classes

Make sure the keywords are at most a tuple of 4, as this repository is hard-coded to give the MiT 4 stages as done in the paper.

Citations

@misc{xie2021segformer,
    title   = {SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers}, 
    author  = {Enze Xie and Wenhai Wang and Zhiding Yu and Anima Anandkumar and Jose M. Alvarez and Ping Luo},
    year    = {2021},
    eprint  = {2105.15203},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Pytorch implementation of
Pytorch implementation of "Attention-Based Recurrent Neural Network Models for Joint Intent Detection and Slot Filling"

RNN-for-Joint-NLU Pytorch implementation of "Attention-Based Recurrent Neural Network Models for Joint Intent Detection and Slot Filling"

Unofficial Implementation of MLP-Mixer in TensorFlow
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

Unofficial Implementation of MLP-Mixer, Image Classification Model
Unofficial Implementation of MLP-Mixer, Image Classification Model

MLP-Mixer Unoffical Implementation of MLP-Mixer, easy to use with terminal. Train and test easly. https://arxiv.org/abs/2105.01601 MLP-Mixer is an arc

MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Comments
  • Something is wrong with your implementation.

    Something is wrong with your implementation.

    Hello!

    First of all, I really like the repo. The implementation is clean and so much easier to understand than the official repo. But after doing some digging, I realized that the number of parameters and layers (especially conv2d) is quite different from the official implementation. This is the case for all variants I have tested (B0 and B5).

    Check out the README in my repo here, and you'll see what I mean. I also included images of the execution graphs of the two different implementations in the 'src' folder, which could help to debug.

    I don't quite have time to dig into the source of the problem, but I just thought I'd share my observations with you.

    opened by camlaedtke 0
  • Models weights + model output HxW

    Models weights + model output HxW

    Hi,

    Could you please add the models weights so we can start training from them?

    Also, why you choose to train models with an output of size (H/4,W/4) and not the original (HxW) size?

    Great job for the paper, very interesting :)

    opened by isega24 2
  • The model configurations for all the SegFormer B0 ~ B5

    The model configurations for all the SegFormer B0 ~ B5

    Hello How are you? Thanks for contributing to this project. Is the model configuration in README MiT-B0 correctly? That's because the total number of params for the model is 36M. Could u provide all the model configurations for SegFormer B0 ~ B5?

    opened by rose-jinyang 5
  • a question about kv reshape in Efficient Self-Attention

    a question about kv reshape in Efficient Self-Attention

    Thanks for sharing your work, your code is so elegant, and inspired me a lot. Here is a question about the implementation of Efficient Self-Attention

    It seems you use a "mean op" to reshape k,v. and the official implementation uses a (learnable) linear mapping to reshape k,v

    may I ask, whether this difference significantly matters in your experiment ?

    in your code:

    k, v = map(lambda t: reduce(t, 'b c (h r1) (w r2) -> b c h w', 'mean', r1 = r, r2 = r), (k, v))
    

    the original implementation uses:

    self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
    self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
    self.norm = nn.LayerNorm(dim)
    
    x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
    x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
    x_ = self.norm(x_)
    kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    k, v = kv[0], kv[1]
    
    opened by masszhou 1
Releases(0.0.6)
Owner
Phil Wang
Working with Attention
Phil Wang
A repository for benchmarking neural vocoders by their quality and speed.

License The majority of VocBench is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Wavenet, Para

Meta Research 177 Dec 12, 2022
On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation On Nonlinear Latent Transformations for GAN-based Image Editi

Valentin Khrulkov 22 Oct 24, 2022
An algorithmic trading bot that learns and adapts to new data and evolving markets using Financial Python Programming and Machine Learning.

ALgorithmic_Trading_with_ML An algorithmic trading bot that learns and adapts to new data and evolving markets using Financial Python Programming and

1 Mar 14, 2022
Code accompanying our paper Feature Learning in Infinite-Width Neural Networks

Empirical Experiments in "Feature Learning in Infinite-width Neural Networks" This repo contains code to replicate our experiments (Word2Vec, MAML) in

Edward Hu 37 Dec 14, 2022
D-NeRF: Neural Radiance Fields for Dynamic Scenes

D-NeRF: Neural Radiance Fields for Dynamic Scenes [Project] [Paper] D-NeRF is a method for synthesizing novel views, at an arbitrary point in time, of

Albert Pumarola 291 Jan 02, 2023
An efficient PyTorch implementation of the evaluation metrics in recommender systems.

recsys_metrics An efficient PyTorch implementation of the evaluation metrics in recommender systems. Overview • Installation • How to use • Benchmark

Xingdong Zuo 12 Dec 02, 2022
Asynchronous Advantage Actor-Critic in PyTorch

Asynchronous Advantage Actor-Critic in PyTorch This is PyTorch implementation of A3C as described in Asynchronous Methods for Deep Reinforcement Learn

Reiji Hatsugai 38 Dec 12, 2022
This is a template for the Non-autoregressive Deep Learning-Based TTS model (in PyTorch).

Non-autoregressive Deep Learning-Based TTS Template This is a template for the Non-autoregressive TTS model. It contains Data Preprocessing Pipeline D

Keon Lee 13 Dec 05, 2022
Repository for "Exploring Sparsity in Image Super-Resolution for Efficient Inference", CVPR 2021

SMSR Reposity for "Exploring Sparsity in Image Super-Resolution for Efficient Inference" [arXiv] Highlights Locate and skip redundant computation in S

Longguang Wang 225 Dec 26, 2022
Reaction SMILES-AA mapping via language modelling

rxn-aa-mapper Reactions SMILES-AA sequence mapping setup conda env create -f conda.yml conda activate rxn_aa_mapper In the following we consider on ex

16 Dec 13, 2022
Model Zoo of BDD100K Dataset

Model Zoo of BDD100K Dataset

ETH VIS Group 200 Dec 27, 2022
a simple, efficient, and intuitive text editor

Oxygen beta a simple, efficient, and intuitive text editor Overview oxygen is a simple, efficient, and intuitive text editor designed as more featured

Aarush Gupta 1 Feb 23, 2022
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation This repository contains the source code of our paper, ESPNet (acc

Sachin Mehta 515 Dec 13, 2022
PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition

PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition The unofficial code of CDistNet. Now, we ha

25 Jul 20, 2022
PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation Viraj Prabhu, Shivam Khare, Deeks

Viraj Prabhu 46 Dec 24, 2022
Neural Message Passing for Computer Vision

Neural Message Passing for Quantum Chemistry Implementation of different models of Neural Networks on graphs as explained in the article proposed by G

Pau Riba 310 Nov 07, 2022
Incorporating Transformer and LSTM to Kalman Filter with EM algorithm

Deep learning based state estimation: incorporating Transformer and LSTM to Kalman Filter with EM algorithm Overview Kalman Filter requires the true p

zshicode 57 Dec 27, 2022
A visualization tool to show a TensorFlow's graph like TensorBoard

tfgraphviz tfgraphviz is a module to visualize a TensorFlow's data flow graph like TensorBoard using Graphviz. tfgraphviz enables to provide a visuali

44 Nov 09, 2022
PyTorch implementation of the Pose Residual Network (PRN)

Pose Residual Network This repository contains a PyTorch implementation of the Pose Residual Network (PRN) presented in our ECCV 2018 paper: Muhammed

Salih Karagoz 289 Nov 28, 2022
Data pipelines for both TensorFlow and PyTorch!

rapidnlp-datasets Data pipelines for both TensorFlow and PyTorch ! If you want to load public datasets, try: tensorflow/datasets huggingface/datasets

1 Dec 08, 2021