Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Related tags

Deep LearningUTNet
Overview

UTNet (Accepted at MICCAI 2021)

Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Introduction

Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.

image image

Supportting models

UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

Getting Started

Currently, we only support M&Ms dataset.

Prerequisites

Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2

Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

UTNet

For default UTNet setting, training with:

python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss

Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:

--block_list 1234 --num_blocks 1,1,1,1

Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

--block_list 1234 --num_blocks 1,1,4,8

Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

--block_list 234 --num_blocks 2,4,8

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

TransUNet

We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the original repo. The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.

python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0

ResNet50-UTNet

For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0

Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.

--block_list 23 --num_blocks 2,4

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

ResNet50-UNet

If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0

SwinUNet

Download pre-trained model from the origin repo. As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.

python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224

Citation

If you find this repo helps, please kindly cite our paper, thanks!

@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}
Python PID Tuner - Makes a model of the System from a Process Reaction Curve and calculates PID Gains

PythonPID_Tuner_SOPDT Step 1: Takes a Process Reaction Curve in csv format - assumes data at 100ms interval (column names CV and PV) Step 2: Makes a r

1 Jan 18, 2022
Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness through a Teacher-guided curriculum Learning Approach

Get Fooled for the Right Reason Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness throu

Sowrya Gali 1 Apr 25, 2022
Eth brownie struct encoding example

eth-brownie struct encoding example Overview This repository contains an example of encoding a struct, so that it can be used in a function call, usin

Ittai Svidler 2 Mar 04, 2022
Pytorch implementation of the unsupervised object discovery method LOST.

LOST Pytorch implementation of the unsupervised object discovery method LOST. More details can be found in the paper: Localizing Objects with Self-Sup

Valeo.ai 189 Dec 25, 2022
CvT-ASSD: Convolutional vision-Transformerbased Attentive Single Shot MultiBox Detector (ICTAI 2021 CCF-C 会议)The 33rd IEEE International Conference on Tools with Artificial Intelligence

CvT-ASSD including extra CvT, CvT-SSD, VGG-ASSD models original-code-website: https://github.com/albert-jin/CvT-SSD new-code-website: https://github.c

金伟强 -上海大学人工智能小渣渣~ 5 Mar 07, 2022
Python Classes: Medical Insurance Project using Object Oriented Programming Concepts

Medical-Insurance-Project-OOP Python Classes: Medical Insurance Project using Object Oriented Programming Concepts Classes are an incredibly useful pr

Hugo B. 0 Feb 04, 2022
This project provides a stock market environment using OpenGym with Deep Q-learning and Policy Gradient.

Stock Trading Market OpenAI Gym Environment with Deep Reinforcement Learning using Keras Overview This project provides a general environment for stoc

Kim, Ki Hyun 769 Dec 25, 2022
Unofficial implementation of One-Shot Free-View Neural Talking Head Synthesis

face-vid2vid Usage Dataset Preparation cd datasets wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl chmod a+rx youtube-dl python load_

worstcoder 68 Dec 30, 2022
Fuse radar and camera for detection

SAF-FCOS: Spatial Attention Fusion for Obstacle Detection using MmWave Radar and Vision Sensor This project hosts the code for implementing the SAF-FC

ChangShuo 18 Jan 01, 2023
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
Efficient Lottery Ticket Finding: Less Data is More

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match

VITA 20 Sep 04, 2022
Localized representation learning from Vision and Text (LoVT)

Localized Vision-Text Pre-Training Contrastive learning has proven effective for pre- training image models on unlabeled data and achieved great resul

Philip Müller 10 Dec 07, 2022
AutoML library for deep learning

Official Website: autokeras.com AutoKeras: An AutoML system based on Keras. It is developed by DATA Lab at Texas A&M University. The goal of AutoKeras

Keras 8.7k Jan 08, 2023
Implementation of Shape Generation and Completion Through Point-Voxel Diffusion

Shape Generation and Completion Through Point-Voxel Diffusion Project | Paper Implementation of Shape Generation and Completion Through Point-Voxel Di

Linqi Zhou 103 Dec 29, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
Code for paper ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization in the Loop.

Who Left the Dogs Out? Evaluation and demo code for our ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization

Benjamin Biggs 29 Dec 28, 2022
Discovering and Achieving Goals via World Models

Discovering and Achieving Goals via World Models [Project Website] [Benchmark Code] [Video (2min)] [Oral Talk (13min)] [Paper] Russell Mendonca*1, Ole

Oleg Rybkin 71 Dec 22, 2022
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Facebook Research 487 Dec 31, 2022
Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab.

CLIP-Guided-Diffusion Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab. Original colab notebooks by Ka

Nerdy Rodent 336 Dec 09, 2022
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

110 Dec 29, 2022