PyTorch Personal Trainer: My framework for deep learning experiments

Related tags

Deep Learningptpt
Overview

Alex's PyTorch Personal Trainer (ptpt)

(name subject to change)

This repository contains my personal lightweight framework for deep learning projects in PyTorch.

Disclaimer: this project is very much work-in-progress. Although technically useable, it is missing many features. Nonetheless, you may find some of the design patterns and code snippets to be useful in the meantime.

Installation

Simply run python -m build in the root of the repo, then run pip install on the resulting .whl file.

No pip package yet..

Usage

Import the library as with any other python library:

from ptpt.trainer import Trainer, TrainerConfig
from ptpt.log import debug, info, warning, error, critical

The core of the library is the trainer.Trainer class. In the simplest case, it takes the following as input:

net:            a `nn.Module` that is the model we wish to train.
loss_fn:        a function that takes a `nn.Module` and a batch as input.
                it returns the loss and optionally other metrics.
train_dataset:  the training dataset.
test_dataset:   the test dataset.
cfg:            a `TrainerConfig` instance that holds all
                hyperparameters.

Once this is instantiated, starting the training loop is as simple as calling trainer.train() where trainer is an instance of Trainer.

cfg stores most of the configuration options for Trainer. See the class definition of TrainerConfig for details on all options.

Examples

An example workflow would go like this:

Define your training and test datasets:

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)

Define your model:

# in this case, we have imported `Net` from another file
net = Net()

Define your loss function that calls net, taking the full batch as input:

# minimising classification error
def loss_fn(net, batch):
    X, y = batch
    logits = net(X)
    loss = F.nll_loss(logits, y)

    pred = logits.argmax(dim=-1, keepdim=True)
    accuracy = 100. * pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
    return loss, accuracy

Optionally create a configuration object:

# see class definition for full list of parameters
cfg = TrainerConfig(
    exp_name = 'mnist-conv',
    batch_size = 64,
    learning_rate = 4e-4,
    nb_workers = 4,
    save_outputs = False,
    metric_names = ['accuracy']
)

Initialise the Trainer class:

trainer = Trainer(
    net=net,
    loss_fn=loss_fn,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    cfg=cfg
)

Call trainer.train() to begin the training loop

trainer.train() # Go!

See more examples here.

Motivation

I found myself repeating a lot of same structure in many of my deep learning projects. This project is the culmination of my efforts refining the typical structure of my projects into (what I hope to be) a wholly reusable and general-purpose library.

Additionally, there are many nice theoretical and engineering tricks that are available to deep learning researchers. Unfortunately, a lot of them are forgotten because they fall outside the typical workflow, despite them being very beneficial to include. Another goal of this project is to transparently include these tricks so they can be added and removed with minimal code change. Where it is sane to do so, some of these could be on by default.

Finally, I am guilty of forgetting to implement decent logging: both of standard output and of metrics. Logging of standard output is not hard, and is implemented using other libraries such as rich. However, metric logging is less obvious. I'd like to avoid larger dependencies such as tensorboard being an integral part of the project, so metrics will be logged to simple numpy arrays. The library will then provide functions to produce plots from these, or they can be used in another library.

TODO:

  • Make a todo.

References

Citations

Owner
Alex McKinney
Student at Durham University. I do a variety of things. I use Arch btw
Alex McKinney
Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Daniel Seita 121 Dec 30, 2022
[Pedestron] Generalizable Pedestrian Detection: The Elephant In The Room. @ CVPR2021

Pedestron Pedestron is a MMdetection based repository, that focuses on the advancement of research on pedestrian detection. We provide a list of detec

Irtiza Hasan 594 Jan 05, 2023
Bringing sanity to world of messed-up data

Sanitize sanitize is a Python module for making sure various things (e.g. HTML) are safe to use. It was originally written by Mark Pilgrim and is dist

Alireza Savand 63 Oct 26, 2021
Simple sinc interpolation in PyTorch.

Kazane: simple sinc interpolation for 1D signal in PyTorch Kazane utilize FFT based convolution to provide fast sinc interpolation for 1D signal when

Chin-Yun Yu 10 May 03, 2022
Air Quality Prediction Using LSTM

AirQualityPredictionUsingLSTM In this Repo, i present to you the winning solution of smart gujarat hackathon 2019 where the task was to predict the qu

Deepak Nandwani 2 Dec 13, 2022
This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability.

Delayed-cellular-neural-network This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability. There is als

4 Apr 28, 2022
Toontown House CT Edition

Toontown House: Classic Toontown House Classic source that should just work. ❓ W

Open Source Toontown Servers 5 Jan 09, 2022
Implementation of Research Paper "Learning to Enhance Low-Light Image via Zero-Reference Deep Curve Estimation"

Zero-DCE and Zero-DCE++(Lite architechture for Mobile and edge Devices) Papers Abstract The paper presents a novel method, Zero-Reference Deep Curve E

Tauhid Khan 15 Dec 10, 2022
Code accompanying the NeurIPS 2021 paper "Generating High-Quality Explanations for Navigation in Partially-Revealed Environments"

Generating High-Quality Explanations for Navigation in Partially-Revealed Environments This work presents an approach to explainable navigation under

RAIL Group @ George Mason University 1 Oct 28, 2022
Simple embedding based text classifier inspired by fastText, implemented in tensorflow

FastText in Tensorflow This project is based on the ideas in Facebook's FastText but implemented in Tensorflow. However, it is not an exact replica of

Alan Patterson 306 Dec 02, 2022
Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, Daniel Silva, Andrew McCallum, Amr Ahmed. KDD 2019.

gHHC Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, D

Nicholas Monath 35 Nov 16, 2022
Library of various Few-Shot Learning frameworks for text classification

FewShotText This repository contains code for the paper A Neural Few-Shot Text Classification Reality Check Environment setup # Create environment pyt

Thomas Dopierre 47 Jan 03, 2023
ShuttleNet: Position-aware Fusion of Rally Progress and Player Styles for Stroke Forecasting in Badminton (AAAI 2022)

ShuttleNet: Position-aware Rally Progress and Player Styles Fusion for Stroke Forecasting in Badminton (AAAI 2022) Official code of the paper ShuttleN

Wei-Yao Wang 11 Nov 30, 2022
Mosaic of Object-centric Images as Scene-centric Images (MosaicOS) for long-tailed object detection and instance segmentation.

MosaicOS Mosaic of Object-centric Images as Scene-centric Images (MosaicOS) for long-tailed object detection and instance segmentation. Introduction M

Cheng Zhang 27 Oct 12, 2022
General Multi-label Image Classification with Transformers

General Multi-label Image Classification with Transformers Jack Lanchantin, Tianlu Wang, Vicente Ordóñez Román, Yanjun Qi Conference on Computer Visio

QData 154 Dec 21, 2022
Torch-ngp - A pytorch implementation of the hash encoder proposed in instant-ngp

HashGrid Encoder (WIP) A pytorch implementation of the HashGrid Encoder from ins

hawkey 1k Jan 01, 2023
Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

41 Jan 03, 2023
GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification

GalaXC GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification @InProceedings{Saini21, author = {Saini, D. and Jain,

Extreme Classification 28 Dec 05, 2022
AttGAN: Facial Attribute Editing by Only Changing What You Want (IEEE TIP 2019)

News 11 Jan 2020: We clean up the code to make it more readable! The old version is here: v1. AttGAN TIP Nov. 2019, arXiv Nov. 2017 TensorFlow impleme

Zhenliang He 568 Dec 14, 2022
Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection

SAGA Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection Please refer to the Jupyter notebook (Example.ipynb) for an example of using t

9 Dec 28, 2022