Prototypical Networks for Few shot Learning in PyTorch

Overview

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Owner
Orobix
Orobix
The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".

R2D2 This is the official code for paper titled "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Mode

Alipay 49 Dec 17, 2022
Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

ResMLP - Pytorch Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch Install $ pip install res-mlp-py

Phil Wang 178 Dec 02, 2022
The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text"

Finnish Dialect Identification The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text". We present a te

Rootroo Ltd 2 Dec 25, 2021
Auto White-Balance Correction for Mixed-Illuminant Scenes

Auto White-Balance Correction for Mixed-Illuminant Scenes Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown York University Video Reference code

Mahmoud Afifi 47 Nov 26, 2022
A Pose Estimator for Dense Reconstruction with the Structured Light Illumination Sensor

Phase-SLAM A Pose Estimator for Dense Reconstruction with the Structured Light Illumination Sensor This open source is written by MATLAB Run Mode Open

Xi Zheng 14 Dec 19, 2022
(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework

(Py)TOD: Tensor-based Outlier Detection, A General GPU-Accelerated Framework Background: Outlier detection (OD) is a key data mining task for identify

Yue Zhao 127 Jan 05, 2023
Monitora la qualità della ricezione dei segnali radio nelle province siciliane.

FMap-server Monitora la qualità della ricezione dei segnali radio nelle province siciliane. Conversion data Frequency - StationName maps are stored in

Triglie 5 May 24, 2021
This repo contains source code and materials for the TEmporally COherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Nils Thuerey 5.2k Jan 02, 2023
Official code for 'Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urban Driving Scenes'

PEBAL This repo contains the Pytorch implementation of our paper: Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urba

Yu Tian 115 Dec 29, 2022
[ECCVW2020] Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DiMP)

Feel free to visit my homepage Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DIMP) [ECCVW2020 paper] Presentation

Seokeon Choi 35 Oct 26, 2022
Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Improving evidential deep learning via multi task learning It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task le

deargen 11 Nov 19, 2022
Bringing sanity to world of messed-up data

Sanitize sanitize is a Python module for making sure various things (e.g. HTML) are safe to use. It was originally written by Mark Pilgrim and is dist

Alireza Savand 63 Oct 26, 2021
MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition

MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition Paper: MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition accepted fo

64 Dec 18, 2022
The repo for the paper "I3CL: Intra- and Inter-Instance Collaborative Learning for Arbitrary-shaped Scene Text Detection".

I3CL: Intra- and Inter-Instance Collaborative Learning for Arbitrary-shaped Scene Text Detection Updates | Introduction | Results | Usage | Citation |

33 Jan 05, 2023
MacroTools provides a library of tools for working with Julia code and expressions.

MacroTools.jl MacroTools provides a library of tools for working with Julia code and expressions. This includes a powerful template-matching system an

FluxML 278 Dec 11, 2022
i-RevNet Pytorch Code

i-RevNet: Deep Invertible Networks Pytorch implementation of i-RevNets. i-RevNets define a family of fully invertible deep networks, built from a succ

Jörn Jacobsen 378 Dec 06, 2022
[TPDS'21] COSCO: Container Orchestration using Co-Simulation and Gradient Based Optimization for Fog Computing Environments

COSCO Framework COSCO is an AI based coupled-simulation and container orchestration framework for integrated Edge, Fog and Cloud Computing Environment

imperial-qore 39 Dec 25, 2022
This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation).

FlatGCN This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation, submitted to ICASSP2022). Req

Dreamer 2 Aug 09, 2022
Network Compression via Central Filter

Network Compression via Central Filter Environments The code has been tested in the following environments: Python 3.8 PyTorch 1.8.1 cuda 10.2 torchsu

2 May 12, 2022
Rate-limit-semaphore - Semaphore implementation with rate limit restriction for async-style (any core)

Rate Limit Semaphore Rate limit semaphore for async-style (any core) There are t

Yan Kurbatov 4 Jun 21, 2022