Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

Related tags

Text Data & NLPASWS
Overview

This codebase is being actively maintained, please create and issue if you have issues using it

Basics

All data files are included under losses and each folder. The main Augmented Shapiro-Wilk Stopping criterion is implemented in analysis.py, along with several helper functions and wrappers. The other comparison heuristics are also included in analysis.py, along with their wrappers. grapher.py contains all the code for generating the graphs used in the paper, and earlystopping_calculator.py includes code for generating tables and calculating some statistics from the data. hyperparameter_search.py contains all the code used to execute the grid-search on the ASWS method, along with the grid-search for the other heuristics.

Installing

If you would like to try our code, just run pip3 install git+https://github.com/justinkterry/ASWS

Example

If you wanted to try to determine the ASWS stopping point of a model, you can do so using the analysis.py file. If at anypoint during model training you wanted to perform the stop criterion test, you can do

from ASWS.analysis import aswt_stopping

test_acc = [] # for storing model accuracies
for i in training_epochs:

    model.train()
    test_accuracy = model.evaluate(test_set)
    test_acc.append(test_accuracy)
    gamma = 0.5 # fill hyperparameters as desired
    num_data = 20
    slack_prop=0.1
    count = 20

    if len(test_acc) > count:
        aswt_stop_criterion = aswt_stopping(test_acc, gamma, count, num_data, slack_prop=slack_prop)

        if aswt_stop_criterion:
            print("Stop Training")

and if you already have finished training the model and wanted to determine the ASWS stopping point, you would need a CSV with columns Epoch, Training Loss, Training Acc, Test Loss, Test Acc. You could then use the following example

from ASWS.analysis import get_aswt_stopping_point_of_model, read_file

_, _, _, test_acc = read_file("modelaccuracy.csv")
gamma = 0.5 # fill hyperparameters as desired
num_data = 20
slack_prop=0.1
count = 20

stop_epoch, stop_accuracy = get_aswt_stopping_point_of_model(test_acc, gamma=gamma, num_data=num_data, count=count, slack_prop=slack_prop)

pytorch-training

The pytorch-training folder contains the driver file for training each model, along with the model files which contain each network definition. The main.py file can be run out of the box for the models listed in the paper. The model to train is specified via the --model argument. All learning rate schedulers listed in the paper are available (via --schedule step etc.) and the ASWS learning rate scheduler is available via --schedule ASWT . The corresponding ASWS hyperparameters are passed in at the command line (for example --gamma 0.5).

Example

In order to recreate the GoogLeNet ASWT 1 scheduler from the paper, you can use the following command

python3 main.py --model GoogLeNet --schedule ASWT --gamma 0.76 --num_data 19 --slack_prop 0.05 --lr 0.1

Owner
Justin Terry
CS PhD student at UMD. I work in deep reinforcement learning.
Justin Terry
LOT: A Benchmark for Evaluating Chinese Long Text Understanding and Generation

LOT: A Benchmark for Evaluating Chinese Long Text Understanding and Generation Tasks | Datasets | LongLM | Baselines | Paper Introduction LOT is a ben

46 Dec 28, 2022
justCTF [*] 2020 challenges sources

justCTF [*] 2020 This repo contains sources for justCTF [*] 2020 challenges hosted by justCatTheFish. TLDR: Run a challenge with ./run.sh (requires Do

justCatTheFish 25 Dec 27, 2022
A benchmark for evaluation and comparison of various NLP tasks in Persian language.

Persian NLP Benchmark The repository aims to track existing natural language processing models and evaluate their performance on well-known datasets.

Mofid AI 68 Dec 19, 2022
Code and data accompanying Natural Language Processing with PyTorch

Natural Language Processing with PyTorch Build Intelligent Language Applications Using Deep Learning By Delip Rao and Brian McMahan Welcome. This is a

Joostware 1.8k Jan 01, 2023
Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR

Speech_38_ru_commands Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR Программа умеет распознавать 38 ключевы

Andrey 9 May 05, 2022
Black for Python docstrings and reStructuredText (rst).

Style-Doc Style-Doc is Black for Python docstrings and reStructuredText (rst). It can be used to format docstrings (Google docstring format) in Python

Telekom Open Source Software 13 Oct 24, 2022
Ongoing research training transformer language models at scale, including: BERT & GPT-2

What is this fork of Megatron-LM and Megatron-DeepSpeed This is a detached fork of https://github.com/microsoft/Megatron-DeepSpeed, which in itself is

BigScience Workshop 316 Jan 03, 2023
Smart discord chatbot integrated with Dialogflow

academic-NLP-chatbot Smart discord chatbot integrated with Dialogflow to interact with students naturally and manage different classes in a school. De

Tom Huynh 5 Oct 24, 2022
This repo stores the codes for topic modeling on palliative care journals.

This repo stores the codes for topic modeling on palliative care journals. Data Preparation You first need to download the journal papers. bash 1_down

3 Dec 20, 2022
Applying "Load What You Need: Smaller Versions of Multilingual BERT" to LaBSE

smaller-LaBSE LaBSE(Language-agnostic BERT Sentence Embedding) is a very good method to get sentence embeddings across languages. But it is hard to fi

Jeong Ukjae 13 Sep 02, 2022
Programme de chiffrement et de déchiffrement inverse d'un message en python3.

Chiffrement Inverse En Python3 Programme de chiffrement et de déchiffrement inverse d'un message en python3. Explication du chiffrement inverse avec c

Malik Makkes 2 Mar 26, 2022
Spacy-ginza-ner-webapi - Named Entity Recognition API with spaCy and GiNZA

Named Entity Recognition API with spaCy and GiNZA I wrote a blog post about this

Yuki Okuda 3 Feb 27, 2022
To classify the News into Real/Fake using Features from the Text Content of the article

Hoax-Detector Authenticity of news has now become a major problem. The Idea is to classify the News into Real/Fake using Features from the Text Conten

Aravindhan 1 Feb 09, 2022
NLP-SentimentAnalysis - Coursera Course ( Duration : 5 weeks ) offered by DeepLearning.AI

Coursera Natural Language Processing Specialization This repository contains material related to Coursera Natural Language Processing Specialization.

Nishant Sharma 1 Jun 05, 2022
Quantifiers and Negations in RE Documents

Quantifiers-and-Negations-in-RE-Documents This project was part of my work for a

Nicolas Ruscher 1 Feb 01, 2022
💥 Fast State-of-the-Art Tokenizers optimized for Research and Production

Provides an implementation of today's most used tokenizers, with a focus on performance and versatility. Main features: Train new vocabularies and tok

Hugging Face 6.2k Dec 31, 2022
A Streamlit web app that generates Rick and Morty stories using GPT2.

Rick and Morty Story Generator This project uses a pre-trained GPT2 model, which was fine-tuned on Rick and Morty transcripts, to generate new stories

₸ornike 33 Oct 13, 2022
Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning

Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning English | 中文 ❗ Now we provide inferencing code and pre-training models

164 Jan 02, 2023
This project aims to conduct a text information retrieval and text mining on medical research publication regarding Covid19 - treatments and vaccinations.

Project: Text Analysis - This project aims to conduct a text information retrieval and text mining on medical research publication regarding Covid19 -

1 Mar 14, 2022
An implementation of model parallel GPT-2 and GPT-3-style models using the mesh-tensorflow library.

GPT Neo 🎉 1T or bust my dudes 🎉 An implementation of model & data parallel GPT3-like models using the mesh-tensorflow library. If you're just here t

EleutherAI 6.7k Dec 28, 2022