A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

You might also like...
A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

SimMIM: A Simple Framework for Masked Image Modeling
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

SeMask: Semantically Masked Transformers for Semantic Segmentation.
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

FocusFace: Multi-task Contrastive Learning for Masked Face Recognition
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Comments
  • Excellent work (`mae.ipynb`)!

    Excellent work (`mae.ipynb`)!

    @ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

    • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
    • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
    • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

    After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

    opened by sayakpaul 7
  • Unshuffle the patches?

    Unshuffle the patches?

    Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation, decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
    no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

    opened by changtaoli 2
  • Could you also share the weight of the pretrained decoder?

    Could you also share the weight of the pretrained decoder?

    Hi,

    Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

    Best Regards, Hongxin

    opened by hongxin001 1
  • Issue with the plotting utility `show_masked_image`

    Issue with the plotting utility `show_masked_image`

    Should be:

    def show_masked_image(self, patches):
            # Utility function that helps visualize maksed images.
            _, unmask_indices = self.get_random_indices()
            unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)
    
            # Necessary for plotting.
            ids = tf.argsort(unmask_indices)
            sorted_unmask_indices = tf.sort(unmask_indices)
            unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)
    
            # Select a random index for visualization.
            idx = np.random.choice(len(sorted_unmask_indices))
            print(f"Index selected: {idx}.")
    
            n = int(np.sqrt(NUM_PATCHES))
            unmask_index = sorted_unmask_indices[idx]
            unmasked_patch = unmasked_patches[idx]
    
            plt.figure(figsize=(4, 4))
    
            count = 0
            for i in range(NUM_PATCHES):
                ax = plt.subplot(n, n, i + 1)
    
                if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                    patch = unmasked_patch[count]
                    patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
                    count = count + 1
                else:
                    patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
            plt.show()
    
            # Return the random index to validate the image outside the method.
            return idx
    
    opened by ariG23498 1
Releases(v1.0.0)
Owner
Aritra Roy Gosthipaty
Learning with a learning rate of 1e-10.
Aritra Roy Gosthipaty
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
AlphaNet Improved Training of Supernet with Alpha-Divergence

AlphaNet: Improved Training of Supernet with Alpha-Divergence This repository contains our PyTorch training code, evaluation code and pretrained model

Facebook Research 87 Oct 10, 2022
Official implementation of "Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform", ICCV 2021

Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform This repository is the implementation of "Variable-Rate Deep Image C

Myungseo Song 47 Dec 13, 2022
A curated list of awesome projects and resources related fastai

A curated list of awesome projects and resources related fastai

Tanishq Abraham 138 Dec 22, 2022
Mail classification with tensorflow and MS Exchange Server (ham or spam).

Mail classification with tensorflow and MS Exchange Server (ham or spam).

Metin Karatas 1 Sep 11, 2021
Python package for missing-data imputation with deep learning

MIDASpy Overview MIDASpy is a Python package for multiply imputing missing data using deep learning methods. The MIDASpy algorithm offers significant

MIDASverse 77 Dec 03, 2022
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

benfeng 69 Nov 15, 2022
Machine learning Bot detection technique, based on United States election dataset

Machine learning Bot detection technique, based on United States election dataset (2020). Current github repo provides implementation described in pap

Alexander Shevtsov 4 Nov 20, 2022
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

Xintao 593 Jan 03, 2023
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
Official Code for "Constrained Mean Shift Using Distant Yet Related Neighbors for Representation Learning"

CMSF Official Code for "Constrained Mean Shift Using Distant Yet Related Neighbors for Representation Learning" Requirements Python = 3.7.6 PyTorch

4 Nov 25, 2022
This is a clean and robust Pytorch implementation of DQN and Double DQN.

DQN/DDQN-Pytorch This is a clean and robust Pytorch implementation of DQN and Double DQN. Here is the training curve: All the experiments are trained

XinJingHao 15 Dec 27, 2022
The authors' implementation of Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations

Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations This is the authors' implementation of Unsupervised Adversarial Learning of

Dwango Media Village 140 Dec 07, 2022
Project page of the paper 'Analyzing Perception-Distortion Tradeoff using Enhanced Perceptual Super-resolution Network' (ECCVW 2018)

EPSR (Enhanced Perceptual Super-resolution Network) paper This repo provides the test code, pretrained models, and results on benchmark datasets of ou

Subeesh Vasu 78 Nov 19, 2022
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Unity Technologies 187 Dec 24, 2022
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
PyTorch implementation for our paper "Deep Facial Synthesis: A New Challenge"

FSGAN Here is the official PyTorch implementation for our paper "Deep Facial Synthesis: A New Challenge". This project achieve the translation between

Deng-Ping Fan 32 Oct 10, 2022
End-to-end Temporal Action Detection with Transformer. [Under review]

TadTR: End-to-end Temporal Action Detection with Transformer By Xiaolong Liu, Qimeng Wang, Yao Hu, Xu Tang, Song Bai, Xiang Bai. This repo holds the c

Xiaolong Liu 105 Dec 25, 2022