Public repo for the ICCV2021-CVAMD paper "Is it Time to Replace CNNs with Transformers for Medical Images?"

Overview

Is it Time to Replace CNNs with Transformers for Medical Images?

Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis. Recently, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding similar levels of performance while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore whether it is time to move to transformer-based models or if we should keep working with CNNs - can we trivially switch to transformers? If so, what are the advantages and drawbacks of switching to ViTs for medical image diagnosis? We consider these questions in a series of experiments on three mainstream medical image datasets. Our findings show that, while CNNs perform better when trained from scratch, off-the-shelf vision transformers using default hyperparameters are on par with CNNs when pretrained on ImageNet, and outperform their CNN counterparts when pretrained using self-supervision.

Enviroment setup

To build using the docker file use the following command
docker build -f Dockerfile -t med_trans \
--build-arg UID=$(id -u) \
--build-arg GID=$(id -g) \
--build-arg USER=$(whoami) \
--build-arg GROUP=$(id -g -n) .

Usage:

  • Training: python classification.py
  • Training with DINO: python classification.py --dino
  • Testing (using json file): python classification.py --test
  • Testing (using saved checkpoint): python classification.py --checkpoint CheckpointName --test
  • Fine tune the learning rate: python classification.py --lr_finder

Configuration (json file)

  • dataset_params
    • dataset: Name of the dataset (ISIC2019, APTOS2019, DDSM)
    • data_location: Location that the datasets are located
    • train_transforms: Defines the augmentations for the training set
    • val_transforms: Defines the augmentations for the validation set
    • test_transforms: Defines the augmentations for the test set
  • dataloader_params: Defines the dataloader parameters (batch size, num_workers etc)
  • model_params
    • backbone_type: type of the backbone model (e.g. resnet50, deit_small)
    • transformers_params: Additional hyperparameters for the transformers
      • img_size: The size of the input images
      • patch_size: The patch size to use for patching the input
      • pretrained_type: If supervised it loads ImageNet weights that come from supervised learning. If dino it loads ImageNet weights that come from sefl-supervised learning with DINO.
    • pretrained: If True, it uses ImageNet pretrained weights
    • freeze_backbone: If True, it freezes the backbone network
    • DINO: It controls the hyperparameters for when training with DINO
  • optimization_params: Defines learning rate, weight decay, learning rate schedule etc.
    • optimizer: The default optimizer's parameters
      • type: The optimizer's type
      • autoscale_rl: If True it scales the learning rate based on the bach size
      • params: Defines the learning rate and the weght decay value
    • LARS_params: If use=True and bach size >= batch_act_thresh it uses LARS as optimizer
    • scheduler: Defines the learning rate schedule
      • type: A list of schedulers to use
      • params: Sets the hyperparameters of the optimizers
  • training_params: Defines the training parameters
    • model_name: The model's name
    • val_every: Sets the frequency of the valiidation step (epochs - float)
    • log_every: Sets the frequency of the logging (iterations - int)
    • save_best_model: If True it will save the bast model based on the validation metrics
    • log_embeddings: If True it creates U-maps on each validation step
    • knn_eval: If True, during validation it will also calculate the scores based on knn evalutation
    • grad_clipping: If > 0, it clips the gradients
    • use_tensorboard: If True, it will use tensorboard for logging instead of wandb
    • use_mixed_precision: If True, it will use mixed precision
    • save_dir: The dir to save the model's checkpoints etc.
  • system_params: Defines if GPUs are used, which GPUs etc.
  • log_params: Project and run name for the logger (we are using Weights & Biases by default)
  • lr_finder: Define the learning rate parameters
    • grid_search_params
      • min_pow, min_pow: The min and max power of 10 for the search
      • resolution: How many different learning rates to try
      • n_epochs: maximum epochs of the training session
      • random_lr: If True, it uses random learning rates withing the accepted range
      • keep_schedule: If True, it keeps the learning rate schedule
      • report_intermediate_steps: If True, it logs if validates throughout the training sessions
  • transfer_learning_params: Turns on or off transfer learning from pretrained models
    • use_pretrained: If True, it will use a pretrained model as a backbone
    • pretrained_model_name: The pretrained model's name
    • pretrained_path: If the prerained model's dir
Owner
Christos Matsoukas
PhD student in Deep Learning @ KTH Royal Institute of Technology
Christos Matsoukas
TAUFE: Task-Agnostic Undesirable Feature DeactivationUsing Out-of-Distribution Data

A deep neural network (DNN) has achieved great success in many machine learning tasks by virtue of its high expressive power. However, its prediction can be easily biased to undesirable features, whi

KAIST Data Mining Lab 8 Dec 07, 2022
Intrusion Test Tool with Python

P3ntsT00L Uma ferramenta escrita em Python, feita para Teste de intrusão. Requisitos ter o python 3.9.8 instalado em sua máquina. ter a git instalada

josh washington 2 Dec 27, 2021
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
Realtime YOLO Monster Detection With Non Maximum Supression

Realtime-YOLO-Monster-Detection-With-Non-Maximum-Supression Table of Contents In

5 Oct 07, 2022
Lightweight Cuda Renderer with Python Wrapper.

pyRender Lightweight Cuda Renderer with Python Wrapper. Compile Change compile.sh line 5 to the glm library include path. This library can be download

Jingwei Huang 53 Dec 02, 2022
Self-Supervised Speech Pre-training and Representation Learning Toolkit.

What's New Sep 2021: We host a challenge in AAAI workshop: The 2nd Self-supervised Learning for Audio and Speech Processing! See SUPERB official site

s3prl 1.6k Jan 08, 2023
Code for ICCV 2021 paper Graph-to-3D: End-to-End Generation and Manipulation of 3D Scenes using Scene Graphs

Graph-to-3D This is the official implementation of the paper Graph-to-3d: End-to-End Generation and Manipulation of 3D Scenes Using Scene Graphs | arx

Helisa Dhamo 33 Jan 06, 2023
Codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense neural networks

DominoSearch This is repository for codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense n

11 Sep 10, 2022
TensorFlow implementation of original paper : https://github.com/hszhao/PSPNet

Keras implementation of PSPNet(caffe) Implemented Architecture of Pyramid Scene Parsing Network in Keras. For the best compability please use Python3.

VladKry 386 Dec 29, 2022
Sequence-tagging using deep learning

Classification using Deep Learning Requirements PyTorch version = 1.9.1+cu111 Python version = 3.8.10 PyTorch-Lightning version = 1.4.9 Huggingface

Vineet Kumar 2 Dec 20, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
Yoloxkeypointsegment - An anchor-free version of YOLO, with a simpler design but better performance

Introduction 关键点版本:已完成 全景分割版本:已完成 实例分割版本:已完成 YOLOX is an anchor-free version of

23 Oct 20, 2022
A novel Engagement Detection with Multi-Task Training (ED-MTT) system

A novel Engagement Detection with Multi-Task Training (ED-MTT) system which minimizes MSE and triplet loss together to determine the engagement level of students in an e-learning environment.

Onur Çopur 12 Nov 11, 2022
MQBench Quantization Aware Training with PyTorch

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
A Peer-to-peer Platform for Secure, Privacy-preserving, Decentralized Data Science

PyGrid is a peer-to-peer network of data owners and data scientists who can collectively train AI models using PySyft. PyGrid is also the central serv

OpenMined 615 Jan 03, 2023
Efficient Training of Visual Transformers with Small Datasets

Official codes for "Efficient Training of Visual Transformers with Small Datasets", NerIPS 2021.

Yahui Liu 112 Dec 25, 2022
A motion detection system with RaspberryPi, OpenCV, Python

Human Detection System using Raspberry Pi Functionality Activates a relay on detecting motion. You may need following components to get the expected R

Omal Perera 55 Dec 04, 2022
Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper

Gotta Go Fast When Generating Data with Score-Based Models This repo contains the official implementation for the paper Gotta Go Fast When Generating

Alexia Jolicoeur-Martineau 89 Nov 09, 2022
Auto White-Balance Correction for Mixed-Illuminant Scenes

Auto White-Balance Correction for Mixed-Illuminant Scenes Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown York University Video Reference code

Mahmoud Afifi 47 Nov 26, 2022