Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Overview

Medical-Transformer

Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

About this repo:

This repo hosts the code for the following networks:

  1. Gated Axial Attention U-Net
  2. MedT

Introduction

Majority of existing Transformer-based network architectures proposed for vision applications require large-scale datasets to train properly. However, compared to the datasets for vision applications, for medical imaging the number of data samples is relatively low, making it difficult to efficiently train transformers for medical appli- cations. To this end, we propose a Gated Axial-Attention model which extends the existing architectures by introducing an additional control mechanism in the self-attention module. Furthermore, to train the model effectively on medical images, we propose a Local-Global training strat- egy (LoGo) which further improves the performance. Specifically, we op- erate on the whole image and patches to learn global and local features, respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.

Using the code:

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/Medical-Transformer
cd Medical-Transformer

The code is stable using Python 3.6.10, Pytorch 1.4.0

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate medt

To install all the dependencies using pip:

pip install -r requirements.txt

Links for downloading the public Datasets:

  1. GLAS Dataset - Link (Original) | Link (Resized)
  2. MoNuSeG Dataset - Link (Original)
  3. Brain Anatomy US dataset from the paper will be made public soon !

Using the Code for your dataset

Dataset Preparation

Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.

Train Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Validation Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Test Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
  • The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.

Training Command:

python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
Change modelname to MedT or logo to train them

Testing Command:

python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"

The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.

Notes:

Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory.

Acknowledgement:

The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.

Citation:

To add

Open an issue or mail me directly in case of any queries or suggestions.

Owner
Jeya Maria Jose
PhD Student at Johns Hopkins University.
Jeya Maria Jose
Code of Periodic Activation Functions Induce Stationarity

Periodic Activation Functions Induce Stationarity This repository is the official implementation of the methods in the publication: L. Meronen, M. Tra

AaltoML 12 Jun 07, 2022
Simple tools for logging and visualizing, loading and training

TNT TNT is a library providing powerful dataloading, logging and visualization utilities for Python. It is closely integrated with PyTorch and is desi

1.5k Jan 02, 2023
Red Team tool for exfiltrating files from a target's Google Drive that you have access to, via Google's API.

GD-Thief Red Team tool for exfiltrating files from a target's Google Drive that you(the attacker) has access to, via the Google Drive API. This includ

Antonio Piazza 39 Dec 27, 2022
DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download

Bubbliiiing 31 Nov 25, 2022
Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained models

Clara Meister 50 Nov 12, 2022
An implementation of the paper "A Neural Algorithm of Artistic Style"

A Neural Algorithm of Artistic Style implementation - Neural Style Transfer This is an implementation of the research paper "A Neural Algorithm of Art

Srijarko Roy 27 Sep 20, 2022
[ICCV 2021 Oral] Just Ask: Learning to Answer Questions from Millions of Narrated Videos

Just Ask: Learning to Answer Questions from Millions of Narrated Videos Webpage • Demo • Paper This repository provides the code for our paper, includ

Antoine Yang 87 Jan 05, 2023
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
PPO is a very popular Reinforcement Learning algorithm at present.

PPO is a very popular Reinforcement Learning algorithm at present. OpenAI takes PPO as the current baseline algorithm. We use the PPO algorithm to train a policy to give the best action in any situat

Rosefintech 11 Aug 23, 2021
Perspective: Julia for Biologists

Perspective: Julia for Biologists 1. Examples Speed: Example 1 - Single cell data and network inference Domain: Single cell data Methodology: Network

Elisabeth Roesch 55 Dec 02, 2022
🔀 Visual Room Rearrangement

AI2-THOR Rearrangement Challenge Welcome to the 2021 AI2-THOR Rearrangement Challenge hosted at the CVPR'21 Embodied-AI Workshop. The goal of this cha

AI2 55 Dec 22, 2022
VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data

VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data Introduction Requirements Installation and Setup Supported Hardware and Software R

SigmaLab 1 Jun 14, 2022
Object detection and instance segmentation toolkit based on PaddlePaddle.

Object detection and instance segmentation toolkit based on PaddlePaddle.

9.3k Jan 02, 2023
Code for paper [ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot] (ICCV 2021, oral))

ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot This repository is the official PyTorch implementation of ICCV-21 pape

Jiarui 21 May 09, 2022
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
Code and Data for NeurIPS2021 Paper "A Dataset for Answering Time-Sensitive Questions"

Time-Sensitive-QA The repo contains the dataset and code for NeurIPS2021 (dataset track) paper Time-Sensitive Question Answering dataset. The dataset

wenhu chen 35 Nov 14, 2022
MIM: MIM Installs OpenMMLab Packages

MIM provides a unified API for launching and installing OpenMMLab projects and their extensions, and managing the OpenMMLab model zoo.

OpenMMLab 254 Jan 04, 2023
TorchXRayVision: A library of chest X-ray datasets and models.

torchxrayvision A library for chest X-ray datasets and models. Including pre-trained models. ( 🎬 promo video about the project) Motivation: While the

Machine Learning and Medicine Lab 575 Jan 08, 2023
The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation"

SD-AANet The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation" [arxiv] Overview confi

cv516Buaa 9 Nov 07, 2022
Official implementation for the paper: Generating Smooth Pose Sequences for Diverse Human Motion Prediction

Generating Smooth Pose Sequences for Diverse Human Motion Prediction This is official implementation for the paper Generating Smooth Pose Sequences fo

Wei Mao 28 Dec 10, 2022