PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
Library of various Few-Shot Learning frameworks for text classification

FewShotText This repository contains code for the paper A Neural Few-Shot Text Classification Reality Check Environment setup # Create environment pyt

Thomas Dopierre 47 Jan 03, 2023
Dieser Scanner findet Websites, die nicht direkt in Suchmaschinen auftauchen, aber trotzdem erreichbar sind.

Deep Web Scanner Dieses Script findet Websites, die per IPv4-Adresse erreichbar sind und speichert deren Metadaten. Die Ausgabe im Terminal wird nach

Alex K. 30 Nov 18, 2022
Adversarial Learning for Modeling Human Motion

Adversarial Learning for Modeling Human Motion This repository contains the open source code which reproduces the results for the paper: Adversarial l

wangqi 6 Jun 15, 2021
A sample pytorch Implementation of ACL 2021 research paper "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction".

Span-ASTE-Pytorch This repository is a pytorch version that implements Ali's ACL 2021 research paper Learning Span-Level Interactions for Aspect Senti

来自丹麦的天籁 10 Dec 06, 2022
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape Completion

GarmentNets This repository contains the source code for the paper GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape

Columbia Artificial Intelligence and Robotics Lab 43 Nov 21, 2022
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
A basic duplicate image detection service using perceptual image hash functions and nearest neighbor search, implemented using faiss, fastapi, and imagehash

Duplicate Image Detection Getting Started Install dependencies pip install -r requirements.txt Run service python main.py Testing Test with pytest How

Matthew Podolak 21 Nov 11, 2022
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
A Partition Filter Network for Joint Entity and Relation Extraction EMNLP 2021

EMNLP 2021 - A Partition Filter Network for Joint Entity and Relation Extraction

zhy 127 Jan 04, 2023
Implementation of ConvMixer-Patches Are All You Need? in TensorFlow and Keras

Patches Are All You Need? - ConvMixer ConvMixer, an extremely simple model that is similar in spirit to the ViT and the even-more-basic MLP-Mixer in t

Sayan Nath 8 Oct 03, 2022
Implementation of MA-Trace - a general-purpose multi-agent RL algorithm for cooperative environments.

Off-Policy Correction For Multi-Agent Reinforcement Learning This repository is the official implementation of Off-Policy Correction For Multi-Agent R

4 Aug 18, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack".

Generative Dynamic Patch Attack This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack". Requirements PyTo

Xiang Li 8 Nov 17, 2022
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

DV Lab 116 Dec 20, 2022
Code implementation for the paper 'Conditional Gaussian PAC-Bayes'.

CondGauss This repository contains PyTorch code for the paper Stochastic Gaussian PAC-Bayes. A novel PAC-Bayesian training method is implemented. Ther

0 Nov 01, 2021
Model-based 3D Hand Reconstruction via Self-Supervised Learning, CVPR2021

S2HAND: Model-based 3D Hand Reconstruction via Self-Supervised Learning S2HAND presents a self-supervised 3D hand reconstruction network that can join

Yujin Chen 72 Dec 12, 2022
Fast and customizable reconnaissance workflow tool based on simple YAML based DSL.

Fast and customizable reconnaissance workflow tool based on simple YAML based DSL, with support of notifications and distributed workload of that work

Américo Júnior 3 Mar 11, 2022
Bridging Composite and Real: Towards End-to-end Deep Image Matting

Bridging Composite and Real: Towards End-to-end Deep Image Matting Please note that the official repository of the paper Bridging Composite and Real:

Jizhizi_Li 30 Oct 31, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022