Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

Overview

MobileViT

RegNet

Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER.


Table of Contents


Model Architecture

Trulli

MobileViT Architecture

Usage

Training

python main.py
optional arguments:
  -h, --help            show this help message and exit
  --gpu_device GPU_DEVICE
                        Select specific GPU to run the model
  --batch-size N        Input batch size for training (default: 64)
  --epochs N            Number of epochs to train (default: 20)
  --num-class N         Number of classes to classify (default: 10)
  --lr LR               Learning rate (default: 0.01)
  --weight-decay WD     Weight decay (default: 1e-5)
  --model-path PATH     Path to save the model

Citation

@InProceedings{Sachin2021,
  title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER},
  author = {Sachin Mehta and Mohammad Rastegari},
  booktitle = {},
  year = {2021}
}

If this implement have any problem please let me know, thank you.

Comments
  • Training settings

    Training settings

    I really appreciate your efforts in implementing this model in pytorch. Here, I have one concern about the training settings. If what I understand is correct, you just trained the model for less than 5 epoches.

    In addition, the hyper-parameters you adopted is different from that in the original article. For instance, in the original manuscript, authors train mobilevit using AdamW optimizer, label smoothing cross-entry and multi-scale sampler. The training phase has a warmup stage.

    I also found that the classificaion accuracy provided here is much lower than that in the original version.

    I conjecture that the gab between accuracies are caused by different training settings.

    opened by hkzhang91 6
  • load pretrain weight failed

    load pretrain weight failed

    import torch
    import models
    
    model = models.MobileViT_S()
    PATH = "./MobileVit-S.pth.tar"
    weights = torch.load(PATH, map_location=lambda storage, loc: storage)
    model.load_state_dict(weights['state_dict'])
    model.eval()
    torch.save(model, './model.pt')
    
    • I try to load the pre-train weight to test one demo; but the network structure does not seem to match the weights, is there any solution?

    image

    opened by hererookie 2
  • model training hyperparameter

    model training hyperparameter

    A problem has been bothering me. the learning rate, optimizer, batch_size, L2 regularization, label smoothing and epochs are inconsistent with the paper. How should I modify the code?

    opened by Agino-ltp 1
  • Have you test MobileVit on cifar-10?

    Have you test MobileVit on cifar-10?

    Thanks for your wonderful work!

    I prepare to try MobileVit on small dataset, such as MNIST, and I need adjust the network structure. Before this work, I want to know if MobileVit has a better performance than other networks on small dataset.

    I notice "get_cifar10_dataset" in utils.py. Have you tested MobileVit on cifar-10? If you have, could you please show me the accuracy and inference time result?

    opened by Jerryme-xxm 1
  • Issues when loading MobileViT_S()

    Issues when loading MobileViT_S()

    I wanted to load the MobileViT_S() model and use the pre-trained weights, but I have got some errors in my code. To make it easier and help others, I will share my solution (in case there will be someone who is beginner like me):

    def load_mobilevit_weights(model_path):
      # Create an instance of the MobileViT model
      net = MobileViT_S()
      
      # Load the PyTorch state_dict
      state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
      
      # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture
      for key in list(state_dict.keys()):
        state_dict[key.replace('module.', '')] = state_dict.pop(key)
      
      # Once the keys are fixed, we can modify the parameters of MobileViT
      net.load_state_dict(state_dict)
      
      return net
    
    net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar")
    
    opened by Sehaba95 4
Releases(weight)
Owner
Hong-Jia Chen
Master student at National Chung Cheng University, Taiwan. Interested in Deep Learning and Computer Vision.
Hong-Jia Chen
[ICCV'21] Pri3D: Can 3D Priors Help 2D Representation Learning?

Pri3D: Can 3D Priors Help 2D Representation Learning? [ICCV 2021] Pri3D leverages 3D priors for downstream 2D image understanding tasks: during pre-tr

Ji Hou 124 Jan 06, 2023
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 06, 2023
3DMV jointly combines RGB color and geometric information to perform 3D semantic segmentation of RGB-D scans.

3DMV 3DMV jointly combines RGB color and geometric information to perform 3D semantic segmentation of RGB-D scans. This work is based on our ECCV'18 p

Владислав Молодцов 0 Feb 06, 2022
An implementation of the "Attention is all you need" paper without extra bells and whistles, or difficult syntax

Simple Transformer An implementation of the "Attention is all you need" paper without extra bells and whistles, or difficult syntax. Note: The only ex

29 Jun 16, 2022
PyTorch framework, for reproducing experiments from the paper Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks

Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks. Code, based on the PyTorch framework, for reprodu

Asaf 3 Dec 27, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News December 27: v1.1.0 New loss functions: CentroidTripletLoss and VICRegLoss Mean reciprocal rank + per-class accuracies See the release notes Than

Kevin Musgrave 5k Jan 05, 2023
This is an open source python repository for various python tests

Welcome to Py-tests This is an open source python repository for various python tests. This is in response to the hacktoberfest2021 challenge. It is a

Yada Martins Tisan 3 Oct 31, 2021
Avalanche RL: an End-to-End Library for Continual Reinforcement Learning

Avalanche RL: an End-to-End Library for Continual Reinforcement Learning Avalanche Website | Getting Started | Examples | Tutorial | API Doc | Paper |

ContinualAI 43 Dec 24, 2022
[AAAI 2021] EMLight: Lighting Estimation via Spherical Distribution Approximation and [ICCV 2021] Sparse Needlets for Lighting Estimation with Spherical Transport Loss

EMLight: Lighting Estimation via Spherical Distribution Approximation (AAAI 2021) Update 12/2021: We release our Virtual Object Relighting (VOR) Datas

Fangneng Zhan 144 Jan 06, 2023
Implementation of Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN Requirements: Tensorflow r1.0.1 Python 2.7 CUDA 7.5+ (For GPU) Introduction Apply Generative Adversarial Nets to generating sequences of discre

Lantao Yu 2k Dec 29, 2022
CUda Matrix Multiply library.

cumm CUda Matrix Multiply library. cumm is developed during learning of CUTLASS, which use too much c++ template and make code unmaintainable. So I de

49 Dec 27, 2022
Using pretrained language models for biomedical knowledge graph completion.

LMs for biomedical KG completion This repository contains code to run the experiments described in: Scientific Language Models for Biomedical Knowledg

Rahul Nadkarni 41 Nov 30, 2022
Python Implementation of the CoronaWarnApp (CWA) Event Registration

Python implementation of the Corona-Warn-App (CWA) Event Registration This is an implementation of the Protocol used to generate event and location QR

MaZderMind 17 Oct 05, 2022
R interface to fast.ai

R interface to fastai The fastai package provides R wrappers to fastai. The fastai library simplifies training fast and accurate neural nets using mod

113 Dec 20, 2022
HEAM: High-Efficiency Approximate Multiplier Optimization for Deep Neural Networks

Approximate Multiplier by HEAM What's HEAM? HEAM is a general optimization method to generate high-efficiency approximate multipliers for specific app

4 Sep 11, 2022
Fully Convolutional DenseNets for semantic segmentation.

Introduction This repo contains the code to train and evaluate FC-DenseNets as described in The One Hundred Layers Tiramisu: Fully Convolutional Dense

485 Nov 26, 2022
BTC-Generator - BTC Generator With Python

Что такое BTC-Generator? Это генератор чеков всеми любимого @BTC_BANKER_BOT Для

DoomGod 3 Aug 24, 2022
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
Mememoji - A facial expression classification system that recognizes 6 basic emotions: happy, sad, surprise, fear, anger and neutral.

a project built with deep convolutional neural network and ❤️ Table of Contents Motivation The Database The Model 3.1 Input Layer 3.2 Convolutional La

Jostine Ho 761 Dec 05, 2022
Boost learning for GNNs from the graph structure under challenging heterophily settings. (NeurIPS'20)

Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs Jiong Zhu, Yujun Yan, Lingxiao Zhao, Mark Heimann, Leman Akoglu,

GEMS Lab: Graph Exploration & Mining at Scale, University of Michigan 70 Dec 18, 2022