Code for "Optimizing risk-based breast cancer screening policies with reinforcement learning"

Related tags

Deep LearningTempo
Overview

Tempo: Optimizing risk-based breast cancer screening policies with reinforcement learning DOI

Introduction

This repository was used to develop Tempo, as described in: Optimizing risk-based breast cancer screening policies with reinforcement learning.

Screening programs must balance the benefits of early detection against the costs of over screening. Here, we introduce a novel reinforcement learning-based framework for personalized screening, Tempo, and demonstrate its efficacy in the context of breast cancer. We trained our risk-based screening policies on a large screening mammography dataset from Massachusetts General Hospital (MGH) USA and validated them on held-out patients from MGH, and on external datasets from Emory USA, Karolinska Sweden and Chang Gung Memorial Hospital (CGMH) Taiwan. Across all test sets, we found that a Tempo policy combined with an image-based AI risk model, Mirai [1] was significantly more efficient than current regimes used in clinical practice in terms of simulated early detection per screen frequency. Moreover, we showed that the same Tempo policy can be easily adapted to a wide range of possible screening preferences, allowing clinicians to select their desired early detection to screening cost trade-off without training new policies. Finally, we demonstrated Tempo policies based on AI-based risk models out performed Tempo policies based on less accurate clinical risk models. Altogether, our results show that pairing AI-based risk models with agile AI-designed screening policies has the potential to improve screening programs, advancing early detection while reducing over-screening.

This code base is meant to provide exact implementation details for the development of Tempo.

Aside on Software Depedencies

This code assumes python3.6 and a Linux environment. The package requirements can be install with pip:

pip install -r requirements.txt

Tempo-Mirai assumes access to Mirai risk assessments. Resources for using Mirai are shown here.

Method

method

Our full framework, named Tempo, is depicted above. As described above, we first train a risk progression neural network to predict future risk assessments given previous assessments. This model is then used to estimate patient risk at unobserved timepoints and it enables us to simulate risk-based screening policies. Next, we train our screening policy, which is implemented as a neural network, to maximize the reward (i.e combination of early detection and screening cost) on our retrospective training set. We train our screening policy to support all possible early detection vs screening cost trade-offs using envelope Q-learning [2], an RL algorithm designed to balance multiple objectives. The input of our screening policies is the patient's risk assessment, and desired weighting between rewards (i.e screening preference). The output of the policy is a recommendation for when to return for the next screen, ranging from six months to three years in the future, in multiples of six months. Our reward balances two contrasting aspects, one reflecting the imaging cost, i.e., the average mammograms a year recommended by the policy, and one modeling early detection benefit relative to the retrospective screening trajectory. Our early detection reward measures the time difference in months between each patient's recommended screening date, if it was after their last negative mammogram, and their actual diagnosis date. We evaluate screening policies by simulating their recommendations for heldout patients.

Training Risk progression models

We experimented with different learning rates, hidden sizes, numbers of layers and dropout, and chose the model that obtained the lowest validation KL divergence on the MGH validation set. Our final risk progression RNN had two layers, a hidden dimension size of 100, a dropout of 0.25, and was trained for 30 epochs with a learning rate of 1e-3 using the Adam optimizer.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/risk_progression/gru.json

Given a trained risk progression model, we can now estimate unobserved risk assessments auto-regressively. At each time step, the model takes as input the previous risk assessment, the prior hidden state, using the previous predicted assessment if the real one is not available, and predicts the risk assessment at the next time step.

Training Tempo Personalized Screening Policies

We implemented our personalized screening policy as multiple layer perceptron, which took as input a risk assessment and weighting between rewards and predicted the Q-value for each action, i.e follow up recommendation, across the rewards. This network was trained using Envelope Q-Learning [2]. We experimented with different numbers of layers, hidden dimension sizes, learning rates, dropouts, exploration epsilons, target network reset rates and weight decay rates.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/screening/neural.json

Data availability

All datasets were used under license to the respective hospital system for the current study and are not publicly available. To access the MGH dataset, investigators should reach out to C.L. to apply for an IRB approved research collaboration and obtain an appropriate Data Use Agreement. To access the Karolinska dataset, investigators should reach out to F.S. to apply for an approved research collaboration and sign a Data Use Agreement. To access the CGMH dataset, investigators should contact G.L. to apply for an IRB approved research collaboration. To access the Emory dataset, investigators should reach out to H.T to apply for an approved collaboration.

References

[1] Yala, Adam, et al. "Toward robust mammography-based models for breast cancer risk." Science Translational Medicine 13.578 (2021).

[2] Yang, Runzhe, Xingyuan Sun, and Karthik Narasimhan. "A generalized algorithm for multi-objective reinforcement learning and policy adaptation." arXiv preprint arXiv:1908.08342 (2019).

Citing Tempo

@article{yala2021optimizing,
  title={Optimizing risk-based breast cancer screening policies with reinforcement learning},
  author={Yala, Adam and Mikhael, Peter and Lehman, Constance and Lin, Gigin and Strand, Fredrik and Wang, Yung-Liang and Hughes, Kevin and Satuluru, Siddharth and Kim, Thomas and Banerjee, Imon and others},
  year={2021}
}
You might also like...
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

a delightful machine learning tool that allows you to train, test and use models without writing code
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Code for
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Releases(v1.0)
Owner
Adam Yala
PhD Candidate at MIT CSAIL
Adam Yala
Multi-objective gym environments for reinforcement learning.

MO-Gym: Multi-Objective Reinforcement Learning Environments Gym environments for multi-objective reinforcement learning (MORL). The environments follo

Lucas Alegre 74 Jan 03, 2023
Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Vansh Wassan 15 Jun 17, 2021
audioLIME: Listenable Explanations Using Source Separation

audioLIME This repository contains the Python package audioLIME, a tool for creating listenable explanations for machine learning models in music info

Institute of Computational Perception 27 Dec 01, 2022
This project is a loose implementation of paper "Algorithmic Financial Trading with Deep Convolutional Neural Networks: Time Series to Image Conversion Approach"

Stock Market Buy/Sell/Hold prediction Using convolutional Neural Network This repo is an attempt to implement the research paper titled "Algorithmic F

Asutosh Nayak 136 Dec 28, 2022
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
An Abstract Cyber Security Simulation and Markov Game for OpenAI Gym

gym-idsgame An Abstract Cyber Security Simulation and Markov Game for OpenAI Gym gym-idsgame is a reinforcement learning environment for simulating at

Kim Hammar 29 Dec 03, 2022
This repository contains the source code for the paper First Order Motion Model for Image Animation

!!! Check out our new paper and framework improved for articulated objects First Order Motion Model for Image Animation This repository contains the s

13k Jan 09, 2023
🧠 A PyTorch implementation of 'Deep CORAL: Correlation Alignment for Deep Domain Adaptation.', ECCV 2016

Deep CORAL A PyTorch implementation of 'Deep CORAL: Correlation Alignment for Deep Domain Adaptation. B Sun, K Saenko, ECCV 2016' Deep CORAL can learn

Andy Hsu 200 Dec 25, 2022
An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come

IceVision is the first agnostic computer vision framework to offer a curated collection with hundreds of high-quality pre-trained models from torchvision, MMLabs, and soon Pytorch Image Models. It or

airctic 789 Dec 29, 2022
Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code

Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code.

Yasunori Shimura 7 Jul 27, 2022
Style-based Neural Drum Synthesis with GAN inversion

Style-based Drum Synthesis with GAN Inversion Demo TensorFlow implementation of a style-based version of the adversarial drum synth (ADS) from the pap

Sound and Music Analysis (SoMA) Group 29 Nov 19, 2022
Official Implementation for HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing

HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing Yuval Alaluf*, Omer Tov*, Ron Mokady, Rinon Gal, Amit H. Bermano *Denotes equ

885 Jan 06, 2023
GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape Completion

GarmentNets This repository contains the source code for the paper GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape

Columbia Artificial Intelligence and Robotics Lab 43 Nov 21, 2022
official implementation for the paper "Simplifying Graph Convolutional Networks"

Simplifying Graph Convolutional Networks Updates As pointed out by #23, there was a subtle bug in our preprocessing code for the reddit dataset. After

Tianyi 727 Jan 01, 2023
Pgn2tex - Scripts to convert pgn files to latex document. Useful to build books or pdf from pgn studies

Pgn2Latex (WIP) A simple script to make pdf from pgn files and studies. It's sti

12 Jul 23, 2022
Official PyTorch implementation of "RMGN: A Regional Mask Guided Network for Parser-free Virtual Try-on" (IJCAI-ECAI 2022)

RMGN-VITON RMGN: A Regional Mask Guided Network for Parser-free Virtual Try-on In IJCAI-ECAI 2022(short oral). [Paper] [Supplementary Material] Abstra

27 Dec 01, 2022
Pytorch implementation of "Geometrically Adaptive Dictionary Attack on Face Recognition" (WACV 2022)

Geometrically Adaptive Dictionary Attack on Face Recognition This is the Pytorch code of our paper "Geometrically Adaptive Dictionary Attack on Face R

6 Nov 21, 2022
Forecasting for knowable future events using Bayesian informative priors (forecasting with judgmental-adjustment).

What is judgyprophet? judgyprophet is a Bayesian forecasting algorithm based on Prophet, that enables forecasting while using information known by the

AstraZeneca 56 Oct 26, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 02, 2021
Pre-Training Graph Neural Networks for Cold-Start Users and Items Representation.

Pretrain-Recsys This is our Tensorflow implementation for our WSDM 2021 paper: Bowen Hao, Jing Zhang, Hongzhi Yin, Cuiping Li, Hong Chen. Pre-Training

30 Nov 14, 2022