Repo for CReST: A Class-Rebalancing Self-Training Framework for Imbalanced Semi-Supervised Learning

Related tags

Deep Learningcrest
Overview

CReST in Tensorflow 2

Code for the paper: "CReST: A Class-Rebalancing Self-Training Framework for Imbalanced Semi-Supervised Learning" by Chen Wei, Kihyuk Sohn, Clayton Mellina, Alan Yuille and Fan Yang.

  • This is not an officially supported Google product.

Install dependencies

sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages env3
. env3/bin/activate
pip install -r requirements.txt
  • The code has been tested on Ubuntu 18.04 with CUDA 10.2.

Environment setting

. env3/bin/activate
export ML_DATA=/path/to/your/data
export ML_DIR=/path/to/your/code
export RESULT=/path/to/your/result
export PYTHONPATH=$PYTHONPATH:$ML_DIR

Datasets

Download or generate the datasets as follows:

  • CIFAR10 and CIFAR100: Follow the steps to download and generate balanced CIFAR10 and CIFAR100 datasets. Put it under ${ML_DATA}/cifar, for example, ${ML_DATA}/cifar/cifar10-test.tfrecord.
  • Long-tailed CIFAR10 and CIFAR100: Follow the steps to download the datasets prepared by Cui et al. Put it under ${ML_DATA}/cifar-lt, for example, ${ML_DATA}/cifar-lt/cifar-10-data-im-0.1.

Running experiment on Long-tailed CIFAR10, CIFAR100

Run MixMatch (paper) and FixMatch (paper):

  • Specify method to run via --method. It can be fixmatch or mixmatch.

  • Specify dataset via --dataset. It can be cifar10lt or cifar100lt.

  • Specify the class imbalanced ratio, i.e., the number of training samples from the most minority class over that from the most majority class, via --class_im_ratio.

  • Specify the percentage of labeled data via --percent_labeled.

  • Specify the number of generations for self-training via --num_generation.

  • Specify whether to use distribution alignment via --do_distalign.

  • Specify the initial distribution alignment temperature via --dalign_t.

  • Specify how distribution alignment is applied via --how_dalign. It can be constant or adaptive.

    python -m train_and_eval_loop \
      --model_dir=/tmp/model \
      --method=fixmatch \
      --dataset=cifar10lt \
      --input_shape=32,32,3 \
      --class_im_ratio=0.01 \
      --percent_labeled=0.1 \
      --fold=1 \
      --num_epoch=64 \
      --num_generation=6 \
      --sched_level=1 \
      --dalign_t=0.5 \
      --how_dalign=adaptive \
      --do_distalign=True

Results

The code reproduces main results of the paper. For all settings and methods, we run experiments on 5 different folds and report the mean and standard deviations. Note that the numbers may not exactly match those from the papers as there are extra randomness coming from the training.

Results on Long-tailed CIFAR10 with 10% labeled data (Table 1 in the paper).

gamma=50 gamma=100 gamma=200
FixMatch 79.4 (0.98) 66.2 (0.83) 59.9 (0.44)
CReST 83.7 (0.40) 75.4 (1.62) 63.9 (0.67)
CReST+ 84.5 (0.41) 77.7 (1.22) 67.5 (1.36)

Training with Multiple GPUs

  • Simply set CUDA_VISIBLE_DEVICES=0,1,2,3 or any number of GPUs.
  • Make sure that batch size is divisible by the number of GPUs.

Augmentation

  • One can concatenate different augmentation shortkeys to compose an augmentation sequence.
    • d: default augmentation, resize and shift.
    • h: horizontal flip.
    • ra: random augment with all augmentation ops.
    • rc: random augment with color augmentation ops only.
    • rg: random augment with geometric augmentation ops only.
    • c: cutout.
    • For example, dhrac applies shift, flip, random augment with all ops, followed by cutout.

Citing this work

@article{wei2021crest,
    title={CReST: A Class-Rebalancing Self-Training Framework for Imbalanced Semi-Supervised Learning},
    author={Chen Wei and Kihyuk Sohn and Clayton Mellina and Alan Yuille and Fan Yang},
    journal={arXiv preprint arXiv:2102.09559},
    year={2021},
}
Owner
Google Research
Google Research
Graph-based community clustering approach to extract protein domains from a predicted aligned error matrix

Using a predicted aligned error matrix corresponding to an AlphaFold2 model , returns a series of lists of residue indices, where each list corresponds to a set of residues clustering together into a

Tristan Croll 24 Nov 23, 2022
LBK 26 Dec 28, 2022
ObjectDrawer-ToolBox: a graphical image annotation tool to generate ground plane masks for a 3D object reconstruction system

ObjectDrawer-ToolBox is a graphical image annotation tool to generate ground plane masks for a 3D object reconstruction system, Object Drawer.

77 Jan 05, 2023
Source code of NeurIPS 2021 Paper ''Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration''

CaGCN This repo is for source code of NeurIPS 2021 paper "Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration". Paper L

6 Dec 19, 2022
MediaPipe is a an open-source framework from Google for building multimodal

MediaPipe is a an open-source framework from Google for building multimodal (eg. video, audio, any time series data), cross platform (i.e Android, iOS, web, edge devices) applied ML pipelines. It is

Bhavishya Pandit 3 Sep 30, 2022
Really awesome semantic segmentation

really-awesome-semantic-segmentation A list of all papers on Semantic Segmentation and the datasets they use. This site is maintained by Holger Caesar

Holger Caesar 400 Nov 28, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
《A-CNN: Annularly Convolutional Neural Networks on Point Clouds》(2019)

A-CNN: Annularly Convolutional Neural Networks on Point Clouds Created by Artem Komarichev, Zichun Zhong, Jing Hua from Department of Computer Science

Artёm Komarichev 44 Feb 24, 2022
DM-ACME compatible implementation of the Arm26 environment from Mujoco

ACME-compatible implementation of Arm26 from Mujoco This repository contains a customized implementation of Mujoco's Arm26 model, that can be used wit

1 Dec 24, 2021
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
A Gura parser implementation for Python

Gura Python parser This repository contains the implementation of a Gura (compliant with version 1.0.0) format parser in Python. Installation pip inst

Gura Config Lang 19 Jan 25, 2022
Fake News Detection Using Machine Learning Methods

Fake-News-Detection-Using-Machine-Learning-Methods Fake news is always a real and dangerous issue. However, with the presence and abundance of various

Achraf Safsafi 1 Jan 11, 2022
CIFAR-10 Photo Classification

Image-Classification CIFAR-10 Photo Classification CIFAR-10_Dataset_Classfication CIFAR-10 Photo Classification Dataset CIFAR is an acronym that stand

ADITYA SHAH 1 Jan 05, 2022
iNAS: Integral NAS for Device-Aware Salient Object Detection

iNAS: Integral NAS for Device-Aware Salient Object Detection Introduction Integral search design (jointly consider backbone/head structures, design/de

顾宇超 77 Dec 02, 2022
PyTorch Implementation of Google Brain's WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis

WaveGrad2 - PyTorch Implementation PyTorch Implementation of Google Brain's WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis. Status (202

Keon Lee 59 Dec 06, 2022
Sample code and notebooks for Vertex AI, the end-to-end machine learning platform on Google Cloud

Google Cloud Vertex AI Samples Welcome to the Google Cloud Vertex AI sample repository. Overview The repository contains notebooks and community conte

Google Cloud Platform 560 Dec 31, 2022
Spatial-Location-Constraint-Prototype-Loss-for-Open-Set-Recognition

Spatial Location Constraint Prototype Loss for Open Set Recognition Official PyTorch implementation of "Spatial Location Constraint Prototype Loss for

Xia Ziheng 12 Jun 24, 2022
Exploration-Exploitation Dilemma Solving Methods

Exploration-Exploitation Dilemma Solving Methods Medium article for this repo - HERE In ths repo I implemented two techniques for tackling mentioned t

Aman Mishra 6 Jan 25, 2022
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Facebook Research 3.8k Dec 22, 2022
SMPL-X: A new joint 3D model of the human body, face and hands together

SMPL-X: A new joint 3D model of the human body, face and hands together [Paper Page] [Paper] [Supp. Mat.] Table of Contents License Description News I

Vassilis Choutas 1k Jan 09, 2023