PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Overview

PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

DOI

Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Image of SimCLR Arch

See also PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning.

Installation

$ conda env create --name simclr --file env.yml
$ conda activate simclr
$ python run.py

Config file

Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the run.py file.

$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100 

If you want to run it on CPU (for debugging purposes) use the --disable-cuda option.

For 16-bit precision GPU training, there NO need to to install NVIDIA apex. Just use the --fp16_precision flag and this implementation will use Pytorch built in AMP training.

Feature Evaluation

Feature evaluation is done using a linear model protocol.

First, we learned features using SimCLR on the STL10 unsupervised set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the STL10 train set and evaluated on the STL10 test set.

Check the Open In Colab notebook for reproducibility.

Note that SimCLR benefits from longer training.

Linear Classification Dataset Feature Extractor Architecture Feature dimensionality Projection Head dimensionality Epochs Top1 %
Logistic Regression (Adam) STL10 SimCLR ResNet-18 512 128 100 74.45
Logistic Regression (Adam) CIFAR10 SimCLR ResNet-18 512 128 100 69.82
Logistic Regression (Adam) STL10 SimCLR ResNet-50 2048 128 50 70.075
Comments
  • A question about the

    A question about the "labels"

    Hi! I have a question about the definition of "labels" in the script "simclr.py".

    On line 54 of "simclr.py", the authors defined:

    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

    So all the entries of "labels" are all zeros. But I think according to the paper, there should be an entry as 1 for the positive pair?

    Thanks in advance for your reply!

    opened by kekehia123 6
  • size of tensors in cosine_simiarity function

    size of tensors in cosine_simiarity function

    Hi , I'm trying to understand the code in : loss/nt_xent.py

    we are sending "representations" on both arguments

        def forward(self, zis, zjs):
            representations = torch.cat([zjs, zis], dim=0)
            similarity_matrix = self.similarity_function(representations, representations)
    

    But when receiving it in cosine_similarity func somehow the sizes are: (N, 1, C) and y shape: (1, 2N, C), how can it be double if you sent the same argument

        def _cosine_simililarity(self, x, y):
            # x shape: (N, 1, C)
            # y shape: (1, 2N, C)
            # v shape: (N, 2N)
            v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
            return v
    

    Thanks for your help.

    opened by BattashB 5
  • How do i train the SimCLR model with my local dataset?

    How do i train the SimCLR model with my local dataset?

    Dear researcher, Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning. But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

    opened by bestalllen 4
  • Question about CE Loss

    Question about CE Loss

    Hello,

    Thanks for sharing the code, nice implementation.

    The way you calculate the loss by using a mask is quite brilliant. But I have a question.

    logits = torch.cat((positives, negatives), dim=1) So if I'm not wrong, the first column of logits is positive and the rest are negatives.

    labels = torch.zeros(2 * self.batch_size).to(self.device).long() But your labels are all zeros, which means no matter positive or negative, the similarity should low.

    So I wonder is the first column of labels supposed to be 1 instead of 0.

    Thanks for your help.

    opened by WShijun1991 4
  • Issue with batch-size

    Issue with batch-size

    In function info_nce_loss, the line 28, creates labels based on batch_size and on other side we have STL10 dataset which has 100,000 images which is divisible by batch_size of 32 and having batch_size like 128 or 64 gives a remainder of 32.

    Having batch_size != 32, causes error in line 42, because the similarity matrix will based on features and labels will be based on batch size.

    For instance, if the batch size = 128, the remaining images in the dataset in the last iter of data_loader is 32. Since we create two variant of each image we'll have 64 images. Now we have 128 x 2 = 256 labels from line 28, and we'll have similarity matrix of (64 x 128, 128 x 64) => (64 x 64) but with mask (256 x 256) causing "dimension mismatch"

    Solution: Change Line 28 as below

    labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

    image

    opened by Mayurji 3
  • 'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    Considering the setting in 'scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)',I think 'scheduler.step()' should be called every step in 'for (xis,xjs),_ in train_loader'. Otherwise,lr will nerver change until 'len(train_loader)' epochs but not steps

    opened by GuohongLi 3
  • Is it something wrong with the training model for CIFAR-10 experiments?

    Is it something wrong with the training model for CIFAR-10 experiments?

    Hi,

    I find that the ResNet20 model for CIFAR-10 experiments is not fully correct. The head conv structure should be modified (stride=1 and no pooling,) because the image size of CIFAR-10 is very small.

    opened by timqqt 2
  • GPU utilization rate is low

    GPU utilization rate is low

    Hi, thanks for the code!

    When I tried to run it on single GPU (v-100), the utilizaiton rate is very low (~0-10%) even if I increase num_worker. Would you know why this happens and how to solve it? Thanks!

    opened by LiJunnan1992 2
  • Why cos_sim after L2 norm?

    Why cos_sim after L2 norm?

    Hi, This code is really useful for me. Thanks! But I got a question about the NT-Xent loss. I noticed that you use L2 norm on z and then use cos_similarity after that. But cos_similarity already contain the function of l2 norm. Why use L2 norm first?

    opened by BoPang1996 2
  • NT_Xent Loss function: all negatives are not being used?

    NT_Xent Loss function: all negatives are not being used?

    Hi @sthalles , Thank you for sharing your code!

    Pl correct me if I am wrong: I see that in line loss/nt_xent.py line 57 (below) you are not computing contrastive loss for all negative pairs as you are reshaping total negatives in 2D array i.e. only a part of negative pairs are being used for a single positive pair, right? :

    _negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)_
    _logits = torch.cat((positives, negatives), dim=1)_
    

    Hope to hear from you soon.

    -Ishan

    opened by DAVEISHAN 2
  • Validation Loss calculation

    Validation Loss calculation

    First of all, thank you for your great work!

    Method _validate in simclr.py will raise ZeroDivisionError at line 148 if the validation data loader performs only one iteration (since counter starts from 0).

    opened by alessiamarcolini 2
  • evaluation code batch_size & validation process

    evaluation code batch_size & validation process

    I'm really appreciated about your good work :) I left a question because I got confused while studying through your great code.

    First, I wonder why you used "batch_size=batch_size*2" differently from train_loader in the test_loader part of the file "mini_batch_logistic_regression_valuator.ipynb". Is it related to creating 2 views when doing data augmentation?

    Also, in the last cell of this file, I'm confused whether the second "for" (of the two "for") in the large epoch "for" statement corresponds to the test process or the validation process. I thought it was a test process, because loss update, backpropagation, optimization, etc. were done only in the first "for", and the second yield only accuracy, but is that right? Or I'm confused if the second "for" is a validating process because the first "for" and the second "for" are going together in the entire epoch processing.

    opened by YejinS 0
  • Review Training | Fine-Tune | Test details

    Review Training | Fine-Tune | Test details

    Hi, I just want to check all the experiments details and make sure I didn't miss any part(?

    1. Training Phase : use SimCLR (two encoder branches) to train on ImageNet for 1000 epochs to get a init pretrained weights.
    2. Fine-Tuned : load the init pretrained weights on the resnet18(50/101/...) with freezed parameters and concate with a linear classifier, and train the classifier with CIFAR10/STL10 training dataset for 100 epochs.
    3. Test Phase : freeze all the encoder, classifier parameters, and test on the CIFAR10/STL10 testing dataset.

    Is this the way how you get the top1 acc in the README?

    opened by Howeng98 0
  • Confusion matrix

    Confusion matrix

    Does anyone know how to add the confusion matrix in this code? After I added it according to the online one, something went wrong. I don't know what went wrong in my code.I can't solve it. please help help me! Thanks. def confusion_matrix(output, labels, conf_matrix):

    preds = torch.argmax(output, dim=-1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
    
    opened by here101 0
  • batch size affect

    batch size affect

    Hi, I'm trying to experiment with CIFAR-10 with the default hyper-params, and it seems to yield a better score when using smaller batch size (e.g. 72% with batch size 256 yet 78% with batch size 128). Anyone in the same situation, here?

    opened by VietHoang1512 1
  •  ModuleNotFoundError: No module named 'torch.cuda'

    ModuleNotFoundError: No module named 'torch.cuda'

    I am using pythion 3.7 on Win10, Anaconda Jupyter. I have successfully installed torch-1.10.0+cu113 torchaudio-0.10.0+cu113 torchvision-0.11.1+cu113. When trying to import torch , I get ModuleNotFoundError: No module named 'torch.cuda' Detailed error:

    ModuleNotFoundError                       Traceback (most recent call last)
    <ipython-input-1-bfd2c657fa76> in <module>
          1 import numpy as np
          2 import pandas as pd
    ----> 3 import torch
          4 import torch.nn as nn
          5 from sklearn.model_selection import train_test_split
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\__init__.py in <module>
        603 
        604 # Shared memory manager needs to know the exact location of manager executable
    --> 605 _C._initExtension(manager_path())
        606 del manager_path
        607 
    
    ModuleNotFoundError: No module named 'torch.cuda'
    

    I found posts for similar error No module named 'torch.cuda.amp'. However, any of the suggested solutions worked. Please advise.

    opened by m-bor 0
Releases(v1.0.1)
Code release for "COTR: Correspondence Transformer for Matching Across Images"

COTR: Correspondence Transformer for Matching Across Images This repository contains the inference code for COTR. We plan to release the training code

UBC Computer Vision Group 360 Jan 06, 2023
The deployment framework aims to provide a simple, lightweight, fast integrated, pipelined deployment framework that ensures reliability, high concurrency and scalability of services.

savior是一个能够进行快速集成算法模块并支持高性能部署的轻量开发框架。能够帮助将团队进行快速想法验证(PoC),避免重复的去github上找模型然后复现模型;能够帮助团队将功能进行流程拆解,很方便的提高分布式执行效率;能够有效减少代码冗余,减少不必要负担。

Tao Luo 125 Dec 22, 2022
Language models are open knowledge graphs ( non official implementation )

language-models-are-knowledge-graphs-pytorch Language models are open knowledge graphs ( work in progress ) A non official reimplementation of Languag

theblackcat102 132 Dec 18, 2022
Accelerating BERT Inference for Sequence Labeling via Early-Exit

Sequence-Labeling-Early-Exit Code for ACL 2021 paper: Accelerating BERT Inference for Sequence Labeling via Early-Exit Requirement: Please refer to re

李孝男 23 Oct 14, 2022
x-transformers-paddle 2.x version

x-transformers-paddle x-transformers-paddle 2.x version paddle 2.x版本 https://github.com/lucidrains/x-transformers 。 requirements paddlepaddle-gpu==2.2

yujun 7 Dec 08, 2022
This repository contains source code for the Situated Interactive Language Grounding (SILG) benchmark

SILG This repository contains source code for the Situated Interactive Language Grounding (SILG) benchmark. If you find this work helpful, please cons

Victor Zhong 17 Nov 27, 2022
Code for our paper "Sematic Representation for Dialogue Modeling" in ACL2021

AMR-Dialogue An implementation for paper "Semantic Representation for Dialogue Modeling". You may find our paper here. Requirements python 3.6 pytorch

xfbai 45 Dec 26, 2022
Minimal implementation and experiments of "No-Transaction Band Network: A Neural Network Architecture for Efficient Deep Hedging".

No-Transaction Band Network: A Neural Network Architecture for Efficient Deep Hedging Minimal implementation and experiments of "No-Transaction Band N

19 Jan 03, 2023
An alarm clock coded in Python 3 with Tkinter

Tkinter-Alarm-Clock An alarm clock coded in Python 3 with Tkinter. Run python3 Tkinter Alarm Clock.py in a terminal if you have Python 3. NOTE: This p

CodeMaster7000 1 Dec 25, 2021
Official Implementation for HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing

HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing Yuval Alaluf*, Omer Tov*, Ron Mokady, Rinon Gal, Amit H. Bermano *Denotes equ

885 Jan 06, 2023
Migration of Edge-based Distributed Federated Learning

FedFly: Towards Migration in Edge-based Distributed Federated Learning About the research Due to mobility, a device participating in Federated Learnin

qub-blesson 11 Nov 13, 2022
Model-based Reinforcement Learning Improves Autonomous Racing Performance

Racing Dreamer: Model-based versus Model-free Deep Reinforcement Learning for Autonomous Racing Cars In this work, we propose to learn a racing contro

Cyber Physical Systems - TU Wien 38 Dec 06, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
Optimized code based on M2 for faster image captioning training

Transformer Captioning This repository contains the code for Transformer-based image captioning. Based on meshed-memory-transformer, we further optimi

lyricpoem 16 Dec 16, 2022
RLDS stands for Reinforcement Learning Datasets

RLDS RLDS stands for Reinforcement Learning Datasets and it is an ecosystem of tools to store, retrieve and manipulate episodic data in the context of

Google Research 135 Jan 01, 2023
Аналитика доходности инвестиционного портфеля в Тинькофф брокере

Аналитика доходности инвестиционного портфеля Тиньков Видео на YouTube Для работы скрипта нужно установить три переменных окружения: export TINKOFF_TO

Alexey Goloburdin 64 Dec 17, 2022
Autotype on websites that have copy-paste disabled like Moodle, HackerEarth contest etc.

Autotype A quick and small python script that helps you autotype on websites that have copy paste disabled like Moodle, HackerEarth contests etc as it

Tushar 32 Nov 03, 2022
Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA)

Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA). Master's thesis documents. Bibliography, experiments and reports.

Erick Cobos 73 Dec 04, 2022
Let's Git - Versionsverwaltung & Open Source Hausaufgabe

Let's Git - Versionsverwaltung & Open Source Hausaufgabe Herzlich Willkommen zu dieser Hausaufgabe für unseren MOOC: Let's Git! Wir hoffen, dass Du vi

1 Dec 13, 2021