A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
Simulation environments for the CrazyFlie quadrotor: Used for Reinforcement Learning and Sim-to-Real Transfer

Phoenix-Drone-Simulation An OpenAI Gym environment based on PyBullet for learning to control the CrazyFlie quadrotor: Can be used for Reinforcement Le

Sven Gronauer 8 Dec 07, 2022
[AI6122] Text Data Management & Processing

[AI6122] Text Data Management & Processing is an elective course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6122 of Semester 1, AY2021-2022, starting from 08/2021. The instruc

HT. Li 1 Jan 17, 2022
Target Propagation via Regularized Inversion

Target Propagation via Regularized Inversion The present code implements an ideal formulation of target propagation using regularized inverses compute

Vincent Roulet 0 Dec 02, 2021
SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning

Datasets | Website | Raw Data | OpenReview SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning Christopher

67 Dec 17, 2022
🏅 Top 5% in 제2회 연구개발특구 인공지능 경진대회 AI SPARK 챌린지

AI_SPARK_CHALLENG_Object_Detection 제2회 연구개발특구 인공지능 경진대회 AI SPARK 챌린지 🏅 Top 5% in mAP(0.75) (443명 중 13등, mAP: 0.98116) 대회 설명 Edge 환경에서의 가축 Object Dete

3 Sep 19, 2022
chainladder - Property and Casualty Loss Reserving in Python

chainladder (python) chainladder - Property and Casualty Loss Reserving in Python This package gets inspiration from the popular R ChainLadder package

Casualty Actuarial Society 130 Dec 07, 2022
Generative Models as a Data Source for Multiview Representation Learning

GenRep Project Page | Paper Generative Models as a Data Source for Multiview Representation Learning Ali Jahanian, Xavier Puig, Yonglong Tian, Phillip

Ali 81 Dec 03, 2022
Semantic Segmentation Architectures Implemented in PyTorch

pytorch-semseg Semantic Segmentation Algorithms Implemented in PyTorch This repository aims at mirroring popular semantic segmentation architectures i

Meet Shah 3.3k Dec 29, 2022
StarGAN2 for practice

StarGAN2 for practice This version of StarGAN2 (coined as 'Post-modern Style Transfer') is intended mostly for fellow artists, who rarely look at scie

vadim epstein 87 Sep 24, 2022
A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM's

sign-language-detection A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM. The project is built for a vocabular

Hashim 4 Feb 06, 2022
diablo2 resurrected loot filter

Only For Chinese and Traditional Chinese The filter only for Chinese and Traditional Chinese, i didn't change it for other language.Maybe you could mo

elmagnifico 249 Dec 04, 2022
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022
Human Action Controller - A human action controller running on different platforms.

Human Action Controller (HAC) Goal A human action controller running on different platforms. Fun Easy-to-use Accurate Anywhere Fun Examples Mouse Cont

27 Jul 20, 2022
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Phil Wang 209 Dec 28, 2022
Official code repository for the work: "The Implicit Values of A Good Hand Shake: Handheld Multi-Frame Neural Depth Refinement"

Handheld Multi-Frame Neural Depth Refinement This is the official code repository for the work: The Implicit Values of A Good Hand Shake: Handheld Mul

55 Dec 14, 2022
Diverse Object-Scene Compositions For Zero-Shot Action Recognition

Diverse Object-Scene Compositions For Zero-Shot Action Recognition This repository contains the source code for the use of object-scene compositions f

7 Sep 21, 2022
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

Frank Liu 26 Oct 13, 2022
Computations and statistics on manifolds with geometric structures.

Geomstats Code Continuous Integration Code coverage (numpy) Code coverage (autograd, tensorflow, pytorch) Documentation Community NEWS: Geomstats is r

875 Dec 31, 2022