Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Overview

Self-supervised learning

Paper Conference

CI testing

Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses. The idea is to learn a representation which can discriminate between negative examples and be as close as possible to augmentations and transformations of itself. In this approach, we first train a ResNet on the unlabeled dataset which is then fine-tuned on a relatively small labeled one. This approach drastically reduces the amount of labeled data required, a big problem in applying deep learning in the real world. Surprisingly, this approach actually leads to increase in robustness as well as raw performance, when compared to fully supervised counterparts, even with the same architecture.

In case, the user wants to skip the pre-training part, the pre-trained weights can be downloaded from here to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the 'ssl_finetune_train.py'.

Steps to run the tutorial

1.) Download the two datasets TCIA-Covid19 & BTCV (More detail about them in the Data section)
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py
3.) Run the 'ssl_script_train.py'
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and logdir_path in 'ssl_finetuning_train.py'
5.) Run the 'ssl_finetuning_script.py'
6.) And that's all folks, use the model to your needs

1.Data

Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets of 600 and 171 3D volumes correspondingly. The data is available for download at this link. If this dataset is being used in your work, please use [1] as reference. A json file is provided which contains the training and validation splits that were used for the training. The json file can be found in the json_files directory of the self-supervised training tutorial.

Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge (BTCV) 2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the json_files directory for the dataset. They correspond to having different number of training volumes ranging from 6, 12 and 24. All 3 json files have the same validation split.

References:

1.) Harmon, Stephanie A., et al. "Artificial intelligence for the detection of COVID-19 pneumonia on chest CT using multinational datasets." Nature communications 11.1 (2020): 1-7.

2.) Tang, Yucheng, et al. "High-resolution 3D abdominal segmentation with random patch network fusion." Medical Image Analysis 69 (2021): 101894.

2. Network Architectures

For pre-training a modified version of ViT [1] has been used, it can be referred here from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which was used for the fine-tuning fully supervised tasks.

The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their own custom created network architectures as well.

References:

1.) Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

2.) Hatamizadeh, Ali, et al. "Unetr: Transformers for 3d medical image segmentation." arXiv preprint arXiv:2103.10504 (2021).

3. Self-supervised Tasks

The pre-training pipeline has two aspects to it (Refer figure shown below). First, it uses augmentation (top row) to mutate the data and second, it utilizes regularized contrastive loss [3] to learn feature representations of the unlabeled data. The multiple augmentations are applied on a randomly selected 3D foreground patch from a 3D volume. Two augmented views of the same 3D patch are generated for the contrastive loss as it functions by drawing the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to maximize the disagreement. The CL offers this functionality on a mini-batch.

image

The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1] and noise augmentation to the image by local pixel shuffling [2]. The secondary task of the network is to simultaneously reconstruct the two augmented views as similar to each other as possible via regularized contrastive loss [3] as its objective is to maximize the agreement. The term regularized has been used here because contrastive loss is adjusted by the reconstruction loss as a dynamic weight itself.

The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same 3D patch:

image

Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure). Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss is driven by maximizing agreement of the reconstruction based on input of the two augmented views. matshow3d from monai.visualize was used for creating this figure, a tutorial for using can be found here

References:

1.) Pathak, Deepak, et al. "Context encoders: Feature learning by inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2.) Chen, Liang, et al. "Self-supervised learning for medical image analysis using image context restoration." Medical image analysis 58 (2019): 101539.

3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

4. Experiment Hyper-parameters

Training Hyper-Parameters for SSL:
Epochs: 300
Validation Frequency: 2
Learning Rate: 1e-4
Batch size: 4 3D Volumes (Total of 8 as 2 samples were drawn per 3D Volume)
Loss Function: L1 Contrastive Loss Temperature: 0.005

Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior UNETR 3D Segmentation tutorial):
Number of Steps: 30000
Validation Frequency: 100 steps
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume)
Learning Rate: 1e-4
Loss Function: DiceCELoss

4. Training & Validation Curves for pre-training SSL

image

L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not L1.

5. Results of the Fine-tuning vs Random Initialization on BTCV

Training Volumes Validation Volumes Random Init Dice score Pre-trained Dice Score Relative Performance Improvement
6 6 63.07 70.09 ~11.13%
12 6 76.06 79.55 ~4.58%
24 6 78.91 82.30 ~4.29%

Citation

@article{Arijit Das,
  title={Self-supervised learning for medical data},
  author={Arijit Das},
  journal={https://github.com/das-projects/selfsupervised-learning},
  year={2020}
}
Owner
Arijit Das
Data Scientist who is passionate about developing and implementing robust and explainable Machine Learning algorithms.
Arijit Das
Classification models 1D Zoo - Keras and TF.Keras

Classification models 1D Zoo - Keras and TF.Keras This repository contains 1D variants of popular CNN models for classification like ResNets, DenseNet

Roman Solovyev 12 Jan 06, 2023
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations

jaxdf - JAX-based Discretization Framework Overview | Example | Installation | Documentation ⚠️ This library is still in development. Breaking changes

UCL Biomedical Ultrasound Group 65 Dec 23, 2022
Constructing Neural Network-Based Models for Simulating Dynamical Systems

Constructing Neural Network-Based Models for Simulating Dynamical Systems Note this repo is work in progress prior to reviewing This is a companion re

Christian Møldrup Legaard 21 Nov 25, 2022
A curated list of awesome Machine Learning frameworks, libraries and software.

Awesome Machine Learning A curated list of awesome machine learning frameworks, libraries and software (by language). Inspired by awesome-php. If you

Joseph Misiti 57.1k Jan 03, 2023
EFENet: Reference-based Video Super-Resolution with Enhanced Flow Estimation

EFENet EFENet: Reference-based Video Super-Resolution with Enhanced Flow Estimation Code is a bit messy now. I woud clean up soon. For training the EF

Yaping Zhao 19 Nov 05, 2022
This is the official code release for the paper Shape and Material Capture at Home

This is the official code release for the paper Shape and Material Capture at Home. The code enables you to reconstruct a 3D mesh and Cook-Torrance BRDF from one or more images captured with a flashl

89 Dec 10, 2022
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 04, 2023
Stock-Prediction - prediction of stock market movements using sentiment analysis and deep learning.

Stock-Prediction- In this project, we aim to enhance the prediction of stock market movements using sentiment analysis and deep learning. We divide th

5 Jan 25, 2022
Pytorch Implementation for (STANet+ and STANet)

Pytorch Implementation for (STANet+ and STANet) V2-Weakly Supervised Visual-Auditory Saliency Detection with Multigranularity Perception (arxiv), pdf:

GuotaoWang 14 Nov 29, 2022
Tensorflow Implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (ICML 2017 workshop)

tf-SNDCGAN Tensorflow implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (https://www.researchgate.net/publicati

Nhat M. Nguyen 248 Nov 25, 2022
Artificial intelligence technology inferring issues and logically supporting facts from raw text

개요 비정형 텍스트를 학습하여 쟁점별 사실과 논리적 근거 추론이 가능한 인공지능 원천기술 Artificial intelligence techno

6 Dec 29, 2021
Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset

Semantic Segmentation on MIT ADE20K dataset in PyTorch This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing da

MIT CSAIL Computer Vision 4.5k Jan 08, 2023
Supercharging Imbalanced Data Learning WithCausal Representation Transfer

ECRT: Energy-based Causal Representation Transfer Code for Supercharging Imbalanced Data Learning With Energy-basedContrastive Representation Transfer

Zidi Xiu 11 May 02, 2022
This Jupyter notebook shows one way to implement a simple first-order low-pass filter on sampled data in discrete time.

How to Implement a First-Order Low-Pass Filter in Discrete Time We often teach or learn about filters in continuous time, but then need to implement t

Joshua Marshall 4 Aug 24, 2022
Code for the ICCV2021 paper "Personalized Image Semantic Segmentation"

PSS: Personalized Image Semantic Segmentation Paper PSS: Personalized Image Semantic Segmentation Yu Zhang, Chang-Bin Zhang, Peng-Tao Jiang, Ming-Ming

张宇 15 Jul 09, 2022
Repository of 3D Object Detection with Pointformer (CVPR2021)

3D Object Detection with Pointformer This repository contains the code for the paper 3D Object Detection with Pointformer (CVPR 2021) [arXiv]. This wo

Zhuofan Xia 117 Jan 06, 2023
Potato Disease Classification - Training, Rest APIs, and Frontend to test.

Potato Disease Classification Setup for Python: Install Python (Setup instructions) Install Python packages pip3 install -r training/requirements.txt

codebasics 95 Dec 21, 2022
Deep learning image registration library for PyTorch

TorchIR: Pytorch Image Registration TorchIR is a image registration library for deep learning image registration (DLIR). I have integrated several ide

Bob de Vos 40 Dec 16, 2022
Temporal Dynamic Convolutional Neural Network for Text-Independent Speaker Verification and Phonemetic Analysis

TDY-CNN for Text-Independent Speaker Verification Official implementation of Temporal Dynamic Convolutional Neural Network for Text-Independent Speake

Seong-Hu Kim 16 Oct 17, 2022