Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Overview

Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Code for “Efficient Sharpness-aware Minimization for Improved Training of Neural Networks”

Requisite

This code is implemented in PyTorch, and we have tested the code under the following environment settings:

  • python = 3.8.8
  • torch = 1.8.0
  • torchvision = 0.9.0

What is in this repository

Codes for our ESAM on CIFAR10/CIFAR100 datasets.

How to use it

from utils.layer_dp_sam import ESAM
base_optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate,momentum=0.9,weight_decay=args.weight_decay)
optimizer = ESAM(paras, base_optimizer, rho=args.rho, weight_dropout=args.weight_dropout,adaptive=args.isASAM,nograd_cutoff=args.nograd_cutoff,opt_dropout = args.opt_dropout,temperature=args.temperature)

--beta the SWP hyperparameter

--gamma the SDS hyperparameter

During training loss_fct should have reduction="none", to return instance-wise losses. defined_backward is the function used for DDP and mixed precision backward

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def defined_backward():
    if args.fp16:
    with amp.scale_loss(loss, optimizer0) as scaled_loss:
        scaled_loss.backward()
    else:
        loss.backward()

paras = [inputs,targets,loss_fct,model,defined_backward]
optimizer.paras = paras
optimizer.step()
predictions_logits,loss = optimizer.returnthings

Example

bash run.sh

Reference Code

[1] SAM

Owner
Angusdu
Angusdu
Covid-19 Test AI (Deep Learning - NNs) Software. Accuracy is the %96.5, loss is the 0.09 :)

Covid-19 Test AI (Deep Learning - NNs) Software I developed a segmentation algorithm to understand whether Covid-19 Test Photos are positive or negati

Emirhan BULUT 28 Dec 04, 2021
A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

195 Dec 07, 2022
PyTorch implementation of the wavelet analysis from Torrence & Compo

Continuous Wavelet Transforms in PyTorch This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The co

Tom Runia 262 Dec 21, 2022
Manipulation OpenAI Gym environments to simulate robots at the STARS lab

Manipulator Learning This repository contains a set of manipulation environments that are compatible with OpenAI Gym and simulated in pybullet. In par

STARS Laboratory 5 Dec 08, 2022
[CVPR 2021] 'Searching by Generating: Flexible and Efficient One-Shot NAS with Architecture Generator'

[CVPR2021] Searching by Generating: Flexible and Efficient One-Shot NAS with Architecture Generator Overview This is the entire codebase for the paper

35 Dec 01, 2022
Jittor implementation of PCT:Point Cloud Transformer

PCT: Point Cloud Transformer This is a Jittor implementation of PCT: Point Cloud Transformer.

MenghaoGuo 547 Jan 03, 2023
pcnaDeep integrates cutting-edge detection techniques with tracking and cell cycle resolving models.

pcnaDeep: a deep-learning based single-cell cycle profiler with PCNA signal Welcome! pcnaDeep integrates cutting-edge detection techniques with tracki

ChanLab 8 Oct 18, 2022
Read and write layered TIFF ImageSourceData and ImageResources tags

Read and write layered TIFF ImageSourceData and ImageResources tags Psdtags is a Python library to read and write the Adobe Photoshop(r) specific Imag

Christoph Gohlke 4 Feb 05, 2022
Modelisation on galaxy evolution using PEGASE-HR

model_galaxy Modelisation on galaxy evolution using PEGASE-HR This is a labwork done in internship at IAP directed by Damien Le Borgne (https://github

Adrien Anthore 1 Jan 14, 2022
PyTorch code of "SLAPS: Self-Supervision Improves Structure Learning for Graph Neural Networks"

SLAPS-GNN This repo contains the implementation of the model proposed in SLAPS: Self-Supervision Improves Structure Learning for Graph Neural Networks

60 Dec 22, 2022
Global Filter Networks for Image Classification

Global Filter Networks for Image Classification Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou This repository contains PyTorch

Yongming Rao 273 Dec 26, 2022
💡 Type hints for Numpy

Type hints with dynamic checks for Numpy! (❒) Installation pip install nptyping (❒) Usage (❒) NDArray nptyping.NDArray lets you define the shape and

Ramon Hagenaars 377 Dec 28, 2022
TensorFlow implementation of "TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?"

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Source: Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Aritra Roy Gosthipaty 23 Dec 24, 2022
A series of convenience functions to make basic image processing operations such as translation, rotation, resizing, skeletonization, and displaying Matplotlib images easier with OpenCV and Python.

imutils A series of convenience functions to make basic image processing functions such as translation, rotation, resizing, skeletonization, and displ

Adrian Rosebrock 4.3k Jan 08, 2023
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning plugins for distributed training using the Ray distributed compu

167 Jan 02, 2023
Official PyTorch implementation of the NeurIPS 2021 paper StyleGAN3

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation of the NeurIPS 2021 paper Alias-Free Generative Adversarial Net

Eugenio Herrera 92 Nov 18, 2022
Large scale and asynchronous Hyperparameter Optimization at your fingertip.

Syne Tune This package provides state-of-the-art distributed hyperparameter optimizers (HPO) where trials can be evaluated with several backend option

Amazon Web Services - Labs 236 Jan 01, 2023
A motion detection system with RaspberryPi, OpenCV, Python

Human Detection System using Raspberry Pi Functionality Activates a relay on detecting motion. You may need following components to get the expected R

Omal Perera 55 Dec 04, 2022
Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Wenxuan Zhou 146 Nov 29, 2022
Official code for On Path Integration of Grid Cells: Group Representation and Isotropic Scaling (NeurIPS 2021)

On Path Integration of Grid Cells: Group Representation and Isotropic Scaling This repo contains the official implementation for the paper On Path Int

Ruiqi Gao 39 Nov 10, 2022