Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Overview

Self-supervised learning

Paper Conference

CI testing

Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses. The idea is to learn a representation which can discriminate between negative examples and be as close as possible to augmentations and transformations of itself. In this approach, we first train a ResNet on the unlabeled dataset which is then fine-tuned on a relatively small labeled one. This approach drastically reduces the amount of labeled data required, a big problem in applying deep learning in the real world. Surprisingly, this approach actually leads to increase in robustness as well as raw performance, when compared to fully supervised counterparts, even with the same architecture.

In case, the user wants to skip the pre-training part, the pre-trained weights can be downloaded from here to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the 'ssl_finetune_train.py'.

Steps to run the tutorial

1.) Download the two datasets TCIA-Covid19 & BTCV (More detail about them in the Data section)
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py
3.) Run the 'ssl_script_train.py'
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and logdir_path in 'ssl_finetuning_train.py'
5.) Run the 'ssl_finetuning_script.py'
6.) And that's all folks, use the model to your needs

1.Data

Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets of 600 and 171 3D volumes correspondingly. The data is available for download at this link. If this dataset is being used in your work, please use [1] as reference. A json file is provided which contains the training and validation splits that were used for the training. The json file can be found in the json_files directory of the self-supervised training tutorial.

Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge (BTCV) 2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the json_files directory for the dataset. They correspond to having different number of training volumes ranging from 6, 12 and 24. All 3 json files have the same validation split.

References:

1.) Harmon, Stephanie A., et al. "Artificial intelligence for the detection of COVID-19 pneumonia on chest CT using multinational datasets." Nature communications 11.1 (2020): 1-7.

2.) Tang, Yucheng, et al. "High-resolution 3D abdominal segmentation with random patch network fusion." Medical Image Analysis 69 (2021): 101894.

2. Network Architectures

For pre-training a modified version of ViT [1] has been used, it can be referred here from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which was used for the fine-tuning fully supervised tasks.

The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their own custom created network architectures as well.

References:

1.) Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

2.) Hatamizadeh, Ali, et al. "Unetr: Transformers for 3d medical image segmentation." arXiv preprint arXiv:2103.10504 (2021).

3. Self-supervised Tasks

The pre-training pipeline has two aspects to it (Refer figure shown below). First, it uses augmentation (top row) to mutate the data and second, it utilizes regularized contrastive loss [3] to learn feature representations of the unlabeled data. The multiple augmentations are applied on a randomly selected 3D foreground patch from a 3D volume. Two augmented views of the same 3D patch are generated for the contrastive loss as it functions by drawing the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to maximize the disagreement. The CL offers this functionality on a mini-batch.

image

The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1] and noise augmentation to the image by local pixel shuffling [2]. The secondary task of the network is to simultaneously reconstruct the two augmented views as similar to each other as possible via regularized contrastive loss [3] as its objective is to maximize the agreement. The term regularized has been used here because contrastive loss is adjusted by the reconstruction loss as a dynamic weight itself.

The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same 3D patch:

image

Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure). Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss is driven by maximizing agreement of the reconstruction based on input of the two augmented views. matshow3d from monai.visualize was used for creating this figure, a tutorial for using can be found here

References:

1.) Pathak, Deepak, et al. "Context encoders: Feature learning by inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2.) Chen, Liang, et al. "Self-supervised learning for medical image analysis using image context restoration." Medical image analysis 58 (2019): 101539.

3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

4. Experiment Hyper-parameters

Training Hyper-Parameters for SSL:
Epochs: 300
Validation Frequency: 2
Learning Rate: 1e-4
Batch size: 4 3D Volumes (Total of 8 as 2 samples were drawn per 3D Volume)
Loss Function: L1 Contrastive Loss Temperature: 0.005

Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior UNETR 3D Segmentation tutorial):
Number of Steps: 30000
Validation Frequency: 100 steps
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume)
Learning Rate: 1e-4
Loss Function: DiceCELoss

4. Training & Validation Curves for pre-training SSL

image

L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not L1.

5. Results of the Fine-tuning vs Random Initialization on BTCV

Training Volumes Validation Volumes Random Init Dice score Pre-trained Dice Score Relative Performance Improvement
6 6 63.07 70.09 ~11.13%
12 6 76.06 79.55 ~4.58%
24 6 78.91 82.30 ~4.29%

Citation

@article{Arijit Das,
  title={Self-supervised learning for medical data},
  author={Arijit Das},
  journal={https://github.com/das-projects/selfsupervised-learning},
  year={2020}
}
Owner
Arijit Das
Data Scientist who is passionate about developing and implementing robust and explainable Machine Learning algorithms.
Arijit Das
Alleviating Over-segmentation Errors by Detecting Action Boundaries

Alleviating Over-segmentation Errors by Detecting Action Boundaries Forked from ASRF offical code. This repo is the a implementation of replacing orig

13 Dec 12, 2022
Pytorch implementation of the Variational Recurrent Neural Network (VRNN).

VariationalRecurrentNeuralNetwork Pytorch implementation of the Variational RNN (VRNN), from A Recurrent Latent Variable Model for Sequential Data. Th

emmanuel 251 Dec 17, 2022
SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers

SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Gu

Chen Liang 23 Nov 07, 2022
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales

IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales. In this case, we ended up using XGBoost because it was the o

1 Jan 04, 2022
【ACMMM 2021】DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning

DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning (ACMMM 2021) Overview We release the code of the DSANet (Dynamic S

Wenhao Wu 46 Dec 27, 2022
SalGAN: Visual Saliency Prediction with Generative Adversarial Networks

SalGAN: Visual Saliency Prediction with Adversarial Networks Junting Pan Cristian Canton Ferrer Kevin McGuinness Noel O'Connor Jordi Torres Elisa Sayr

Image Processing Group - BarcelonaTECH - UPC 347 Nov 22, 2022
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

HiT-GAN Official TensorFlow Implementation HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GA

Google Research 78 Oct 31, 2022
Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis"

Beyond the Spectrum Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis" by Yang He, Ning Yu, Margret Keu

Yang He 27 Jan 07, 2023
Individual Treatment Effect Estimation

CAPE Individual Treatment Effect Estimation Run CAPE python train_causal.py --loop 10 -m cape_cau -d NI --i_t 1 Run a baseline model python train_cau

S. Deng 4 Sep 02, 2022
Code for the bachelors-thesis flaky fault localization

Flaky_Fault_Localization Scripts for the Bachelors-Thesis: "Flaky Fault Localization" by Christian Kasberger. The thesis examines the usefulness of sp

Christian Kasberger 1 Oct 26, 2021
Boosted CVaR Classification (NeurIPS 2021)

Boosted CVaR Classification Runtian Zhai, Chen Dan, Arun Sai Suggala, Zico Kolter, Pradeep Ravikumar NeurIPS 2021 Table of Contents Quick Start Train

Runtian Zhai 4 Feb 15, 2022
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
Non-Vacuous Generalisation Bounds for Shallow Neural Networks

This package requires jax, tensorflow, and numpy. Either tensorflow or scikit-learn can be used for loading data. To run in a nix-shell with required

Felix Biggs 0 Feb 04, 2022
Loopy belief propagation for factor graphs on discrete variables, in JAX!

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

Vicarious 62 Dec 23, 2022
[CVPR 2022 Oral] Versatile Multi-Modal Pre-Training for Human-Centric Perception

Versatile Multi-Modal Pre-Training for Human-Centric Perception Fangzhou Hong1  Liang Pan1  Zhongang Cai1,2,3  Ziwei Liu1* 1S-Lab, Nanyang Technologic

Fangzhou Hong 96 Jan 03, 2023
Heterogeneous Deep Graph Infomax

Heterogeneous-Deep-Graph-Infomax Parameter Setting: HDGI-A: Node-level dimension: 16 Attention head: 4 Semantic-level attention vector: 8 learning rat

52 Oct 31, 2022
(AAAI2020)Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing

Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing This repository contains pytorch source code for AAAI2020 oral paper: Grapy-ML

54 Aug 04, 2022
The Fundamental Clustering Problems Suite (FCPS) summaries 54 state-of-the-art clustering algorithms, common cluster challenges and estimations of the number of clusters as well as the testing for cluster tendency.

FCPS Fundamental Clustering Problems Suite The package provides over sixty state-of-the-art clustering algorithms for unsupervised machine learning pu

9 Nov 27, 2022
Satellite labelling tool for manual labelling of storm top features such as overshooting tops, above-anvil plumes, cold U/Vs, rings etc.

Satellite labelling tool About this app A tool for manual labelling of storm top features such as overshooting tops, above-anvil plumes, cold U/Vs, ri

Czech Hydrometeorological Institute - Satellite Department 10 Sep 14, 2022