This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Overview

Predicting Patient Outcomes with Graph Representation Learning

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning. You can watch a video of the spotlight talk at W3PHIAI (AAAI workshop) here:

Watch the video

Citation

If you use this code or the models in your research, please cite the following:

@misc{rocheteautong2021,
      title={Predicting Patient Outcomes with Graph Representation Learning}, 
      author={Emma Rocheteau and Catherine Tong and Petar Veličković and Nicholas Lane and Pietro Liò},
      year={2021},
      eprint={2101.03940},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Motivation

Recent work on predicting patient outcomes in the Intensive Care Unit (ICU) has focused heavily on the physiological time series data, largely ignoring sparse data such as diagnoses and medications. When they are included, they are usually concatenated in the late stages of a model, which may struggle to learn from rarer disease patterns. Instead, we propose a strategy to exploit diagnoses as relational information by connecting similar patients in a graph. To this end, we propose LSTM-GNN for patient outcome prediction tasks: a hybrid model combining Long Short-Term Memory networks (LSTMs) for extracting temporal features and Graph Neural Networks (GNNs) for extracting the patient neighbourhood information. We demonstrate that LSTM-GNNs outperform the LSTM-only baseline on length of stay prediction tasks on the eICU database. More generally, our results indicate that exploiting information from neighbouring patient cases using graph neural networks is a promising research direction, yielding tangible returns in supervised learning performance on Electronic Health Records.

Pre-Processing Instructions

eICU Pre-Processing

  1. To run the sql files you must have the eICU database set up: https://physionet.org/content/eicu-crd/2.0/.

  2. Follow the instructions: https://eicu-crd.mit.edu/tutorials/install_eicu_locally/ to ensure the correct connection configuration.

  3. Replace the eICU_path in paths.json to a convenient location in your computer, and do the same for eICU_preprocessing/create_all_tables.sql using find and replace for '/Users/emmarocheteau/PycharmProjects/eICU-GNN-LSTM/eICU_data/'. Leave the extra '/' at the end.

  4. In your terminal, navigate to the project directory, then type the following commands:

    psql 'dbname=eicu user=eicu options=--search_path=eicu'
    

    Inside the psql console:

    \i eICU_preprocessing/create_all_tables.sql
    

    This step might take a couple of hours.

    To quit the psql console:

    \q
    
  5. Then run the pre-processing scripts in your terminal. This will need to run overnight:

    python3 -m eICU_preprocessing.run_all_preprocessing
    

Graph Construction

To make the graphs, you can use the following scripts:

This is to make most of the graphs that we use. You can alter the arguments given to this script.

python3 -m graph_construction.create_graph --freq_adjust --penalise_non_shared --k 3 --mode k_closest

Write the diagnosis strings into eICU_data folder:

python3 -m graph_construction.get_diagnosis_strings

Get the bert embeddings:

python3 -m graph_construction.bert

Create the graph from the bert embeddings:

python3 -m graph_construction.create_bert_graph --k 3 --mode k_closest

Alternatively, you can request to download our graphs using this link: https://drive.google.com/drive/folders/1yWNLhGOTPhu6mxJRjKCgKRJCJjuToBS4?usp=sharing

Training the ML Models

Before proceeding to training the ML models, do the following.

  1. Define data_dir, graph_dir, log_path and ray_dir in paths.json to convenient locations.

  2. Run the following to unpack the processed eICU data into mmap files for easy loading during training. The mmap files will be saved in data_dir.

    python3 -m src.dataloader.convert
    

The following commands train and evaluate the models introduced in our paper.

N.B.

  • The models are structured using pytorch-lightning. Graph neural networks and neighbourhood sampling are implemented using pytorch-geometric.

  • Our models assume a default graph which is made with k=3 under a k-closest scheme. If you wish to use other graphs, refer to read_graph_edge_list in src/dataloader/pyg_reader.py to add a reference handle to version2filename for your graph.

  • The default task is In-House-Mortality Prediction (ihm), add --task los to the command to perform the Length-of-Stay Prediction (los) task instead.

  • These commands use the best set of hyperparameters; To use other hyperparameters, remove --read_best from the command and refer to src/args.py.

a. LSTM-GNN

The following runs the training and evaluation for LSTM-GNN models. --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_lstmgnn --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstmgnn_search --bilstm --ts_mask --add_flat --class_weights  --gnn_name gat --add_diag

b. Dynamic LSTM-GNN

The following runs the training & evaluation for dynamic LSTM-GNN models. --gnn_name can be set as gcn, gat, or mpnn.

python3 -m train_dynamic --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.dynamic_lstmgnn_search --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn

c. GNN

The following runs the GNN models (with neighbourhood sampling). --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_gnn --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.ns_gnn_search --ts_mask --add_flat --class_weights --gnn_name gat --add_diag

d. LSTM (Baselines)

The following runs the baseline bi-LSTMs. To remove diagnoses from the input vector, remove --add_diag from the command.

python3 -m train_ns_lstm --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstm_search --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag
Owner
Emma Rocheteau
Computer Science PhD Student at Cambridge
Emma Rocheteau
Image Captioning on google cloud platform based on iot

Image-Captioning-on-google-cloud-platform-based-on-iot - Image Captioning on google cloud platform based on iot

Shweta_kumawat 1 Jan 20, 2022
QICK: Quantum Instrumentation Control Kit

QICK: Quantum Instrumentation Control Kit The QICK is a kit of firmware and software to use the Xilinx RFSoC to control quantum systems. It consists o

81 Dec 15, 2022
Adversarial Attacks are Reversible via Natural Supervision

Adversarial Attacks are Reversible via Natural Supervision ICCV2021 Citation @InProceedings{Mao_2021_ICCV, author = {Mao, Chengzhi and Chiquier

Computer Vision Lab at Columbia University 20 May 22, 2022
LSTM model trained on a small dataset of 3000 names written in PyTorch

LSTM model trained on a small dataset of 3000 names. Model generates names from model by selecting one out of top 3 letters suggested by model at a time until an EOS (End Of Sentence) character is no

Sahil Lamba 1 Dec 20, 2021
Federated Learning Based on Dynamic Regularization

Federated Learning Based on Dynamic Regularization This is implementation of Federated Learning Based on Dynamic Regularization. Requirements Please i

39 Jan 07, 2023
A study project using the AA-RMVSNet to reconstruct buildings from multiple images

3d-building-reconstruction This is part of a study project using the AA-RMVSNet to reconstruct buildings from multiple images. Introduction It is exci

17 Oct 17, 2022
Groceries ARL: Association Rules (Birliktelik Kuralı)

Groceries_ARL Association Rules (Birliktelik Kuralı) Birliktelik kuralları, mark

Şebnem 5 Feb 08, 2022
Multistream CNN for Robust Acoustic Modeling

Multistream Convolutional Neural Network (CNN) A multistream CNN is a novel neural network architecture for robust acoustic modeling in speech recogni

ASAPP Research 37 Sep 21, 2022
Image Super-Resolution Using Very Deep Residual Channel Attention Networks

Image Super-Resolution Using Very Deep Residual Channel Attention Networks

kongdebug 14 Oct 14, 2022
Make a surveillance camera from your raspberry pi!

rpi-surveillance Make a surveillance camera from your Raspberry Pi 4! The surveillance is built as following: the camera records 10 seconds video and

Vladyslav 62 Feb 03, 2022
Nest - A flexible tool for building and sharing deep learning modules

Nest - A flexible tool for building and sharing deep learning modules Nest is a flexible deep learning module manager, which aims at encouraging code

ZhouYanzhao 41 Oct 10, 2022
[CVPR 2021] Counterfactual VQA: A Cause-Effect Look at Language Bias

Counterfactual VQA (CF-VQA) This repository is the Pytorch implementation of our paper "Counterfactual VQA: A Cause-Effect Look at Language Bias" in C

Yulei Niu 94 Dec 03, 2022
Python implementation of cover trees, near-drop-in replacement for scipy.spatial.kdtree

This is a Python implementation of cover trees, a data structure for finding nearest neighbors in a general metric space (e.g., a 3D box with periodic

Patrick Varilly 28 Nov 25, 2022
Self-Supervised Document-to-Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference

Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference This repo is the implementation for SD

Microsoft 36 Nov 28, 2022
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
Reinforcement Learning with Q-Learning Algorithm on gym's frozen lake environment implemented in python

Reinforcement Learning with Q Learning Algorithm Q learning algorithm is trained on the gym's frozen lake environment. Libraries Used gym Numpy tqdm P

1 Nov 10, 2021
NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation (ACL-IJCNLP 2021)

NeuralWOZ This code is official implementation of "NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation". Sungdong Kim, Mi

NAVER AI 31 Oct 25, 2022
Predicting the duration of arrival delays for commercial flights.

Flight Delay Prediction Our objective is to predict arrival delays of commercial flights. According to the US Department of Transportation, about 21%

Jordan Silke 1 Jan 11, 2022
✅ How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.

How Robust are Fact Checking Systems on Colloquial Claims? Official PyTorch implementation of our NAACL paper: Byeongchang Kim*, Hyunwoo Kim*, Seokhee

Byeongchang Kim 19 Mar 15, 2022
AttentionGAN for Unpaired Image-to-Image Translation & Multi-Domain Image-to-Image Translation

AttentionGAN-v2 for Unpaired Image-to-Image Translation AttentionGAN-v2 Framework The proposed generator learns both foreground and background attenti

Hao Tang 530 Dec 27, 2022