Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

Overview

PAWS-TF 🐾

Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS) in TensorFlow (2.4.1).

PAWS introduces a simple way to combine a very small fraction of labeled data with a comparatively larger corpus of unlabeled data during pre-training. With its approach, it sets the state-of-the-art in semi-supervised learning (as of May 2021) beating methods like SimCLRV2, Meta Pseudo Labels that too with fewer parameters and a smaller pre-training schedule. For details, I recommend checking out the original paper as well as this blog post by the authors.

This repository implements and includes all the major bits proposed in PAWS in TensorFlow. The only major difference is that the pre-training and subsequent fine-tuning weren't run for the original number of epochs (600 and 30 respectively) to save compute. I have reused the utility components for PAWS loss from the original implementation.

Dataset ⌗

The current code works with CIFAR10 and uses 4000 labeled samples (8%) during pre-training (along with the unlabeled samples).

Features

  • Multi-crop augmentation strategy (originally introduced in SwAV)
  • Class stratified sampler (common in few-shot classification problems)
  • WarmUpCosine learning rate schedule (which is typical for self-supervised and semi-supervised pre-training)
  • LARS optimizer (comes from TensorFlow Model Garden)

The trunk portion (all, except the last classification layer) of a WideResNet-28-2 is used inside the encoder for CIFAR10. All the experimental configurations were followed from the Appendix C of the paper.

Setup and code structure 💻

A GCP VM (n1-standard-8) with a single V100 GPU was used for executing the code.

  • paws_train.py runs the pre-training as introduced in PAWS.
  • fine_tune.py runs the fine-tuning part as suggested in Appendix C. Note that this is only required for CIFAR10.
  • nn_eval.py runs the soft nearest neighbor classification on CIFAR10 test set.

Pre-training and fine-tuning total take 1.4 hours to complete. All the logs are available in misc/logs.txt. Additionally, the indices that were used to sample the labeled examples from the CIFAR10 training set are available here.

Results 📊

Pre-training

PAWS minimizes the cross-entropy loss (as well as maximizes mean-entropy) during pre-training. This is what the training plot indicates too:

To evaluate the effectivity of the pre-training, PAWS performs soft nearest neighbor classification to report the top-1 accuracy score on a given test set.

Top-1 Accuracy

This repository gets to 73.46% top-1 accuracy on the CIFAR10 test set. Again, note that I only pre-trained for 50 epochs (as opposed to 600) and fine-tuned for 10 epochs (as opposed to 30). With the original schedule this score should be around 96.0%.

In the following PCA projection plot, we see that the embeddings of images (computed after fine-tuning) of PAWS are starting to be well separated:

Notebooks 📘

There are two Colab Notebooks:

Misc ⺟

  • Model weights are available here for reproducibility.
  • With mixed-precision training, the performance can further be improved. I am open to accepting contributions that would implement mixed-precision training in the current code.

Acknowledgements

  • Huge amount of thanks to Mahmoud Assran (first author of PAWS) for patiently resolving my doubts.
  • ML-GDE program for providing GCP credit support.

Paper Citation

@misc{assran2021semisupervised,
      title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
      author={Mahmoud Assran and Mathilde Caron and Ishan Misra and Piotr Bojanowski and Armand Joulin and Nicolas Ballas and Michael Rabbat},
      year={2021},
      eprint={2104.13963},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Supplementary code for the paper
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

ISTR: End-to-End Instance Segmentation with Transformers (https://arxiv.org/abs/2105.00637)

This is the project page for the paper: ISTR: End-to-End Instance Segmentation via Transformers, Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wa

Releases(v1.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation

TVT Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation Datasets: Digit: MNIST, SVHN, USPS Object: Office, Office-Home, Vi

37 Dec 15, 2022
Shuffle Attention for MobileNetV3

SA-MobileNetV3 Shuffle Attention for MobileNetV3 Train Run the following command for train model on your own dataset: python train.py --dataset mnist

Sajjad Aemmi 36 Dec 28, 2022
A tool for making map images from OpenTTD save games

OpenTTD Surveyor A tool for making map images from OpenTTD save games. This is not part of the main OpenTTD codebase, nor is it ever intended to be pa

Aidan Randle-Conde 9 Feb 15, 2022
The implementation our EMNLP 2021 paper "Enhanced Language Representation with Label Knowledge for Span Extraction".

LEAR The implementation our EMNLP 2021 paper "Enhanced Language Representation with Label Knowledge for Span Extraction". See below for an overview of

杨攀 93 Jan 07, 2023
N-Person-Check-Checker-Splitter - A calculator app use to divide checks

N-Person-Check-Checker-Splitter This is my from-scratch programmed calculator ap

2 Feb 15, 2022
GndNet: Fast ground plane estimation and point cloud segmentation for autonomous vehicles using deep neural networks.

GndNet: Fast Ground plane Estimation and Point Cloud Segmentation for Autonomous Vehicles. Authors: Anshul Paigwar, Ozgur Erkent, David Sierra Gonzale

Anshul Paigwar 114 Dec 29, 2022
Graph-based community clustering approach to extract protein domains from a predicted aligned error matrix

Using a predicted aligned error matrix corresponding to an AlphaFold2 model , returns a series of lists of residue indices, where each list corresponds to a set of residues clustering together into a

Tristan Croll 24 Nov 23, 2022
Raindrop strategy for Irregular time series

Graph-Guided Network For Irregularly Sampled Multivariate Time Series Overview This repository contains processed datasets and implementation code for

Zitnik Lab @ Harvard 74 Jan 03, 2023
🎁 3,000,000+ Unsplash images made available for research and machine learning

The Unsplash Dataset The Unsplash Dataset is made up of over 250,000+ contributing global photographers and data sourced from hundreds of millions of

Unsplash 2k Jan 03, 2023
AI that generate music

PianoGPT ai that generate music try it here https://share.streamlit.io/annasajkh/pianogpt/main/main.py or here https://huggingface.co/spaces/Annas/Pia

Annas 28 Nov 27, 2022
Colab notebook and additional materials for Python-driven analysis of redlining data in Philadelphia

RedliningExploration The Google Colaboratory file contained in this repository contains work inspired by a project on educational inequality in the Ph

Benjamin Warren 1 Jan 20, 2022
The code is an implementation of Feedback Convolutional Neural Network for Visual Localization and Segmentation.

Feedback Convolutional Neural Network for Visual Localization and Segmentation The code is an implementation of Feedback Convolutional Neural Network

19 Dec 04, 2022
TCNN Temporal convolutional neural network for real-time speech enhancement in the time domain

TCNN Pandey A, Wang D L. TCNN: Temporal convolutional neural network for real-time speech enhancement in the time domain[C]//ICASSP 2019-2019 IEEE Int

凌逆战 16 Dec 30, 2022
PyTorch-lightning implementation of the ESFW module proposed in our paper Edge-Selective Feature Weaving for Point Cloud Matching

Edge-Selective Feature Weaving for Point Cloud Matching This repository contains a PyTorch-lightning implementation of the ESFW module proposed in our

5 Feb 14, 2022
Code for the paper SphereRPN: Learning Spheres for High-Quality Region Proposals on 3D Point Clouds Object Detection, ICIP 2021.

SphereRPN Code for the paper SphereRPN: Learning Spheres for High-Quality Region Proposals on 3D Point Clouds Object Detection, ICIP 2021. Authors: Th

Thang Vu 15 Dec 02, 2022
MolRep: A Deep Representation Learning Library for Molecular Property Prediction

MolRep: A Deep Representation Learning Library for Molecular Property Prediction Summary MolRep is a Python package for fairly measuring algorithmic p

AI-Health @NSCC-gz 83 Dec 24, 2022
Code for the Lovász-Softmax loss (CVPR 2018)

The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks Maxim Berman, Amal Ranne

Maxim Berman 1.3k Jan 04, 2023
Simultaneous Demand Prediction and Planning

Simultaneous Demand Prediction and Planning Dependencies Python packages: Pytorch, scikit-learn, Pandas, Numpy, PyYAML Data POI: data/poi Road network

Yizong Wang 1 Sep 01, 2022
Fast, flexible and easy to use probabilistic modelling in Python.

Please consider citing the JMLR-MLOSS Manuscript if you've used pomegranate in your academic work! pomegranate is a package for building probabilistic

Jacob Schreiber 3k Dec 29, 2022