Learning to Self-Train for Semi-Supervised Few-Shot

Overview

Learning to Self-Train for Semi-Supervised Few-Shot Classification

LICENSE Python TensorFlow

This repository contains the TensorFlow implementation for NeurIPS 2019 Paper "Learning to Self-Train for Semi-Supervised Few-Shot Classification".

Check the few-shot classification leaderboard.

Summary

Installation

In order to run this repository, we advise you to install python 2.7 or 3.5 and TensorFlow 1.3.0 with Anaconda.

You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install tensorflow on it:

conda create --name lst-tf python=2.7
conda activate lst-tf
conda install tensorflow-gpu=1.3.0

Install other requirements:

pip install scipy tqdm opencv-python pillow matplotlib

Clone this repository:

git clone https://github.com/xinzheli1217/learning-to-self-train.git 
cd learning-to-self-train

Project Architecture

.
├── data_generator              # dataset generator 
|   └── meta_data_generator.py  # data genertor for meta-train phase
├── models                      # tensorflow model files 
|   ├── models.py               # resnet12 CNN class
|   └── meta_model_LST.py       # semi-supervised meta-train model class
├── trainer                     # tensorflow trianer files  
|   └── meta_LST.py             # semi-supervised meta-train trainer class
├── utils                       # a series of tools used in this repo
|   └── misc.py                 # miscellaneous tool functions
| 
├── data                        # the folder containing datasets for experiments
├── pretrain_weights_dir        # the folder containing MTL pre-training weights
├── weights_saving_dir          # the folder containing meta-training weights
├── test_output_dir             # the folder containing meta-testing files
├── filenames_and_labels        # the folder containing image file paths and labels for experiments
|
├── exp_train.py                # the python file with main function and parameter settings for meta-training
└── exp_test.py                 # the python file with main function and parameter settings for meta-testing

Running Experiments

First, download our processed images: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./data. And then download the pre-trained models: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./pretrain_weights_dir.

Training from Pre-Trained Models

Run semi-supervised meta-train phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_train.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --nb_ul_samples=10 --metatrain_iterations=15000 --exp_name='LST_mini_1_shot'

Run semi-supervised meta-test phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_test.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --use_distractors=False --nb_ul_samples=100 --unfiles_num=10 --test_iter=15000 --recurrent_stage_nums=6 --nums_in_folders=30 --hard_selection=20 --exp_name='LST_mini_1_shot' 

Hyperparameters and Options

There are some main hyperparameters used in the experiments, you can edit them in the exp_train.py and the exp_test.py file for meta-train and meta-test phase respectively. There are two kinds of hyperparameters: (1) common hyperparameters that shared with meta-train and meta-test, (2) test-specific hyperparameters that used for recurrent self-training process in meta-test.

  • Common hyperparameters:

    • way_num number of classes
    • shot_num number of examples per class
    • dataset dataset used in the experiment (miniImagenet or tieredImagenet)
    • pretrain_class_num number of meta-train classes
    • exp_name name for the current experiment
    • meta_batch_size number of tasks sampled per meta-update in meta-train phase
    • base_lr step size alpha for inner gradient update
    • meta_lr the meta learning rate for SS and initial model parameters
    • min_meta_lr the min meta learning rate for all meta-parameters
    • swn_lr the meta learning rate for SWN
    • nb_ul_samples number of unlabeled examples per class
    • re_train_epoch_num number of re-training inner gradient updates
    • train_base_epoch_num number of total inner gradient updates during train (meta-train only)
    • test_base_epoch_num number of total inner gradient updates during test (meta-test only)
  • Test-specific hyperparameters:

    • use_distractors if using distractor classes during meta-test
    • num_dis number of distracting classes used for meta-testing
    • unfiles_num number of unlabeled sample files used in the experiment (There are 10 unlabeled samples per class in each file)
    • recurrent_stage_nums number of recurrent stages used during meta-test
    • local_update_num number of inner gradient updates used in each recurrent stage
    • nums_in_folders number of unlabeled samples (per class) used in each recurrent stage
    • hard_selection number of remaining samples (per class) after applying hard-selection

If you want to change other settings, please see the comments and descriptions in exp_train.py and exp_test.py.

Performance

(%) 𝑚𝑖𝑛𝑖 𝒕𝒊𝒆𝒓𝒆𝒅 𝑚𝑖𝑛𝑖 (w/D) 𝒕𝒊𝒆𝒓𝒆𝒅 (w/D)
1-shot 70.1 ± 1.9 77.7 ± 1.6 64.1 ± 1.9 73.5 ± 1.6
5-shot 78.7 ± 0.8 85.2 ± 0.8 77.4 ± 1.8 83.4 ± 0.8

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{li2019lst,
  title={Learning to Self-Train for Semi-Supervised Few-Shot Classification},
  author = {Li, Xinzhe and Sun, Qianru and Liu, Yaoyao and Zhou, Qin and Zheng, Shibao and Chua, Tat-Seng and Schiele, Bernt},
  booktitle={NeurIPS},
  year={2019}
}

Acknowledgements

Our implementations use the source code from the following repositories and users:

S-attack library. Official implementation of two papers "Are socially-aware trajectory prediction models really socially-aware?" and "Vehicle trajectory prediction works, but not everywhere".

S-attack library: A library for evaluating trajectory prediction models This library contains two research projects to assess the trajectory predictio

VITA lab at EPFL 71 Jan 04, 2023
An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning

Mammoth - An Extendible (General) Continual Learning Framework for Pytorch NEWS STAY TUNED: We are working on an update of this repository to include

AImageLab 277 Dec 28, 2022
ETMO: Evolutionary Transfer Multiobjective Optimization

ETMO: Evolutionary Transfer Multiobjective Optimization To promote the research on ETMO, benchmark problems are of great importance to ETMO algorithm

Songbai Liu 0 Mar 16, 2021
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022
(CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic

ClassSR (CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic Paper Authors: Xiangtao Kong, Hengyuan

Xiangtao Kong 308 Jan 05, 2023
PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation

deep-hist PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation PyT

Winfried Lötzsch 10 Dec 06, 2022
LaBERT - A length-controllable and non-autoregressive image captioning model.

Length-Controllable Image Captioning (ECCV2020) This repo provides the implemetation of the paper Length-Controllable Image Captioning. Install conda

bearcatt 53 Nov 13, 2022
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Master Docs License Apache MXNet (incubating) is a deep learning framework designed for both efficiency an

ROCm Software Platform 29 Nov 16, 2022
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

Salesforce 334 Jan 06, 2023
Unsupervised Foreground Extraction via Deep Region Competition

Unsupervised Foreground Extraction via Deep Region Competition [Paper] [Code] The official code repository for NeurIPS 2021 paper "Unsupervised Foregr

28 Nov 06, 2022
An implementation of an abstract algebra for music tones (pitches).

nbdev template Use this template to more easily create your nbdev project. If you are using an older version of this template, and want to upgrade to

Open Music Kit 0 Oct 10, 2022
Easy Parallel Library (EPL) is a general and efficient deep learning framework for distributed model training.

English | 简体中文 Easy Parallel Library Overview Easy Parallel Library (EPL) is a general and efficient library for distributed model training. Usability

Alibaba 185 Dec 21, 2022
Simple torch.nn.module implementation of Alias-Free-GAN style filter and resample

Alias-Free-Torch Simple torch module implementation of Alias-Free GAN. This repository including Alias-Free GAN style lowpass sinc filter @filter.py A

이준혁(Junhyeok Lee) 64 Dec 22, 2022
Pyramid Pooling Transformer for Scene Understanding

Pyramid Pooling Transformer for Scene Understanding Requirements: torch 1.6+ torchvision 0.7.0 timm==0.3.2 Validated on torch 1.6.0, torchvision 0.7.0

Yu-Huan Wu 119 Dec 29, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
Companion code for "Bayesian logistic regression for online recalibration and revision of risk prediction models with performance guarantees"

Companion code for "Bayesian logistic regression for online recalibration and revision of risk prediction models with performance guarantees" Installa

0 Oct 13, 2021
Colossal-AI: A Unified Deep Learning System for Large-Scale Parallel Training

ColossalAI An integrated large-scale model training system with efficient parallelization techniques Installation PyPI pip install colossalai Install

HPC-AI Tech 7.1k Jan 03, 2023
The Noise Contrastive Estimation for softmax output written in Pytorch

An NCE implementation in pytorch About NCE Noise Contrastive Estimation (NCE) is an approximation method that is used to work around the huge computat

Kaiyu Shi 287 Nov 25, 2022
Adversarial-autoencoders - Tensorflow implementation of Adversarial Autoencoders

Adversarial Autoencoders (AAE) Tensorflow implementation of Adversarial Autoencoders (ICLR 2016) Similar to variational autoencoder (VAE), AAE imposes

Qian Ge 236 Nov 13, 2022