A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

You might also like...
A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

SimMIM: A Simple Framework for Masked Image Modeling
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

SeMask: Semantically Masked Transformers for Semantic Segmentation.
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

FocusFace: Multi-task Contrastive Learning for Masked Face Recognition
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Comments
  • Excellent work (`mae.ipynb`)!

    Excellent work (`mae.ipynb`)!

    @ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

    • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
    • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
    • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

    After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

    opened by sayakpaul 7
  • Unshuffle the patches?

    Unshuffle the patches?

    Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation, decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
    no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

    opened by changtaoli 2
  • Could you also share the weight of the pretrained decoder?

    Could you also share the weight of the pretrained decoder?

    Hi,

    Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

    Best Regards, Hongxin

    opened by hongxin001 1
  • Issue with the plotting utility `show_masked_image`

    Issue with the plotting utility `show_masked_image`

    Should be:

    def show_masked_image(self, patches):
            # Utility function that helps visualize maksed images.
            _, unmask_indices = self.get_random_indices()
            unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)
    
            # Necessary for plotting.
            ids = tf.argsort(unmask_indices)
            sorted_unmask_indices = tf.sort(unmask_indices)
            unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)
    
            # Select a random index for visualization.
            idx = np.random.choice(len(sorted_unmask_indices))
            print(f"Index selected: {idx}.")
    
            n = int(np.sqrt(NUM_PATCHES))
            unmask_index = sorted_unmask_indices[idx]
            unmasked_patch = unmasked_patches[idx]
    
            plt.figure(figsize=(4, 4))
    
            count = 0
            for i in range(NUM_PATCHES):
                ax = plt.subplot(n, n, i + 1)
    
                if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                    patch = unmasked_patch[count]
                    patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
                    count = count + 1
                else:
                    patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
            plt.show()
    
            # Return the random index to validate the image outside the method.
            return idx
    
    opened by ariG23498 1
Releases(v1.0.0)
Owner
Aritra Roy Gosthipaty
Learning with a learning rate of 1e-10.
Aritra Roy Gosthipaty
商品推荐系统

商品top50推荐系统 问题建模 本项目的数据集给出了15万左右的用户以及12万左右的商品, 以及对应的经过脱敏处理的用户特征和经过预处理的商品特征,旨在为用户推荐50个其可能购买的商品。 推荐系统架构方案 本项目采用传统的召回+排序的方案。

107 Dec 29, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 05, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
African language Speech Recognition - Speech-to-Text

Swahili-Speech-To-Text Table of Contents Swahili-Speech-To-Text Overview Scenario Approach Project Structure data: models: notebooks: scripts tests: l

2 Jan 05, 2023
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*

CMU Locus Lab 136 Dec 18, 2022
A flexible and extensible framework for gait recognition.

A flexible and extensible framework for gait recognition. You can focus on designing your own models and comparing with state-of-the-arts easily with the help of OpenGait.

Shiqi Yu 335 Dec 22, 2022
Various operations like path tracking, counting, etc by using yolov5

Object-tracing-with-YOLOv5 Various operations like path tracking, counting, etc by using yolov5

Pawan Valluri 5 Nov 28, 2022
Toontown: Galaxy, a new Toontown game based on Disney's Toontown Online

Toontown: Galaxy The official archive repo for Toontown: Galaxy, a new Toontown

1 Feb 15, 2022
CCNet: Criss-Cross Attention for Semantic Segmentation (TPAMI 2020 & ICCV 2019).

CCNet: Criss-Cross Attention for Semantic Segmentation Paper Links: Our most recent TPAMI version with improvements and extensions (Earlier ICCV versi

Zilong Huang 1.3k Dec 27, 2022
Code for paper: Group-CAM: Group Score-Weighted Visual Explanations for Deep Convolutional Networks

Group-CAM By Zhang, Qinglong and Rao, Lu and Yang, Yubin [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the o

zhql 98 Nov 16, 2022
This repository holds code and data for our PETS'22 article 'From "Onion Not Found" to Guard Discovery'.

From "Onion Not Found" to Guard Discovery (PETS'22) This repository holds the code and data for our PETS'22 paper titled 'From "Onion Not Found" to Gu

Lennart Oldenburg 3 May 04, 2022
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

<a href=[email protected](SZ)"> 7 Dec 16, 2021
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
Boosting Monocular Depth Estimation Models to High-Resolution via Content-Adaptive Multi-Resolution Merging

Boosting Monocular Depth Estimation Models to High-Resolution via Content-Adaptive Multi-Resolution Merging This repository contains an implementation

Computational Photography Lab @ SFU 1.1k Jan 02, 2023
chen2020iros: Learning an Overlap-based Observation Model for 3D LiDAR Localization.

Overlap-based 3D LiDAR Monte Carlo Localization This repo contains the code for our IROS2020 paper: Learning an Overlap-based Observation Model for 3D

Photogrammetry & Robotics Bonn 219 Dec 15, 2022
Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

Mingyu Ding 36 Dec 06, 2022
Train an imgs.ai model on your own dataset

imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings.

Fabian Offert 5 Dec 21, 2021
Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library.

SymEngine Python Wrappers Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library. Installation Pip See License section

136 Dec 28, 2022