Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Overview

Self-supervised learning

Paper Conference

CI testing

Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses. The idea is to learn a representation which can discriminate between negative examples and be as close as possible to augmentations and transformations of itself. In this approach, we first train a ResNet on the unlabeled dataset which is then fine-tuned on a relatively small labeled one. This approach drastically reduces the amount of labeled data required, a big problem in applying deep learning in the real world. Surprisingly, this approach actually leads to increase in robustness as well as raw performance, when compared to fully supervised counterparts, even with the same architecture.

In case, the user wants to skip the pre-training part, the pre-trained weights can be downloaded from here to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the 'ssl_finetune_train.py'.

Steps to run the tutorial

1.) Download the two datasets TCIA-Covid19 & BTCV (More detail about them in the Data section)
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py
3.) Run the 'ssl_script_train.py'
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and logdir_path in 'ssl_finetuning_train.py'
5.) Run the 'ssl_finetuning_script.py'
6.) And that's all folks, use the model to your needs

1.Data

Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets of 600 and 171 3D volumes correspondingly. The data is available for download at this link. If this dataset is being used in your work, please use [1] as reference. A json file is provided which contains the training and validation splits that were used for the training. The json file can be found in the json_files directory of the self-supervised training tutorial.

Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge (BTCV) 2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the json_files directory for the dataset. They correspond to having different number of training volumes ranging from 6, 12 and 24. All 3 json files have the same validation split.

References:

1.) Harmon, Stephanie A., et al. "Artificial intelligence for the detection of COVID-19 pneumonia on chest CT using multinational datasets." Nature communications 11.1 (2020): 1-7.

2.) Tang, Yucheng, et al. "High-resolution 3D abdominal segmentation with random patch network fusion." Medical Image Analysis 69 (2021): 101894.

2. Network Architectures

For pre-training a modified version of ViT [1] has been used, it can be referred here from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which was used for the fine-tuning fully supervised tasks.

The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their own custom created network architectures as well.

References:

1.) Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

2.) Hatamizadeh, Ali, et al. "Unetr: Transformers for 3d medical image segmentation." arXiv preprint arXiv:2103.10504 (2021).

3. Self-supervised Tasks

The pre-training pipeline has two aspects to it (Refer figure shown below). First, it uses augmentation (top row) to mutate the data and second, it utilizes regularized contrastive loss [3] to learn feature representations of the unlabeled data. The multiple augmentations are applied on a randomly selected 3D foreground patch from a 3D volume. Two augmented views of the same 3D patch are generated for the contrastive loss as it functions by drawing the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to maximize the disagreement. The CL offers this functionality on a mini-batch.

image

The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1] and noise augmentation to the image by local pixel shuffling [2]. The secondary task of the network is to simultaneously reconstruct the two augmented views as similar to each other as possible via regularized contrastive loss [3] as its objective is to maximize the agreement. The term regularized has been used here because contrastive loss is adjusted by the reconstruction loss as a dynamic weight itself.

The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same 3D patch:

image

Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure). Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss is driven by maximizing agreement of the reconstruction based on input of the two augmented views. matshow3d from monai.visualize was used for creating this figure, a tutorial for using can be found here

References:

1.) Pathak, Deepak, et al. "Context encoders: Feature learning by inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2.) Chen, Liang, et al. "Self-supervised learning for medical image analysis using image context restoration." Medical image analysis 58 (2019): 101539.

3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

4. Experiment Hyper-parameters

Training Hyper-Parameters for SSL:
Epochs: 300
Validation Frequency: 2
Learning Rate: 1e-4
Batch size: 4 3D Volumes (Total of 8 as 2 samples were drawn per 3D Volume)
Loss Function: L1 Contrastive Loss Temperature: 0.005

Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior UNETR 3D Segmentation tutorial):
Number of Steps: 30000
Validation Frequency: 100 steps
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume)
Learning Rate: 1e-4
Loss Function: DiceCELoss

4. Training & Validation Curves for pre-training SSL

image

L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not L1.

5. Results of the Fine-tuning vs Random Initialization on BTCV

Training Volumes Validation Volumes Random Init Dice score Pre-trained Dice Score Relative Performance Improvement
6 6 63.07 70.09 ~11.13%
12 6 76.06 79.55 ~4.58%
24 6 78.91 82.30 ~4.29%

Citation

@article{Arijit Das,
  title={Self-supervised learning for medical data},
  author={Arijit Das},
  journal={https://github.com/das-projects/selfsupervised-learning},
  year={2020}
}
Owner
Arijit Das
Data Scientist who is passionate about developing and implementing robust and explainable Machine Learning algorithms.
Arijit Das
The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp.

PISE The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp. Requirement conda create -n pise pyt

jinszhang 110 Nov 21, 2022
TSIT: A Simple and Versatile Framework for Image-to-Image Translation

TSIT: A Simple and Versatile Framework for Image-to-Image Translation This repository provides the official PyTorch implementation for the following p

Liming Jiang 255 Nov 23, 2022
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 09, 2022
Official PyTorch repo for JoJoGAN: One Shot Face Stylization

JoJoGAN: One Shot Face Stylization This is the PyTorch implementation of JoJoGAN: One Shot Face Stylization. Abstract: While there have been recent ad

1.3k Dec 29, 2022
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
Weakly Supervised Scene Text Detection using Deep Reinforcement Learning

Weakly Supervised Scene Text Detection using Deep Reinforcement Learning This repository contains the setup for all experiments performed in our Paper

Emanuel Metzenthin 3 Dec 16, 2022
A curated list of awesome open source libraries to deploy, monitor, version and scale your machine learning

Awesome production machine learning This repository contains a curated list of awesome open source libraries that will help you deploy, monitor, versi

The Institute for Ethical Machine Learning 12.9k Jan 04, 2023
Byte-based multilingual transformer TTS for low-resource/few-shot language adaptation.

One model to speak them all 🌎 Audio Language Text ▷ Chinese 人人生而自由,在尊严和权利上一律平等。 ▷ English All human beings are born free and equal in dignity and rig

Mutian He 60 Nov 14, 2022
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
This project is based on our SIGGRAPH 2021 paper, ROSEFusion: Random Optimization for Online DenSE Reconstruction under Fast Camera Motion .

ROSEFusion 🌹 This project is based on our SIGGRAPH 2021 paper, ROSEFusion: Random Optimization for Online DenSE Reconstruction under Fast Camera Moti

219 Dec 27, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023
Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation", Haoxiang Wang, Han Zhao, Bo Li.

Bridging Multi-Task Learning and Meta-Learning Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Trainin

AI Secure 57 Dec 15, 2022
PyoMyo - Python Opensource Myo library

PyoMyo Python module for the Thalmic Labs Myo armband. Cross platform and multithreaded and works without the Myo SDK. pip install pyomyo Documentati

PerlinWarp 81 Jan 08, 2023
Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport

Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport This GitHub page provides code for reproducing the results i

Andrew Zammit Mangion 1 Nov 08, 2021
Few-shot Learning of GPT-3

Few-shot Learning With Language Models This is a codebase to perform few-shot "in-context" learning using language models similar to the GPT-3 paper.

Tony Z. Zhao 224 Dec 28, 2022
Official pytorch implementation of the IrwGAN for unaligned image-to-image translation

IrwGAN (ICCV2021) Unaligned Image-to-Image Translation by Learning to Reweight [Update] 12/15/2021 All dataset are released, trained models and genera

37 Nov 09, 2022
Tensorflow Tutorials using Jupyter Notebook

Tensorflow Tutorials using Jupyter Notebook TensorFlow tutorials written in Python (of course) with Jupyter Notebook. Tried to explain as kindly as po

Sungjoon 2.6k Dec 22, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
RobustART: Benchmarking Robustness on Architecture Design and Training Techniques

The first comprehensive Robustness investigation benchmark on large-scale dataset ImageNet regarding ARchitecture design and Training techniques towards diverse noises.

132 Dec 23, 2022