Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

Related tags

Text Data & NLPASWS
Overview

This codebase is being actively maintained, please create and issue if you have issues using it

Basics

All data files are included under losses and each folder. The main Augmented Shapiro-Wilk Stopping criterion is implemented in analysis.py, along with several helper functions and wrappers. The other comparison heuristics are also included in analysis.py, along with their wrappers. grapher.py contains all the code for generating the graphs used in the paper, and earlystopping_calculator.py includes code for generating tables and calculating some statistics from the data. hyperparameter_search.py contains all the code used to execute the grid-search on the ASWS method, along with the grid-search for the other heuristics.

Installing

If you would like to try our code, just run pip3 install git+https://github.com/justinkterry/ASWS

Example

If you wanted to try to determine the ASWS stopping point of a model, you can do so using the analysis.py file. If at anypoint during model training you wanted to perform the stop criterion test, you can do

from ASWS.analysis import aswt_stopping

test_acc = [] # for storing model accuracies
for i in training_epochs:

    model.train()
    test_accuracy = model.evaluate(test_set)
    test_acc.append(test_accuracy)
    gamma = 0.5 # fill hyperparameters as desired
    num_data = 20
    slack_prop=0.1
    count = 20

    if len(test_acc) > count:
        aswt_stop_criterion = aswt_stopping(test_acc, gamma, count, num_data, slack_prop=slack_prop)

        if aswt_stop_criterion:
            print("Stop Training")

and if you already have finished training the model and wanted to determine the ASWS stopping point, you would need a CSV with columns Epoch, Training Loss, Training Acc, Test Loss, Test Acc. You could then use the following example

from ASWS.analysis import get_aswt_stopping_point_of_model, read_file

_, _, _, test_acc = read_file("modelaccuracy.csv")
gamma = 0.5 # fill hyperparameters as desired
num_data = 20
slack_prop=0.1
count = 20

stop_epoch, stop_accuracy = get_aswt_stopping_point_of_model(test_acc, gamma=gamma, num_data=num_data, count=count, slack_prop=slack_prop)

pytorch-training

The pytorch-training folder contains the driver file for training each model, along with the model files which contain each network definition. The main.py file can be run out of the box for the models listed in the paper. The model to train is specified via the --model argument. All learning rate schedulers listed in the paper are available (via --schedule step etc.) and the ASWS learning rate scheduler is available via --schedule ASWT . The corresponding ASWS hyperparameters are passed in at the command line (for example --gamma 0.5).

Example

In order to recreate the GoogLeNet ASWT 1 scheduler from the paper, you can use the following command

python3 main.py --model GoogLeNet --schedule ASWT --gamma 0.76 --num_data 19 --slack_prop 0.05 --lr 0.1

Owner
Justin Terry
CS PhD student at UMD. I work in deep reinforcement learning.
Justin Terry
Nystromformer: A Nystrom-based Algorithm for Approximating Self-Attention

Nystromformer: A Nystrom-based Algorithm for Approximating Self-Attention April 6, 2021 We extended segment-means to compute landmarks without requiri

Zhanpeng Zeng 322 Jan 01, 2023
CredData is a set of files including credentials in open source projects

CredData is a set of files including credentials in open source projects. CredData includes suspicious lines with manual review results and more information such as credential types for each suspicio

Samsung 19 Sep 07, 2022
Flaxformer: transformer architectures in JAX/Flax

Flaxformer: transformer architectures in JAX/Flax Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used

Google 114 Dec 29, 2022
SciBERT is a BERT model trained on scientific text.

SciBERT is a BERT model trained on scientific text.

AI2 1.2k Dec 24, 2022
🚀 RocketQA, dense retrieval for information retrieval and question answering, including both Chinese and English state-of-the-art models.

In recent years, the dense retrievers based on pre-trained language models have achieved remarkable progress. To facilitate more developers using cutt

475 Jan 04, 2023
ACL22 paper: Imputing Out-of-Vocabulary Embeddings with LOVE Makes Language Models Robust with Little Cost

Imputing Out-of-Vocabulary Embeddings with LOVE Makes Language Models Robust with Little Cost LOVE is accpeted by ACL22 main conference as a long pape

Lihu Chen 32 Jan 03, 2023
Spokestack is a library that allows a user to easily incorporate a voice interface into any Python application with a focus on embedded systems.

Welcome to Spokestack Python! This library is intended for developing voice interfaces in Python. This can include anything from Raspberry Pi applicat

Spokestack 133 Sep 20, 2022
AI-powered literature discovery and review engine for medical/scientific papers

AI-powered literature discovery and review engine for medical/scientific papers paperai is an AI-powered literature discovery and review engine for me

NeuML 819 Dec 30, 2022
GraphNLI: A Graph-based Natural Language Inference Model for Polarity Prediction in Online Debates

GraphNLI: A Graph-based Natural Language Inference Model for Polarity Prediction in Online Debates Vibhor Agarwal, Sagar Joglekar, Anthony P. Young an

Vibhor Agarwal 2 Jun 30, 2022
Seonghwan Kim 24 Sep 11, 2022
🐍 A hyper-fast Python module for reading/writing JSON data using Rust's serde-json.

A hyper-fast, safe Python module to read and write JSON data. Works as a drop-in replacement for Python's built-in json module. This is alpha software

Matthias 479 Jan 01, 2023
A highly sophisticated sequence-to-sequence model for code generation

CoderX A proof-of-concept AI system by Graham Neubig (June 30, 2021). About CoderX CoderX is a retrieval-based code generation AI system reminiscent o

Graham Neubig 39 Aug 03, 2021
Deep Learning for Natural Language Processing - Lectures 2021

This repository contains slides for the course "20-00-0947: Deep Learning for Natural Language Processing" (Technical University of Darmstadt, Summer term 2021).

0 Feb 21, 2022
KR-FinBert And KR-FinBert-SC

KR-FinBert & KR-FinBert-SC Much progress has been made in the NLP (Natural Language Processing) field, with numerous studies showing that domain adapt

5 Jul 29, 2022
An easy to use Natural Language Processing library and framework for predicting, training, fine-tuning, and serving up state-of-the-art NLP models.

Welcome to AdaptNLP A high level framework and library for running, training, and deploying state-of-the-art Natural Language Processing (NLP) models

Novetta 407 Jan 03, 2023
Text Normalization(文本正则化)

Text Normalization(文本正则化) 任务描述:通过机器学习算法将英文文本的“手写”形式转换成“口语“形式,例如“6ft”转换成“six feet”等 实验结果 XGBoost + bag-of-words: 0.99159 XGBoost+Weights+rules:0.99002

Jason_Zhang 0 Feb 26, 2022
💬 Open source machine learning framework to automate text- and voice-based conversations: NLU, dialogue management, connect to Slack, Facebook, and more - Create chatbots and voice assistants

Rasa Open Source Rasa is an open source machine learning framework to automate text-and voice-based conversations. With Rasa, you can build contextual

Rasa 15.3k Dec 30, 2022
Mednlp - Medical natural language parsing and utility library

Medical natural language parsing and utility library A natural language medical

Paul Landes 3 Aug 24, 2022
A Facebook Messenger Chatbot using NLP

A Facebook Messenger Chatbot using NLP This project is about creating a messenger chatbot using basic NLP techniques and models like Logistic Regressi

6 Nov 20, 2022
Prithivida 690 Jan 04, 2023