Details about the wide minima density hypothesis and metrics to compute width of a minima

Overview

wide-minima-density-hypothesis

Details about the wide minima density hypothesis and metrics to compute width of a minima

This repo presents the wide minima density hypothesis as proposed in the following paper:

Key contributions:

  • Hypothesis about minima density
  • A SOTA LR schedule that exploits the hypothesis and beats general baseline schedules
  • Reducing wall clock training time and saving GPU compute hours with our LR schedule (Pretraining BERT-Large in 33% less training steps)
  • SOTA BLEU score on IWSLT'14 ( DE-EN )

Prerequisite:

  • CUDA, cudnn
  • Python 3.6+
  • PyTorch 1.4.0

Knee LR Schedule

Based on the density of wide vs narrow minima , we propose the Knee LR schedule that pushes generalization boundaries further by exploiting the nature of the loss landscape. The LR schedule is an explore-exploit based schedule, where the explore phase maintains a high lr for a significant time to access and land into a wide minimum with a good probability. The exploit phase is a simple linear decay scheme, which decays the lr to zero over the exploit phase. The only hyperparameter to tune is the explore epochs/steps. We have shown that 50% of the training budget allocated for explore is good enough for landing in a wider minimum and better generalization, thus removing the need for hyperparameter tuning.

  • Note that many experiments require warmup, which is done in the initial phase of training for a fixed number of steps and is usually required for Adam based optimizers/ large batch training. It is complementary with the Knee schedule and can be added to it.

To use the Knee Schedule, import the scheduler into your training file:

>>> from knee_lr_schedule import KneeLRScheduler
>>> scheduler = KneeLRScheduler(optimizer, peak_lr, warmup_steps, explore_steps, total_steps)

To use it during training :

>>> model.train()
>>> output = model(inputs)
>>> loss = criterion(output, targets)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step()

Details about args:

  • optimizer: optimizer needed for training the model ( SGD/Adam )
  • peak_lr: the peak learning required for explore phase to escape narrow minimas
  • warmup_steps: steps required for warmup( usually needed for adam optimizers/ large batch training) Default value: 0
  • explore_steps: total steps for explore phase.
  • total_steps: total training budget steps for training the model

Measuring width of a minima

Keskar et.al 2016 (https://arxiv.org/abs/1609.04836) argue that wider minima generalize much better than sharper minima. The computation method in their work uses the compute expensive LBFGS-B second order method, which is hard to scale. We use a projected gradient ascent based method, which is first order in nature and very easy to implement/use. Here is a simple way you can compute the width of the minima your model finds during training:

>>> from minima_width_compute import ComputeKeskarSharpness
>>> cks = ComputeKeskarSharpness(model_final_ckpt, optimizer, criterion, trainloader, epsilon, lr, max_steps)
>>> width = cks.compute_sharpness()

Details about args:

  • model_final_ckpt: model loaded with the saved checkpoint after final training step
  • optimizer : optimizer to use for projected gradient ascent ( SGD, Adam )
  • criterion : criterion for computing loss (e.g. torch.nn.CrossEntropyLoss)
  • trainloader : iterator over the training dataset (torch.utils.data.DataLoader)
  • epsilon : epsilon value determines the local boundary around which minima witdh is computed (Default value : 1e-4)
  • lr : lr for the optimizer to perform projected gradient ascent ( Default: 0.001)
  • max_steps : max steps to compute the width (Default: 1000). Setting it too low could lead to the gradient ascent method not converging to an optimal point.

The above default values have been chosen after tuning and observing the loss values of projected gradient ascent on Cifar-10 with ResNet-18 and SGD-Momentum optimizer, as mentioned in our paper. The values may vary for experiments with other optimizers/datasets/models. Please tune them for optimal convergence.

  • Acknowledgements: We would like to thank Harshay Shah (https://github.com/harshays) for his helpful discussions for computing the width of the minima.

Citation

Please cite our paper in your publications if you use our work:

@article{iyer2020wideminima,
  title={Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule},
  author={Iyer, Nikhil and Thejas, V and Kwatra, Nipun and Ramjee, Ramachandran and Sivathanu, Muthian},
  journal={arXiv preprint arXiv:2003.03977},
  year={2020}
}
  • Note: This work was done during an internship at Microsoft Research India
Owner
Nikhil Iyer
Studied at BITS-Pilani Hyderabad Campus. AI Research @ Jio-Haptik. Ex Microsoft Research India
Nikhil Iyer
Use unsupervised and supervised learning to predict stocks

AIAlpha: Multilayer neural network architecture for stock return prediction This project is meant to be an advanced implementation of stacked neural n

Vivek Palaniappan 1.5k Dec 26, 2022
Instantaneous Motion Generation for Robots and Machines.

Ruckig Instantaneous Motion Generation for Robots and Machines. Ruckig generates trajectories on-the-fly, allowing robots and machines to react instan

Berscheid 374 Dec 23, 2022
Thermal Control of Laser Powder Bed Fusion using Deep Reinforcement Learning

This repository is the implementation of the paper "Thermal Control of Laser Powder Bed Fusion Using Deep Reinforcement Learning", linked here. The project makes use of the Deep Reinforcement Library

BaratiLab 11 Dec 27, 2022
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
Code of the paper "Deep Human Dynamics Prior" in ACM MM 2021.

Code of the paper "Deep Human Dynamics Prior" in ACM MM 2021. Figure 1: In the process of motion capture (mocap), some joints or even the whole human

Shinny cui 3 Oct 31, 2022
N-Person-Check-Checker-Splitter - A calculator app use to divide checks

N-Person-Check-Checker-Splitter This is my from-scratch programmed calculator ap

2 Feb 15, 2022
This is the dataset and code release of the OpenRooms Dataset.

This is the dataset and code release of the OpenRooms Dataset.

Visual Intelligence Lab of UCSD 95 Jan 08, 2023
The official project of SimSwap (ACM MM 2020)

SimSwap: An Efficient Framework For High Fidelity Face Swapping Proceedings of the 28th ACM International Conference on Multimedia The official reposi

Six_God 2.6k Jan 08, 2023
Single Image Random Dot Stereogram for Tensorflow

TensorFlow-SIRDS Single Image Random Dot Stereogram for Tensorflow SIRDS is a means to present 3D data in a 2D image. It allows for scientific data di

Greg Peatfield 5 Aug 10, 2022
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
A visualization tool to show a TensorFlow's graph like TensorBoard

tfgraphviz tfgraphviz is a module to visualize a TensorFlow's data flow graph like TensorBoard using Graphviz. tfgraphviz enables to provide a visuali

44 Nov 09, 2022
Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic Scenes", ICCV 2021.

Deep 3D Mask Volume for View Synthesis of Dynamic Scenes Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic S

Ken Lin 17 Oct 12, 2022
Super-BPD: Super Boundary-to-Pixel Direction for Fast Image Segmentation (CVPR 2020)

Super-BPD for Fast Image Segmentation (CVPR 2020) Introduction We propose direction-based super-BPD, an alternative to superpixel, for fast generic im

189 Dec 07, 2022
Harmonic Memory Networks for Graph Completion

HMemNetworks Code and documentation for Harmonic Memory Networks, a series of models for compositionally assembling representations of graph elements

mlalisse 0 Oct 27, 2021
Vehicle detection using machine learning and computer vision techniques for Udacity's Self-Driving Car Engineer Nanodegree.

Vehicle Detection Video demo Overview Vehicle detection using these machine learning and computer vision techniques. Linear SVM HOG(Histogram of Orien

hata 1.1k Dec 18, 2022
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
Awesome Weak-Shot Learning

Awesome Weak-Shot Learning In weak-shot learning, all categories are split into non-overlapped base categories and novel categories, in which base cat

BCMI 162 Dec 30, 2022
VGGVox models for Speaker Identification and Verification trained on the VoxCeleb (1 & 2) datasets

VGGVox models for speaker identification and verification This directory contains code to import and evaluate the speaker identification and verificat

338 Dec 27, 2022
CVPR2021 Content-Aware GAN Compression

Content-Aware GAN Compression [ArXiv] Paper accepted to CVPR2021. @inproceedings{liu2021content, title = {Content-Aware GAN Compression}, auth

52 Nov 06, 2022