Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Overview

Self-supervised learning

Paper Conference

CI testing

Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses. The idea is to learn a representation which can discriminate between negative examples and be as close as possible to augmentations and transformations of itself. In this approach, we first train a ResNet on the unlabeled dataset which is then fine-tuned on a relatively small labeled one. This approach drastically reduces the amount of labeled data required, a big problem in applying deep learning in the real world. Surprisingly, this approach actually leads to increase in robustness as well as raw performance, when compared to fully supervised counterparts, even with the same architecture.

In case, the user wants to skip the pre-training part, the pre-trained weights can be downloaded from here to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the 'ssl_finetune_train.py'.

Steps to run the tutorial

1.) Download the two datasets TCIA-Covid19 & BTCV (More detail about them in the Data section)
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py
3.) Run the 'ssl_script_train.py'
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and logdir_path in 'ssl_finetuning_train.py'
5.) Run the 'ssl_finetuning_script.py'
6.) And that's all folks, use the model to your needs

1.Data

Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets of 600 and 171 3D volumes correspondingly. The data is available for download at this link. If this dataset is being used in your work, please use [1] as reference. A json file is provided which contains the training and validation splits that were used for the training. The json file can be found in the json_files directory of the self-supervised training tutorial.

Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge (BTCV) 2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the json_files directory for the dataset. They correspond to having different number of training volumes ranging from 6, 12 and 24. All 3 json files have the same validation split.

References:

1.) Harmon, Stephanie A., et al. "Artificial intelligence for the detection of COVID-19 pneumonia on chest CT using multinational datasets." Nature communications 11.1 (2020): 1-7.

2.) Tang, Yucheng, et al. "High-resolution 3D abdominal segmentation with random patch network fusion." Medical Image Analysis 69 (2021): 101894.

2. Network Architectures

For pre-training a modified version of ViT [1] has been used, it can be referred here from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which was used for the fine-tuning fully supervised tasks.

The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their own custom created network architectures as well.

References:

1.) Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

2.) Hatamizadeh, Ali, et al. "Unetr: Transformers for 3d medical image segmentation." arXiv preprint arXiv:2103.10504 (2021).

3. Self-supervised Tasks

The pre-training pipeline has two aspects to it (Refer figure shown below). First, it uses augmentation (top row) to mutate the data and second, it utilizes regularized contrastive loss [3] to learn feature representations of the unlabeled data. The multiple augmentations are applied on a randomly selected 3D foreground patch from a 3D volume. Two augmented views of the same 3D patch are generated for the contrastive loss as it functions by drawing the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to maximize the disagreement. The CL offers this functionality on a mini-batch.

image

The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1] and noise augmentation to the image by local pixel shuffling [2]. The secondary task of the network is to simultaneously reconstruct the two augmented views as similar to each other as possible via regularized contrastive loss [3] as its objective is to maximize the agreement. The term regularized has been used here because contrastive loss is adjusted by the reconstruction loss as a dynamic weight itself.

The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same 3D patch:

image

Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure). Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss is driven by maximizing agreement of the reconstruction based on input of the two augmented views. matshow3d from monai.visualize was used for creating this figure, a tutorial for using can be found here

References:

1.) Pathak, Deepak, et al. "Context encoders: Feature learning by inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2.) Chen, Liang, et al. "Self-supervised learning for medical image analysis using image context restoration." Medical image analysis 58 (2019): 101539.

3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

4. Experiment Hyper-parameters

Training Hyper-Parameters for SSL:
Epochs: 300
Validation Frequency: 2
Learning Rate: 1e-4
Batch size: 4 3D Volumes (Total of 8 as 2 samples were drawn per 3D Volume)
Loss Function: L1 Contrastive Loss Temperature: 0.005

Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior UNETR 3D Segmentation tutorial):
Number of Steps: 30000
Validation Frequency: 100 steps
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume)
Learning Rate: 1e-4
Loss Function: DiceCELoss

4. Training & Validation Curves for pre-training SSL

image

L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not L1.

5. Results of the Fine-tuning vs Random Initialization on BTCV

Training Volumes Validation Volumes Random Init Dice score Pre-trained Dice Score Relative Performance Improvement
6 6 63.07 70.09 ~11.13%
12 6 76.06 79.55 ~4.58%
24 6 78.91 82.30 ~4.29%

Citation

@article{Arijit Das,
  title={Self-supervised learning for medical data},
  author={Arijit Das},
  journal={https://github.com/das-projects/selfsupervised-learning},
  year={2020}
}
Owner
Arijit Das
Data Scientist who is passionate about developing and implementing robust and explainable Machine Learning algorithms.
Arijit Das
A fast MoE impl for PyTorch

An easy-to-use and efficient system to support the Mixture of Experts (MoE) model for PyTorch.

Rick Ho 873 Jan 09, 2023
Simple tutorials using Google's TensorFlow Framework

TensorFlow-Tutorials Introduction to deep learning based on Google's TensorFlow framework. These tutorials are direct ports of Newmu's Theano Tutorial

Nathan Lintz 6k Jan 06, 2023
code for "Feature Importance-aware Transferable Adversarial Attacks"

Feature Importance-aware Attack(FIA) This repository contains the code for the paper: Feature Importance-aware Transferable Adversarial Attacks (ICCV

Hengchang Guo 44 Nov 24, 2022
This repository contains code and data for "On the Multimodal Person Verification Using Audio-Visual-Thermal Data"

trimodal_person_verification This repository contains the code, and preprocessed dataset featured in "A Study of Multimodal Person Verification Using

ISSAI 7 Aug 31, 2022
Vehicle detection using machine learning and computer vision techniques for Udacity's Self-Driving Car Engineer Nanodegree.

Vehicle Detection Video demo Overview Vehicle detection using these machine learning and computer vision techniques. Linear SVM HOG(Histogram of Orien

hata 1.1k Dec 18, 2022
Contrastively Disentangled Sequential Variational Audoencoder

Contrastively Disentangled Sequential Variational Audoencoder (C-DSVAE) Overview This is the implementation for our C-DSVAE, a novel self-supervised d

Junwen Bai 35 Dec 24, 2022
Teaches a student network from the knowledge obtained via training of a larger teacher network

Distilling-the-knowledge-in-neural-network Teaches a student network from the knowledge obtained via training of a larger teacher network This is an i

Abhishek Sinha 146 Dec 11, 2022
Mask-invariant Face Recognition through Template-level Knowledge Distillation

Mask-invariant Face Recognition through Template-level Knowledge Distillation This is the official repository of "Mask-invariant Face Recognition thro

Fadi Boutros 35 Dec 06, 2022
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023
4th place solution to datafactory challenge by Intermarché.

Solution to Datafactory challenge by Intermarché. 4th place solution to datafactory challenge by Intermarché. The objective of the challenge is to pre

Raphael Sourty 11 Mar 19, 2022
Semantic Segmentation for Real Point Cloud Scenes via Bilateral Augmentation and Adaptive Fusion (CVPR 2021)

Semantic Segmentation for Real Point Cloud Scenes via Bilateral Augmentation and Adaptive Fusion (CVPR 2021) This repository is for BAAF-Net introduce

90 Dec 29, 2022
Implementing a simplified copy of Shazam application from scratch using MinHashing and LSH.

Building Shazam from scratch In this repository we tried to implement a simplified copy of the Shazam application able to tell you the name of a song

Arturo Ghinassi 0 Nov 17, 2022
Domain Adaptation with Invariant RepresentationLearning: What Transformations to Learn?

Domain Adaptation with Invariant RepresentationLearning: What Transformations to Learn? Repository Structure: DSAN |└───amazon |    └── dataset (Amazo

DMIRLAB 17 Jan 04, 2023
Deep Learning agent of Starcraft2, similar to AlphaStar of DeepMind except size of network.

Introduction This repository is for Deep Learning agent of Starcraft2. It is very similar to AlphaStar of DeepMind except size of network. I only test

Dohyeong Kim 136 Jan 04, 2023
这是一个yolox-keras的源码,可以用于训练自己的模型。

YOLOX:You Only Look Once目标检测模型在Keras当中的实现 目录 性能情况 Performance 实现的内容 Achievement 所需环境 Environment 小技巧的设置 TricksSet 文件下载 Download 训练步骤 How2train 预测步骤 Ho

Bubbliiiing 64 Nov 10, 2022
The author's officially unofficial PyTorch BigGAN implementation.

BigGAN-PyTorch The author's officially unofficial PyTorch BigGAN implementation. This repo contains code for 4-8 GPU training of BigGANs from Large Sc

Andy Brock 2.6k Jan 02, 2023
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
In-place Parallel Super Scalar Samplesort (IPS⁴o)

In-place Parallel Super Scalar Samplesort (IPS⁴o) This is the implementation of the algorithm IPS⁴o presented in the paper Engineering In-place (Share

82 Dec 22, 2022
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
Trainable Bilateral Filter Layer (PyTorch)

Trainable Bilateral Filter Layer (PyTorch) This repository contains our GPU-accelerated trainable bilateral filter layer (three spatial and one range

FabianWagner 26 Dec 25, 2022