Image-to-image regression with uncertainty quantification in PyTorch

Overview

im2im-uq

A platform for image-to-image regression with rigorous, distribution-free uncertainty quantification.


An algorithmic MRI reconstruction with uncertainty. A rapidly acquired but undersampled MR image of a knee (A) is fed into a model that predicts a sharp reconstruction (B) along with a calibrated notion of uncertainty (C). In (C), red means high uncertainty and blue means low uncertainty. Wherever the reconstruction contains hallucinations, the uncertainty is high; see the hallucination in the image patch (E), which has high uncertainty in (F), and does not exist in the ground truth (G).

Summary

This repository provides a convenient way to train deep-learning models in PyTorch for image-to-image regression---any task where the input and output are both images---along with rigorous uncertainty quantification. The uncertainty quantification takes the form of an interval for each pixel which is guaranteed to contain most true pixel values with high-probability no matter the choice of model or the dataset used (it is a risk-controlling prediction set). The training pipeline is already built to handle more than one GPU and all training/calibration should run automatically.

The basic workflow is

  • Define your dataset in core/datasets/.
  • Create a folder for your experiment experiments/new_experiment, along with a file experiments/new_experiment/config.yml defining the model architecture, hyperparameters, and method of uncertainty quantification. You can use experiments/fastmri_test/config.yml as a template.
  • Edit core/scripts/router.py to point to your data directory.
  • From the root folder, run wandb sweep experiments/new_experiment/config.yml, and run the resulting sweep.
  • After the sweep is complete, models will be saved in experiments/new_experiment/checkpoints, the metrics will be printed to the terminal, and outputs will be in experiments/new_experiment/output/. See experiments/fastmri_test/plot.py for an example of how to make plots from the raw outputs.

Following this procedure will train one or more models (depending on config.yml) that perform image-to-image regression with rigorous uncertainty quantification.

There are two pre-baked examples that you can run on your own after downloading the open-source data: experiments/fastmri_test/config.yml and experiments/temca_test/config.yml. The third pre-baked example, experiments/bsbcm_test/config.yml, reiles on data collected at Berkeley that has not yet been publicly released (but will be soon).

Paper

Image-to-Image Regression with Distribution-Free Uncertainty Quantification and Applications in Imaging

@article{angelopoulos2022image,
  title={Image-to-Image Regression with Distribution-Free Uncertainty Quantification and Applications in Imaging},
  author={Angelopoulos, Anastasios N and Kohli, Amit P and Bates, Stephen and Jordan, Michael I and Malik, Jitendra and Alshaabi, Thayer and Upadhyayula, Srigokul and Romano, Yaniv},
  journal={arXiv preprint arXiv:2202.05265},
  year={2022}
}

Installation

You will need to execute

conda env create -f environment.yml
conda activate im2im-uq

You will also need to go through the Weights and Biases setup process that initiates when you run your first sweep. You may need to make an account on their website.

Reproducing the results

FastMRI dataset

  • Download the FastMRI dataset to your machine and unzip it. We worked with the knee_singlecoil_train dataset.
  • Edit Line 71 of core/scripts/router to point to the your local dataset.
  • From the root folder, run wandb sweep experiments/fastmri_test/config.yml
  • After the run is complete, run cd experiments/fastmri_test/plot.py to plot the results.

TEMCA2 dataset

  • Download the TEMCA2 dataset to your machine and unzip it. We worked with sections 3501 through 3839.
  • Edit Line 78 of core/scripts/router to point to the your local dataset.
  • From the root folder, run wandb sweep experiments/temca_test/config.yml
  • After the run is complete, run cd experiments/temca_test/plot.py to plot the results.

Adding a new experiment

If you want to extend this code to a new experiment, you will need to write some code compatible with our infrastructure. If adding a new dataset, you will need to write a valid PyTorch dataset object; you need to add a new model architecture, you will need to specify it; and so on. Usually, you will want to start by creating a folder experiments/new_experiment along with a config file experiments/new_experiment/config.yml. The easiest way is to start from an existing config, like experiments/fastmri_test/config.yml.

Adding new datasets

To add a new dataset, use the following procedure.

  • Download the dataset to your machine.
  • In core/datasets, make a new folder for your dataset core/datasets/new_dataset.
  • Make a valid PyTorch Dataset class for your new dataset. The most critical part is writing a __get_item__ method that returns an image-image pair in CxHxW order; see core/datasets/bsbcm/BSBCMDataset.py for a simple example.
  • Make a file core/datasets/new_dataset/__init__.py and export your dataset by adding the line from .NewDataset.py import NewDatasetClass (substituting in your filename and classname appropriately).
  • Edit core/scripts/router.py to load your new dataset, near Line 64, following the pattern therein. You will also need to import your dataset object.
  • Populate your new config file experiments/new_experiment/config.yml with the correct directories and experiment name.
  • Execute wandb sweep experiments/new_experiment/config.yml and proceed as normal!

Adding new models

In our system, there are two parts to a model---the base architecture, which we call a trunk (e.g. a U-Net), and the final layer. Defining a trunk is as simple as writing a regular PyTorch nn.module and adding it near Line 87 of core/scripts/router.py (you will also need to import it); see core/models/trunks/unet.py for an example.

The process for adding a final layer is a bit more involved. The final layer is simply a Pytorch nn.module, but it also must come with two functions: a loss function and a nested prediction set function. See core/models/finallayers/quantile_layer.py for an example. The steps are:

  • Create a final layer nn.module object. The final layer should also have a heuristic notion of uncertainty built in, like quantile outputs.
  • Specify the loss function is used to train a network with this final layer.
  • Specify a nested prediction set function that uses output of the final layer to form a prediction set. The prediction set should scale up and down with a free factor lam, which will later be calibrated. The function should have the same prototype as that on Line 34 of core/models/finallayers/quantile_layer.py for an example.
  • After creating the new final layer and related functions, add it to core/models/add_uncertainty.py as in Line 59.
  • Edit wandb sweep experiments/new_experiment/config.yml to include your new final layer, and run the sweep as normal!
Owner
Anastasios Angelopoulos
Ph.D. student at UC Berkeley AI Research.
Anastasios Angelopoulos
Info and sample codes for "NTU RGB+D Action Recognition Dataset"

"NTU RGB+D" Action Recognition Dataset "NTU RGB+D 120" Action Recognition Dataset "NTU RGB+D" is a large-scale dataset for human action recognition. I

Amir Shahroudy 578 Dec 30, 2022
Next-Best-View Estimation based on Deep Reinforcement Learning for Active Object Classification

next_best_view_rl Setup Clone the repository: git clone --recurse-submodules ... In 'third_party/zed-ros-wrapper': git checkout devel Install mujoco `

Christian Korbach 1 Feb 15, 2022
Code repository for the paper "Doubly-Trained Adversarial Data Augmentation for Neural Machine Translation" with instructions to reproduce the results.

Doubly Trained Neural Machine Translation System for Adversarial Attack and Data Augmentation Languages Experimented: Data Overview: Source Target Tra

Steven Tan 1 Aug 18, 2022
Tutorial: Introduction to Graph Machine Learning, with Jupyter notebooks

GraphMLTutorialNLDL22 Tutorial NLDL22: Introduction to Graph Machine Learning, with Jupyter notebooks This tutorial takes place during the conference

UiT Machine Learning Group 3 Jan 10, 2022
Pytorch-3dunet - 3D U-Net model for volumetric semantic segmentation written in pytorch

pytorch-3dunet PyTorch implementation 3D U-Net and its variants: Standard 3D U-Net based on 3D U-Net: Learning Dense Volumetric Segmentation from Spar

Adrian Wolny 1.3k Dec 28, 2022
YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with ONNX, TensorRT, ncnn, and OpenVINO supported.

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

7.7k Jan 03, 2023
A minimalist implementation of score-based diffusion model

sdeflow-light This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper "A V

Chin-Wei Huang 89 Dec 20, 2022
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022
Code repository for Semantic Terrain Classification for Off-Road Autonomous Driving

BEVNet Datasets Datasets should be put inside data/. For example, data/semantic_kitti_4class_100x100. Training BEVNet-S Example: cd experiments bash t

(Brian) JoonHo Lee 24 Dec 12, 2022
The fundamental package for scientific computing with Python.

NumPy is the fundamental package needed for scientific computing with Python. Website: https://www.numpy.org Documentation: https://numpy.org/doc Mail

NumPy 22.4k Jan 09, 2023
Code for "Learning to Regrasp by Learning to Place"

Learning2Regrasp Learning to Regrasp by Learning to Place, CoRL 2021. Introduction We propose a point-cloud-based system for robots to predict a seque

Shuo Cheng (成硕) 18 Aug 27, 2022
People Interaction Graph

Gihan Jayatilaka*, Jameel Hassan*, Suren Sritharan*, Janith Senananayaka, Harshana Weligampola, et. al., 2021. Holistic Interpretation of Public Scenes Using Computer Vision and Temporal Graphs to Id

University of Peradeniya : COVID Research Group 1 Aug 24, 2022
Provide baselines and evaluation metrics of the task: traffic flow prediction

Note: This repo is adpoted from https://github.com/UNIMIBInside/Smart-Mobility-Prediction. Due to technical reasons, I did not fork their code. Introd

Zhangzhi Peng 11 Nov 02, 2022
Autonomous Robots Kalman Filters

Autonomous Robots Kalman Filters The Kalman Filter is an easy topic. However, ma

20 Jul 18, 2022
Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far Can We Go?" submitted to TOSEM

tosem2021-personality-rep-package Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far

Collaborative Development Group 1 Dec 13, 2021
Deep learning for spiking neural networks

A deep learning library for spiking neural networks. Norse aims to exploit the advantages of bio-inspired neural components, which are sparse and even

Electronic Vision(s) Group — BrainScaleS Neuromorphic Hardware 59 Nov 28, 2022
Prototypical Cross-Attention Networks for Multiple Object Tracking and Segmentation, NeurIPS 2021 Spotlight

PCAN for Multiple Object Tracking and Segmentation This is the offical implementation of paper PCAN for MOTS. We also present a trailer that consists

ETH VIS Group 328 Dec 29, 2022
Awesome Human Pose Estimation

Human Pose Estimation Related Publication

Zhe Wang 1.2k Dec 26, 2022
This repository contains the map content ontology used in narrative cartography

Narrative-cartography-ontology This repository contains the map content ontology used in narrative cartography, which is associated with a submission

Weiming Huang 0 Oct 31, 2021
A fast, scalable, high performance Gradient Boosting on Decision Trees library, used for ranking, classification, regression and other machine learning tasks for Python, R, Java, C++. Supports computation on CPU and GPU.

Website | Documentation | Tutorials | Installation | Release Notes CatBoost is a machine learning method based on gradient boosting over decision tree

CatBoost 6.9k Jan 04, 2023