PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Owner
Malik Boudiaf
Malik Boudiaf
A facial recognition doorbell system using a Raspberry Pi

Facial Recognition Doorbell This project expands on the person-detecting doorbell system to allow it to identify faces, and announce names accordingly

rydercalmdown 22 Apr 15, 2022
Conversion between units used in magnetism

convmag Conversion between various units used in magnetism The conversions between base units available are: T - G : 1e4

0 Jul 15, 2021
Material del curso IIC2233 Programación Avanzada 📚

Contenidos Los contenidos se organizan según la semana del semestre en que nos encontremos, y según la semana que se destina para su estudio. Los cont

IIC2233 @ UC 72 Dec 23, 2022
Brain tumor detection using CNN (InceptionResNetV2 Model)

Brain-Tumor-Detection Building a detection model using a convolutional neural network in Tensorflow & Keras. Used brain MRI images. InceptionResNetV2

1 Feb 13, 2022
Autoregressive Predictive Coding: An unsupervised autoregressive model for speech representation learning

Autoregressive Predictive Coding This repository contains the official implementation (in PyTorch) of Autoregressive Predictive Coding (APC) proposed

iamyuanchung 173 Dec 18, 2022
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty

AugMix Introduction We propose AugMix, a data processing technique that mixes augmented images and enforces consistent embeddings of the augmented ima

Google Research 876 Dec 17, 2022
Deep Learning Visuals contains 215 unique images divided in 23 categories

Deep Learning Visuals contains 215 unique images divided in 23 categories (some images may appear in more than one category). All the images were originally published in my book "Deep Learning with P

Daniel Voigt Godoy 1.3k Dec 28, 2022
This package implements the algorithms introduced in Smucler, Sapienza, and Rotnitzky (2020) to compute optimal adjustment sets in causal graphical models.

optimaladj: A library for computing optimal adjustment sets in causal graphical models This package implements the algorithms introduced in Smucler, S

Facundo Sapienza 6 Aug 04, 2022
An implementation of the Contrast Predictive Coding (CPC) method to train audio features in an unsupervised fashion.

CPC_audio This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper Unsupervised Pretraining Transfers we

8 Nov 14, 2022
Official Implementation of DE-CondDETR and DELA-CondDETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-CondDETR and DELA-Cond

Wen Wang 41 Dec 12, 2022
🥇 LG-AI-Challenge 2022 1위 솔루션 입니다.

LG-AI-Challenge-for-Plant-Classification Dacon에서 진행된 농업 환경 변화에 따른 작물 병해 진단 AI 경진대회 에 대한 코드입니다. (colab directory에 코드가 잘 정리 되어있습니다.) Requirements python

siwooyong 10 Jun 30, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 08, 2023
Image restoration with neural networks but without learning.

Warning! The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make s

Dmitry Ulyanov 7.4k Jan 01, 2023
Deep Halftoning with Reversible Binary Pattern

Deep Halftoning with Reversible Binary Pattern ICCV Paper | Project Website | BibTex Overview Existing halftoning algorithms usually drop colors and f

Menghan Xia 17 Nov 22, 2022
Preprocessed Datasets for our Multimodal NER paper

Unified Multimodal Transformer (UMT) for Multimodal Named Entity Recognition (MNER) Two MNER Datasets and Codes for our ACL'2020 paper: Improving Mult

76 Dec 21, 2022
An LSTM for time-series classification

Update 10-April-2017 And now it works with Python3 and Tensorflow 1.1.0 Update 02-Jan-2017 I updated this repo. Now it works with Tensorflow 0.12. In

Rob Romijnders 391 Dec 27, 2022
Sum-Product Probabilistic Language

Sum-Product Probabilistic Language SPPL is a probabilistic programming language that delivers exact solutions to a broad range of probabilistic infere

MIT Probabilistic Computing Project 57 Nov 17, 2022
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

Status: Under development (expect bug fixes and huge updates) ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectiv

37 Dec 28, 2022
📚 A collection of all the Deep Learning Metrics that I came across which are not accuracy/loss.

📚 A collection of all the Deep Learning Metrics that I came across which are not accuracy/loss.

Rahul Vigneswaran 1 Jan 17, 2022
Deep learning algorithms for muon momentum estimation in the CMS Trigger System

Deep learning algorithms for muon momentum estimation in the CMS Trigger System The Compact Muon Solenoid (CMS) is a general-purpose detector at the L

anuragB 2 Oct 06, 2021