Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)

Related tags

Deep Learningautowu
Overview

Automated Learning Rate Scheduler for Large-Batch Training

The official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML).

Overview

AutoWU is an automated LR scheduler which consists of two phases: warmup and decay. Learning rate (LR) is increased in an exponential rate until the loss starts to increase, and in the decay phase LR is decreased following the pre-specified type of the decay (either cosine or constant-then-cosine, in our experiments).

Transition from the warmup to the decay phase is done automatically by testing whether the minimum of the predicted loss curve is attained in the past or not with high probability, and the prediction is made via Gaussian Process regression.

Diagram summarizing AutoWU

How to use

Setup

pip install -r requirements.txt

Quick use

You can use AutoWU as other PyTorch schedulers, except that it takes loss as an argument (like ReduceLROnPlateau in PyTorch). The following code snippet demonstrates a typical usage of AutoWU.

from autowu import AutoWU

...

scheduler = AutoWU(optimizer,
                   len(train_loader),  # the number of steps in one epoch 
                   total_epochs,  # total number of epochs
                   immediate_cooldown=True,
                   cooldown_type='cosine',
                   device=device)

...

for _ in range(total_epochs):
    for inputs, targets in train_loader:
        loss = loss_fn(model(inputs), targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step(loss)

The default decay phase schedule is ''cosine''. To use constant-then-cosine schedule rather than cosine, set immediate_cooldown=False and set cooldown_fraction to a desired value:

scheduler = AutoWU(optimizer,
                   len(train_loader),  # the number of steps in one epoch 
                   total_epochs,  # total number of epochs
                   immediate_cooldown=False,
                   cooldown_type='cosine',
                   cooldown_fraction=0.2,  # fraction of cosine decay at the end
                   device=device)

Reproduction of results

We provide an exemplar training script train.py which is based on Pytorch Image Models. The script supports training ResNet-50 and EfficientNet-B0 on ImageNet classification under the setting almost identical to the paper. We report the top-1 accuracy of ResNet-50 and EfficientNet-B0 on the validation set trained with batch sizes 4K (4096) and 16K (16384), along with the scores reported in our paper.

ResNet-50 This repo. Reported (paper)
4K 75.54% 75.70%
16K 74.87% 75.22%
EfficientNet-B0 This repo. Reported (paper)
4K 75.74% 75.81%
16K 75.66% 75.44%

You can use distributed.launch util to run the script. For instance, in case of ResNet-50 training with batch size 4096, execute the following line with variables set according to your environment:

python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=4 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py \
--data-root $DATA_ROOT \
--amp \
--batch-size 256 

In addition, add --model efficientnet_b0 argument in case of EfficientNet-B0 training.

Citation

@inproceedings{
    kim2021automated,
    title={Automated Learning Rate Scheduler for Large-batch Training},
    author={Chiheon Kim and Saehoon Kim and Jongmin Kim and Donghoon Lee and Sungwoong Kim},
    booktitle={8th ICML Workshop on Automated Machine Learning (AutoML)},
    year={2021},
    url={https://openreview.net/forum?id=ljIl7KCNYZH}
}

License

This project is licensed under the terms of Apache License 2.0. Copyright 2021 Kakao Brain. All right reserved.

Owner
Kakao Brain
Kakao Brain Corp.
Kakao Brain
fklearn: Functional Machine Learning

fklearn: Functional Machine Learning fklearn uses functional programming principles to make it easier to solve real problems with Machine Learning. Th

nubank 1.4k Dec 07, 2022
Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

7 Jun 22, 2022
Like ThreeJS but for Python and based on wgpu

pygfx A render engine, inspired by ThreeJS, but for Python and targeting Vulkan/Metal/DX12 (via wgpu). Introduction This is a Python render engine bui

139 Jan 07, 2023
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 864 Dec 30, 2022
library for nonlinear optimization, wrapping many algorithms for global and local, constrained or unconstrained, optimization

NLopt is a library for nonlinear local and global optimization, for functions with and without gradient information. It is designed as a simple, unifi

Steven G. Johnson 1.4k Dec 25, 2022
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions"

Graph Convolution Simulator (GCS) Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions" Requirements: PyTor

yifan 10 Oct 18, 2022
PyTorch implementation of "Debiased Visual Question Answering from Feature and Sample Perspectives" (NeurIPS 2021)

D-VQA We provide the PyTorch implementation for Debiased Visual Question Answering from Feature and Sample Perspectives (NeurIPS 2021). Dependencies P

Zhiquan Wen 19 Dec 22, 2022
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Deep Cognition and Language Research (DeCLaRe) Lab 89 Dec 26, 2022
A python comtrade load library accelerated by go

Comtrade-GRPC Code for python used is mainly from dparrini/python-comtrade. Just patch the code in BinaryDatReader.parse for parsing a little more eff

Bo 1 Dec 27, 2021
Full Stack Deep Learning Labs

Full Stack Deep Learning Labs Welcome! Project developed during lab sessions of the Full Stack Deep Learning Bootcamp. We will build a handwriting rec

Full Stack Deep Learning 1.2k Dec 31, 2022
基于Flask开发后端、VUE开发前端框架,在WEB端部署YOLOv5目标检测模型

基于Flask开发后端、VUE开发前端框架,在WEB端部署YOLOv5目标检测模型

37 Jan 01, 2023
Malware Bypass Research using Reinforcement Learning

Malware Bypass Research using Reinforcement Learning

Bobby Filar 76 Dec 26, 2022
Using fully convolutional networks for semantic segmentation with caffe for the cityscapes dataset

Using fully convolutional networks for semantic segmentation (Shelhamer et al.) with caffe for the cityscapes dataset How to get started Download the

Simon Guist 27 Jun 06, 2022
Stock-Prediction - prediction of stock market movements using sentiment analysis and deep learning.

Stock-Prediction- In this project, we aim to enhance the prediction of stock market movements using sentiment analysis and deep learning. We divide th

5 Jan 25, 2022
Revisiting Weakly Supervised Pre-Training of Visual Perception Models

SWAG: Supervised Weakly from hashtAGs This repository contains SWAG models from the paper Revisiting Weakly Supervised Pre-Training of Visual Percepti

Meta Research 134 Jan 05, 2023
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
一个多语言支持、易使用的 OCR 项目。An easy-to-use OCR project with multilingual support.

AgentOCR 简介 AgentOCR 是一个基于 PaddleOCR 和 ONNXRuntime 项目开发的一个使用简单、调用方便的 OCR 项目 本项目目前包含 Python Package 【AgentOCR】 和 OCR 标注软件 【AgentOCRLabeling】 使用指南 Pytho

AgentMaker 98 Nov 10, 2022
Frigate - NVR With Realtime Object Detection for IP Cameras

A complete and local NVR designed for HomeAssistant with AI object detection. Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras.

Blake Blackshear 6.4k Dec 31, 2022
Custom implementation of Corrleation Module

Pytorch Correlation module this is a custom C++/Cuda implementation of Correlation module, used e.g. in FlowNetC This tutorial was used as a basis for

Clément Pinard 361 Dec 12, 2022