Automatic learning-rate scheduler

Overview

AutoLRS

This is the PyTorch code implementation for the paper AutoLRS: Automatic Learning-Rate Schedule by Bayesian Optimization on the Fly published at ICLR 2021.

A TensorFlow version will appear in this repo later.

What is AutoLRS?

Finding a good learning rate schedule for a DNN model is non-trivial. The goal of AutoLRS is to automatically tune the learning rate (LR) over the course of training without human involvement. AutoLRS chops up the whole training process into a few training stages (each consists of τ steps), and its mission is to determine a constant LR for each training stage. AutoLRS treats the validation loss as a black-box function of LR, and uses Bayesian optimization (BO) to search for the best LR which can minimize the validation loss for each training stage. Because BO would require τ steps of training to evaluate the validation loss for each LR it explores, to reduce this cost, we only apply an LR to train the DNN for τ’ (τ’ << τ) steps and train an exponential time-series forecasting model to predict the loss after τ steps. In our default setting, τ’ = τ/10 and BO explores 10 LRs in each stage, so the number of steps for searching LR is equal to the number of steps for actual training.

AutoLRS does not depend on a pre-defined LR schedule, dataset, or a specified task and is compatible with almost all optimizers. The LR schedules auto-generated by AutoLRS lead to speedup over highly hand-tuned LR schedules for several state-of-the-art DNNs including ResNet-50, Transformer, and BERT.

Setup

$ pip install --user -r requirements.txt

How to use AutoLRS for your work?

autolrs_server.py is the brain of AutoLRS, which implements the search algorithm including BO and the exponential forecasting model.

autolrs_callback.py implements a callback which you can plug into your Pytorch training loop. The callback receives commands from the server via socket, adjusting the learning rate, saving/restoring model parameters and optimizer states according to commands sent from the server.

Notes

  • You need to pass two arguments min_lr and max_lr when launching autolrs_server.py to set the LR search interval. This interval can be found by an LR range test or simply set according to your experience. Do not set the min_lr too small (for example 1e-10), otherwise, BO will waste a lot of cycles to try exploring very small LR values.
  • The current AutoLRS does not search LR for warmup steps since warmup does not have an explicit optimization objective, such as minimizing the validation loss. Warmup usually takes very few steps, and its main purpose is to prevent deeper layers in a DNN from creating training instability, especially when training using a large batch size. You can manually add a warmup stage by setting warmup_step and warmup_lr when initializing the autolrs_callback.AutoLRS callback.

Example

We provide an example of using AutoLRS to train various DNNs on the CIFAR-10 dataset. The models are imported from kuangliu's great and simple pytorch-cifar repository.

Prerequisites: Python 3.6+, PyTorch 1.0+

Run the example

$ bash run.sh

Contact

You can contact us at [email protected]. We would love to hear your questions and feedback!

Poster

Owner
Yuchen Jin
Yuchen Jin
PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch.

snn-localization repo PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch. Install Dependencies Orig

Sami BARCHID 1 Jan 06, 2022
Instance-wise Feature Importance in Time (FIT)

Instance-wise Feature Importance in Time (FIT) FIT is a framework for explaining time series perdiction models, by assigning feature importance to eve

Sana 46 Dec 25, 2022
Neural Network to colorize grayscale images

#colornet Neural Network to colorize grayscale images Results Grayscale Prediction Ground Truth Eiji K used colornet for anime colorization Sources Au

Pavel Hanchar 3.6k Dec 24, 2022
smc.covid is an R package related to the paper A sequential Monte Carlo approach to estimate a time varying reproduction number in infectious disease models: the COVID-19 case by Storvik et al

smc.covid smc.covid is an R package related to the paper A sequential Monte Carlo approach to estimate a time varying reproduction number in infectiou

0 Oct 15, 2021
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

3 Nov 23, 2022
Repository providing a wide range of self-supervised pretrained models for computer vision tasks.

Hierarchical Pretraining: Research Repository This is a research repository for reproducing the results from the project "Self-supervised pretraining

Colorado Reed 53 Nov 09, 2022
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

Utkarsh Ojha 251 Dec 11, 2022
Code of paper: "DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks"

DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks Abstract: Adversarial training has been proven to

倪仕文 (Shiwen Ni) 58 Nov 10, 2022
Code for the paper "Can Active Learning Preemptively Mitigate Fairness Issues?" presented at RAI 2021.

Can Active Learning Preemptively Mitigate Fairness Issues? Code for the paper "Can Active Learning Preemptively Mitigate Fairness Issues?" presented a

ElementAI 7 Aug 12, 2022
Analyzing basic network responses to novel classes

novelty-detection Analyzing how AlexNet responds to novel classes with varying degrees of similarity to pretrained classes from ImageNet. If you find

Noam Eshed 34 Oct 02, 2022
KwaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%)

KuaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%) KuaiRec is a real-world dataset collected from the recommendation log

Chongming GAO (高崇铭) 70 Dec 28, 2022
This repository contains the implementation of the paper: "Towards Frequency-Based Explanation for Robust CNN"

RobustFreqCNN About This repository contains the implementation of the paper "Towards Frequency-Based Explanation for Robust CNN" arxiv. It primarly d

Sarosij Bose 2 Jan 23, 2022
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning

Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning

Emile van Krieken 140 Dec 30, 2022
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

THUDM 58 Dec 17, 2022
Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods Introduction Graph Neural Networks (GNNs) have demonstrated

37 Dec 15, 2022
AI pipelines for Nvidia Jetson Platform

Jetson Multicamera Pipelines Easy-to-use realtime CV/AI pipelines for Nvidia Jetson Platform. This project: Builds a typical multi-camera pipeline, i.

NVIDIA AI IOT 96 Dec 23, 2022
Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation

Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation Introduction WAKD is a PyTorch implementation for our ICPR-2022 pap

2 Oct 20, 2022
Syed Waqas Zamir 906 Dec 30, 2022
HairCLIP: Design Your Hair by Text and Reference Image

Overview This repository hosts the official PyTorch implementation of the paper: "HairCLIP: Design Your Hair by Text and Reference Image". Our single

322 Jan 06, 2023