Implementation of the SUMO (Slim U-Net trained on MODA) model

Related tags

Deep Learningsumo
Overview

SUMO - Slim U-Net trained on MODA

Implementation of the SUMO (Slim U-Net trained on MODA) model as described in:

TODO: add reference to paper once available

Installation Guide

On Linux with anaconda or miniconda installed, the project can be used by running the following commands to clone the repository, create a new environment and install the required dependencies:

git clone https://github.com/dslaborg/sumo.git
cd sumo
conda env create --file environment.yaml
conda activate sumo

Scripts - Quick Guide

Running and evaluating an experiment

The main model training and evaluation procedure is implemented in bin/train.py and bin/eval.py using the Pytorch Lightning framework. A chosen configuration used to train the model is called an experiment, and the evaluation is carried out using a configuration and the result folder of a training run.

train.py

Trains the model as specified in the corresponding configuration file, writes its log to the console and saves a log file and intermediate results for Tensorboard and model checkpoints to a result directory.

Arguments:

  • -e NAME, --experiment NAME: name of experiment to run, for which a NAME.yaml file has to exist in the config directory; default is default

eval.py

Evaluates a trained model, either on the validation data or test data and reports the achieved metrics.

Arguments:

  • -e NAME, --experiment NAME: name of configuration file, that should be used for evaluation, for which a NAME.yaml file has to exist in the config directory; usually equals the experiment used to train the model; default is default
  • -i PATH, --input PATH: path containing the model that should be evaluated; the given input can either be a model checkpoint, which then will be used directly, or the output directory of a train.py execution, in which case the best model will be used from PATH/models/; if the configuration has cross validation enabled, the output directory is expected and the best model per fold will be obtained from PATH/fold_*/models/; no default value
  • -t, --test: if given, the test data is used instead of the validation data

Further example scripts

In addition to scripts used to create the figures in our manuscript (spindle_analysis.py, spindle_analysis_correlations.py and spindle_detection_examply.py), the scripts directory contains two scripts that demonstrate the usage of this project.

create_data_splits.py

Demonstrates the procedure used to split the data into test and non-test subjects and the subsequent creation of a hold-out validation set and (alternatively) cross validation folds.

Arguments:

  • -i PATH, --input PATH: path containing the (necessary) input data, as produced by the MODA file MODA02_genEEGVectBlock.m; relative paths starting from the scripts directory; default is ../input/
  • -o PATH, --output PATH: path in which the generated data splits should be stored in; relative paths starting from the scripts directory; default is ../output/datasets_{datatime}
  • -n NUMBER, --n_datasets NUMBER: number of random split-candidates drawn/generated; default is 25
  • -t FRACTION, --test FRACTION: Proportion of data that is used as test data; 0<=FRACTION<=1; default is 0.2

predict_plain_data.py

Demonstrates how to predict spindles with a trained SUMO model on arbitrary EEG data, which is expected as a dict with the keys representing the EEG channels and the values the corresponding data vector.

Arguments:

  • -d PATH, --data_path PATH: path containing the input data, either in .pickle or .npy format, as a dict with the channel name as key and the EEG data as value; relative paths starting from the scripts directory; no default value
  • -m PATH, --model_path PATH: path containing the model checkpoint, which should be used to predict spindles; relative paths starting from the scripts directory; default is ../output/final.ckpt
  • -g NUMBER, --gpus NUMBER: number of GPUs to use, if 0 is given, calculations are done using CPUs; default is 0
  • -sr RATE, --sample_rate RATE: sample rate of the provided data; default is 100.0

Project Setup

The project is set up as follows:

  • bin/: contains the train.py and eval.py scripts, which are used for model training and subsequent evaluation in experiments (as configured within the config directory) using the Pytorch Lightning framework
  • config/: contains the configurations of the experiments, configuring how to train or evaluate the model
    • default.yaml: provides a sensible default configuration
    • final.yaml: contains the configuration used to train the final model checkpoint (output/final.ckpt)
    • predict.yaml: configuration that can be used to predict spindles on arbitrary data, e.g. by using the script at scripts/predict_plain_data.py
  • input/: should contain the used input files, e.g. the EEG data and annotated spindles as produced by the MODA repository and transformed as demonstrated in the /scripts/create_data_splits.py file
  • output/: contains generated output by any experiment runs or scripts, e.g. the created figures
    • final.ckpt: the final model checkpoint, on which the test data performance, as reported in the paper, was obtained
  • scripts/: various scripts used to create the plots of our paper and to demonstrate the usage of this project
    • a7/: python implementation of the A7 algorithm as described in:
      Karine Lacourse, Jacques Delfrate, Julien Beaudry, Paul E. Peppard and Simon C. Warby. "A sleep spindle detection algorithm that emulates human expert spindle scoring." Journal of Neuroscience Methods 316 (2019): 3-11.
      
    • create_data_splits.py: demonstrates the procedure, how the data set splits were obtained, including the evaluation on the A7 algorithm
    • predict_plain_data.py: demonstrates the prediction of spindles on arbitrary EEG data, using a trained model checkpoint
    • spindle_analysis.py, spindle_analysis_correlations.py, spindle_detection_example.py: scripts used to create some of the figures used in our paper
  • sumo/: the implementation of the SUMO model and used classes and functions, for more information see the docstrings

Configuration Parameters

The configuration of an experiment is implemented using yaml configuration files. These files must be placed within the config directory and must match the name past as --experiment to the eval.py or train.py script. The default.yaml is always loaded as a set of default configuration parameters and parameters specified in an additional file overwrite the default values. Any parameters or groups of parameters that should be None, have to be configured as either null or Null following the YAML definition.

The available parameters are as follows:

  • data: configuration of the used input data; optional, can be None if spindle should be annotated on arbitrary EEG data
    • directory and file_name: the input file containing the Subject objects (see scripts/create_data_splits.py) is expected to be located at ${directory}/${file_name}, where relative paths are to be starting from the root project directory; the file should be a (pickled) dict with the name of a data set as key and the list of corresponding subjects as value; default is input/subjects.pickle
    • split: describing the keys of the data sets to be used, specifying either train and validation, or cross_validation, and optionally test
      • cross_validation: can be either an integer k>=2, in which the keys fold_0, ..., fold_{k-1} are expected to exist, or a list of keys
    • batch_size: size of the used minbatches during training; default is 12
    • preprocessing: if z-scoring should be performed on the EEG data, default is True
  • experiment: definition of the performed experiment; mandatory
    • model: definition of the model configuration; mandatory
      • n_classes: number of output parameters; default is 2
      • activation: name of an activation function as defined in torch.nn package; default is ReLU
      • depth: number of layers of the U excluding the last layer; default is 2
      • channel_size: number of filters of the convolutions in the first layer; default is 16
      • pools: list containing the size of pooling and upsampling operations; has to contain as many values as the value of depth; default [4;4]
      • convolution_params: parameters used by the Conv1d modules
      • moving_avg_size: width of the moving average filter; default is 42
    • train: configuration used in training the model; mandatory
      • n_epochs: maximal number of epochs to be run before stopping training; default is 800
      • early_stopping: number of epochs without any improvement in the val_f1_mean metric, after which training is stopped; default is 300
      • optimizer: configuration of an optimizer as defined in torch.optim package; contains class_name (default is Adam) and parameters, which are passed to the constructor of the used optimizer class
      • lr_scheduler: used learning rate scheduler; optional, default is None
      • loss: configuration of loss function as defined either in sumo.loss package (GeneralizedDiceLoss) or torch.nn package; contains class_name (default is GeneralizedDiceLoss) and parameters, which are passed to the constructor of the used loss class
    • validation: configuration used in evaluating the model; mandatory
      • overlap_threshold_step: step size of the overlap thresholds used to calculate (validation) F1 scores
Omnidirectional camera calibration in python

Omnidirectional Camera Calibration Key features pure python initial solution based on A Toolbox for Easily Calibrating Omnidirectional Cameras (Davide

Thomas Pönitz 12 Nov 22, 2022
Torchreid: Deep learning person re-identification in PyTorch.

Torchreid Torchreid is a library for deep-learning person re-identification, written in PyTorch. It features: multi-GPU training support both image- a

Kaiyang 3.7k Jan 05, 2023
A program to recognize fruits on pictures or videos using yolov5

Yolov5 Fruits Detector Requirements Either Linux or Windows. We recommend Linux for better performance. Python 3.6+ and PyTorch 1.7+. Installation To

Fateme Zamanian 30 Jan 06, 2023
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
The official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness.

This repository is the official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness. Requirements pip install -r requi

Jie Ren 17 Dec 12, 2022
Codebase for Inducing Causal Structure for Interpretable Neural Networks

Interchange Intervention Training (IIT) Codebase for Inducing Causal Structure for Interpretable Neural Networks Release Notes 12/01/2021: Code and Pa

Zen 6 Oct 10, 2022
A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning

A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning Website • About • Installation • Using OpenDR

OpenDR 304 Dec 28, 2022
tensorflow implementation of 'YOLO : Real-Time Object Detection'

YOLO_tensorflow (Version 0.3, Last updated :2017.02.21) 1.Introduction This is tensorflow implementation of the YOLO:Real-Time Object Detection It can

Jinyoung Choi 1.7k Nov 21, 2022
Python3 Implementation of (Subspace Constrained) Mean Shift Algorithm in Euclidean and Directional Product Spaces

(Subspace Constrained) Mean Shift Algorithms in Euclidean and/or Directional Product Spaces This repository contains Python3 code for the mean shift a

Yikun Zhang 0 Oct 19, 2021
Code for pre-training CharacterBERT models (as well as BERT models).

Pre-training CharacterBERT (and BERT) This is a repository for pre-training BERT and CharacterBERT. DISCLAIMER: The code was largely adapted from an o

Hicham EL BOUKKOURI 31 Dec 05, 2022
Pytorch implementation for DFN: Distributed Feedback Network for Single-Image Deraining.

DFN:Distributed Feedback Network for Single-Image Deraining Abstract Recently, deep convolutional neural networks have achieved great success for sing

6 Nov 05, 2022
Transformer Tracking (CVPR2021)

TransT - Transformer Tracking [CVPR2021] Official implementation of the TransT (CVPR2021) , including training code and trained models. We are revisin

chenxin 465 Jan 06, 2023
Code for classifying international patents based on the text of their titles/abstracts

Patent Classification Goal: To train a machine learning classifier that can automatically classify international patents downloaded from the WIPO webs

Prashanth Rao 1 Nov 08, 2022
Secure Distributed Training at Scale

Secure Distributed Training at Scale This repository contains the implementation of experiments from the paper "Secure Distributed Training at Scale"

Yandex Research 9 Jul 11, 2022
Source code for "OmniPhotos: Casual 360° VR Photography"

OmniPhotos: Casual 360° VR Photography Project Page | Video | Paper | Demo | Data This repository contains the source code for creating and viewing Om

Christian Richardt 144 Dec 30, 2022
Official repository for "Deep Recurrent Neural Network with Multi-scale Bi-directional Propagation for Video Deblurring".

RNN-MBP Deep Recurrent Neural Network with Multi-scale Bi-directional Propagation for Video Deblurring (AAAI-2022) by Chao Zhu, Hang Dong, Jinshan Pan

SIV-LAB 22 Aug 31, 2022
Code for "Offline Meta-Reinforcement Learning with Advantage Weighting" [ICML 2021]

Offline Meta-Reinforcement Learning with Advantage Weighting (MACAW) MACAW code used for the experiments in the ICML 2021 paper. Installing the enviro

Eric Mitchell 28 Jan 01, 2023
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022