Official PyTorch code for CVPR 2020 paper "Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision"

Overview

Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision

https://arxiv.org/abs/2003.00393

Abstract

Active learning (AL) aims to minimize labeling efforts for data-demanding deep neural networks (DNNs) by selecting the most representative data points for annotation. However, currently used methods are ill-equipped to deal with biased data. The main motivation of this paper is to consider a realistic setting for pool-based semi-supervised AL, where the unlabeled collection of train data is biased. We theoretically derive an optimal acquisition function for AL in this setting. It can be formulated as distribution shift minimization between unlabeled train data and weakly-labeled validation dataset. To implement such acquisition function, we propose a low-complexity method for feature density matching using Fisher kernel (FK) self-supervision as well as several novel pseudo-label estimators. Our FK-based method outperforms state-of-the-art methods on MNIST, SVHN, and ImageNet classification while requiring only 1/10th of processing. The conducted experiments show at least 40% drop in labeling efforts for the biased class-imbalanced data compared to existing methods.

BibTex Citation

If you like our paper or code, please cite its CVPR2020 preprint using the following BibTex:

@article{gudovskiy2020al,
  title={Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision},
  author={Gudovskiy, Denis and Hodgkinson, Alec and Yamaguchi, Takuya and Tsukizawa, Sotaro},
  journal={arXiv:2003.00393},
  year={2020}
}

Installation

  • Install v1.1+ PyTorch by selecting your environment on the website and running the appropriate command.
  • Clone this repository: code has been tested on Python 3+.
  • Install DALI for ImageNet only: tested on v0.11.0.
  • Optionally install Kornia for MC-based pseudo-label estimation metrics. However, due to strict Python 3.6+ requirement for this lib, by default, we provide our simple rotation function. Use Kornia to experiment with other sampling strategies.

Datasets

Data and temporary files like descriptors, checkpoints and index files are saved into ./local_data/{dataset} folder. For example, MNIST scripts are located in ./mnist and its data is saved into ./local_data/MNIST folder, correspondingly. In order to get statistically significant results, we execute multiple runs of the same configuration with randomized weights and training dataset splits and save results to ./local_data/{dataset}/runN folders. We suggest to check that you have enough space for large-scale datasets.

MNIST, SVHN

Datasets will be automatically downloaded and converted to PyTorch after the first run of AL.

ImageNet

Due to large size, ImageNet has to be manually downloaded and preprocessed using these scripts.

Code Organization

  • Scripts are located in ./{dataset} folder.
  • Main parts of the framework are contained in only few files: "unsup.py", "gen_descr.py", "main_descr.py" as well as execution script "run.py".
  • Dataset loaders are located in ./{dataset}/custom_datasets and DNN models in ./{dataset}/custom_models
  • The "unsup.py" is a script to train initial model by unsupervised pretraining using rotation method and to produce all-random weights initial model.
  • The "gen_descr.py" generates descriptor database files in ./local_data/{dataset}/runN/descr.
  • The "main_descr.py" performs AL feature matching, adds new data to training dataset and retrains model with new augmented data. Its checkpoints are saved into ./local_data/{dataset}/runN/checkpoint.
  • The run.py" can read these checkpoint files and perform AL iteration with retraining.
  • The run_plot.py" generates performance curves that can be found in the paper.
  • To make confusion matrices and t-SNE plots, use extra "visualize_tsne.py" script for MNIST only.
  • VAAL code can be found in ./vaal folder, which is adopted version of official repo.

Running Active Learning Experiments

  • Install minimal required packages from requirements.txt.
  • The command interface for all methods is combined into "run.py" script. It can run multiple algorithms and data configurations.
  • The script parameters may differ depending on the dataset and, hence, it is better to use "python3 run.py --help" command.
  • First, you have to set configuration in cfg = list() according to its format and execute "run.py" script with "--initial" flag to generate initial random and unsupervised pretrained models.
  • Second, the same script should be run without "--initial".
  • Third, after all AL steps are executed, "run_plot.py" should be used to reproduce performance curves.
  • All these steps require basic understanding of the AL terminology.
  • Use the default configurations to reproduce paper results.
  • To speed up or parallelize multiple runs, use --run-start, --run-stop parameters to limit number of runs saved in ./local_data/{dataset}/runN folders. The default setting is 10 runs for MNIST, 5 for SVHN and 1 for ImageNet.
pip3 install -U -r requirements.txt
python3 run.py --gpu 0 --initial # generate initial models
python3 run.py --gpu 0 --unsupervised 0 # AL with the initial all-random parameters model
python3 run.py --gpu 0 --unsupervised 1 # AL with the initial model pretrained using unsupervised rotation method

Reference Results

MNIST

MNIST LeNet test accuracy: (a) no class imbalance, (b) 100x class imbalance, and (c) ablation study of pseudo-labeling and unsupervised pretraining (100x class imbalance). Our method decreases labeling by 40% compared to prior works for biased data.

SVHN and ImageNet

SVHN ResNet-10 test (top) and ImageNet ResNet-18 val (bottom) accuracy: (a,c) no class imbalance and (b,d) with 100x class imbalance.

MNIST Visualizations

Confusion matrix (top) and t-SNE (bottom) of MNIST test data at AL iteration b=3 with 100x class imbalance for: (a) varR with E=1, K=128, (b) R_{z,g}, S=hat{p}(y,z), L=80 (ours), and (c) R_{z,g}, S=y, L=80. Dots and balls represent correspondingly correctly and incorrectly classified images for t-SNE visualizations. The underrepresented classes {5,8,9} have on average 36% accuracy for prior work (a), while our method (b) increases their accuracy to 75%. The ablation configuration (c) shows 89% theoretical limit of our method.

Owner
Denis
Machine and Deep Learning Researcher
Denis
Code for Neurips2021 Paper "Topology-Imbalance Learning for Semi-Supervised Node Classification".

Topology-Imbalance Learning for Semi-Supervised Node Classification Introduction Code for NeurIPS 2021 paper "Topology-Imbalance Learning for Semi-Sup

Victor Chen 40 Nov 23, 2022
J.A.R.V.I.S is an AI virtual assistant made in python.

J.A.R.V.I.S is an AI virtual assistant made in python. Running JARVIS Without Python To run JARVIS without python: 1. Head over to our installation pa

somePythonProgrammer 16 Dec 29, 2022
An implementation of shampoo

shampoo.pytorch An implementation of shampoo, proposed in Shampoo : Preconditioned Stochastic Tensor Optimization by Vineet Gupta, Tomer Koren and Yor

Ryuichiro Hataya 69 Sep 10, 2022
D2Go is a toolkit for efficient deep learning

D2Go D2Go is a production ready software system from FacebookResearch, which supports end-to-end model training and deployment for mobile platforms. W

Facebook Research 744 Jan 04, 2023
Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

Self Supervised Learning with Fastai Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks. Install pip install self-

Kerem Turgutlu 276 Dec 23, 2022
A comprehensive list of published machine learning applications to cosmology

ml-in-cosmology This github attempts to maintain a comprehensive list of published machine learning applications to cosmology, organized by subject ma

George Stein 290 Dec 29, 2022
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.

aft-pytorch Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc. Installation You can i

Rishabh Anand 184 Dec 12, 2022
LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs.

LocUNet LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs. The method utilizes accura

4 Oct 05, 2022
DLFlow is a deep learning framework.

DLFlow是一套深度学习pipeline,它结合了Spark的大规模特征处理能力和Tensorflow模型构建能力。利用DLFlow可以快速处理原始特征、训练模型并进行大规模分布式预测,十分适合离线环境下的生产任务。利用DLFlow,用户只需专注于模型开发,而无需关心原始特征处理、pipeline构建、生产部署等工作。

DiDi 152 Oct 27, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
DvD-TD3: Diversity via Determinants for TD3 version

DvD-TD3: Diversity via Determinants for TD3 version The implementation of paper Effective Diversity in Population Based Reinforcement Learning. Instal

3 Feb 11, 2022
Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation

Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation. Generally, MAS methods register multiple atlases, i.e., medical images with corresponding labels, to a target i

NanYoMy 13 Oct 09, 2022
Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding

🍐 quince Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding 🍐 Installation $ git clone

Andrew Jesson 19 Jun 23, 2022
Pytorch reimplementation of PSM-Net: "Pyramid Stereo Matching Network"

This is a Pytorch Lightning version PSMNet which is based on JiaRenChang/PSMNet. use python main.py to start training. PSM-Net Pytorch reimplementatio

XIAOTIAN LIU 1 Nov 25, 2021
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
ROS-UGV-Control-Interface - Control interface which can be used in any UGV

ROS-UGV-Control-Interface Cam Closed: Cam Opened:

Ahmet Fatih Akcan 1 Nov 04, 2022
Streaming Anomaly Detection Framework in Python (Outlier Detection for Streaming Data)

Python Streaming Anomaly Detection (PySAD) PySAD is an open-source python framework for anomaly detection on streaming multivariate data. Documentatio

Selim Firat Yilmaz 181 Dec 18, 2022
An official reimplementation of the method described in the INTERSPEECH 2021 paper - Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

Facebook Research 253 Jan 06, 2023
The repository is for safe reinforcement learning baselines.

Safe-Reinforcement-Learning-Baseline The repository is for Safe Reinforcement Learning (RL) research, in which we investigate various safe RL baseline

172 Dec 19, 2022
Fast sparse deep learning on CPUs

SPARSEDNN **If you want to use this repo, please send me an email: [email pro

Ziheng Wang 44 Nov 30, 2022