Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Overview

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Negative Loss, Transfer Learning/Fine-Tuning Question

    Negative Loss, Transfer Learning/Fine-Tuning Question

    Hi! Thanks for sharing this repo -- really clean and easy to use.

    When training using the PyTorch Lightning script from the repo, my loss is negative (and gets more negative over time) when training. Is this expected? Screenshot 2020-06-22 at 6 23 47 PM


    I'm curious to know if you've fine-tuned a pretrained model using this BYOL as the README example suggested. If yes, how were the results? Any intuition regarding how many epochs to fine-tune for?

    Thanks!

    opened by rsomani95 13
  • AssertionError: hidden layer never emitted an output with multi-gpu training

    AssertionError: hidden layer never emitted an output with multi-gpu training

    I tried your library with a WideResnet40-2 model and used layer_index=-2.

    The lightning example works fine for single-gpu but i got the error with multiple GPUs.

    opened by reactivetype 7
  • How to transfer the trained ckpt to pytorch.pth model?

    How to transfer the trained ckpt to pytorch.pth model?

    I use the example script to train a model, I got a ckpt file. but how could I extra the trained resnet50.pth instead of the whole SelfSupervisedLearner? Sorry I am new for pytorch lightning lib. What I want is the SelfSupervised resnet50.pth, because I want this to replace the original ImageNet-pretrained one. Thank you a lot.

    opened by knaffe 5
  • Training loss decreased and then increased

    Training loss decreased and then increased

    Hi, I used your example on my own data. The training loss decreased and then increased after 100 epochs, which is wired. Did you meet similar situations? Is it hard to train the model? the batchsize is 128/256 lr is 0.1/0.2 weight_decay is 1e-6

    opened by easonyang1996 4
  • Can't load ckpt

    Can't load ckpt

    I use byol-pytorch-master/examples/lightning/train.py to generate ckpt locally after training, but when I load ckpt, there will be the following errors. How should I load it? Thanks a lot! 截屏2020-11-18 上午12 51 48

    opened by AndrewTal 4
  • BYOL uses different augmentations for view1 and view2

    BYOL uses different augmentations for view1 and view2

    opened by OlivierDehaene 4
  • Transferring results on Cifar and other datasets

    Transferring results on Cifar and other datasets

    Thanks for your open sourcing!

    I notice that the BYOL has a large gap on the transferring downstream datasets: e.g., SimCLR reaches 71.6% on Cifar 100, while BYOL can reach to 78.4%.

    I understand that this might depends on the downstream training protocols. And could you provide us a sample code on that, especially for the LBFGS optimized logistic regressor?

    opened by jacobswan1 4
  • The saved network is same as the initial one?

    The saved network is same as the initial one?

    Firstly, thank you so much for this clean implementation!!

    The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

    The code I tested:

    import torch
    from net.byol import BYOL
    from torchvision import models
     
           
    resnet = models.resnet50(pretrained=True)
    param_1 = resnet.parameters()
    
    learner = BYOL(
        resnet,
        image_size = 256,
        hidden_layer = 'avgpool'
    )
    
    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
    
    def sample_unlabelled_images():
        return torch.randn(20, 3, 256, 256)
    
    for _ in range(2):
        images = sample_unlabelled_images()
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of target encoder
    
    # save your improved network
    torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
    
    # restore the model      
    resnet2 = models.resnet50()
    resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
    param_2 = resnet2.parameters()
    
    # test whether two models are the same 
    for p1, p2 in zip(param_1, param_2):
        if p1.data.ne(p2.data).sum() > 0:
            print('They are different.')
    print('They are same.')
    
    opened by KimMeen 3
  • the maximum batch size can only be set to 32

    the maximum batch size can only be set to 32

    When I run the code with a 2080ti GPU with 10G memory, the maximum batch size can only be set to 32. Is there any place in the code that takes up a lot of video memory?

    opened by cuixianheng 3
  • Pretrained network

    Pretrained network

    Hi, thanks for sharing the code and making it so easy to use. I see in the example you set resnet = models.resnet50(pretrained=True). Is this what is done in the paper? Shouldn't self-supervised-learned networks be trained from scratch?

    Thanks again, P.

    opened by pmorerio 3
  • Singleton Class Members

    Singleton Class Members

    Forgive me for my unfamiliarity with software design, but I'm wondering why it is necessary to write a singleton wrapper for projector and target_encoder. Is there any disadvantage of initializing them in __init__?

    opened by wentaoyuan 3
  • Increase EMA-parameter during training

    Increase EMA-parameter during training

    Hi, I noticed that the EMA-parameter (called beta in the code, τ in the paper) is not updated during training. In the paper they describe that they increase τ from the start value to 1 during training: "Specifically, we set τ = 1 − (1 − τbase) · (cos(πk/K) + 1)/2 with k the current training step and K the maximum number of training steps." This makes a huge difference to the validation loss at the end of the training.

    without_tau_update with_tau_update

    opened by Benjamin-Hansson 1
  • Why the loss is different from BYOL authors'

    Why the loss is different from BYOL authors'

    I found the loss is different from the loss said in BYOL paper which should be a L2 loss and I did't find explanation... The loss in this repo is a cosine loss, and I just want to know why. BTW, thanks for this great repo!

    opened by Jing-XING 2
  • How to cluster/predict images?

    How to cluster/predict images?

    Hi, I have trained using examples given with pytorch-lightning. I couldn't find code to do clustering of images after training. How can I find which image falls in which cluster? Is there any predictor API? I want to do something like this

    image

    opened by laxmimerit 1
  • BN layer weights and biases are not updated

    BN layer weights and biases are not updated

    Thanks for sharing this repo, great work!

    I trained BYOL on my data and noticed that the weights and biases for BN layers are not updated on the saved model. I used resnet18 without pretrained weights resnet = models.resnet50(pretrained=False). After training for multiple epochs, the saved model has bn1.weight all equal to 1.0 and bn1.bias all equal to 0.0 .

    Is this the expected behavior or am I missing something? Appreciate your response!

    opened by kregmi 1
  •  Warning: grad and param do not obey the gradient layout contract.

    Warning: grad and param do not obey the gradient layout contract.

    Has anybody gotten a similar warning when using it?

    Warning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. grad.sizes() = [512, 256, 1, 1], strides() = [256, 1, 1, 1] param.sizes() = [512, 256, 1, 1], strides() = [256, 1, 256, 256] (function operator())

    opened by mohaEs 3
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Neural Surface Maps

Neural Surface Maps Official implementation of Neural Surface Maps - Luca Morreale, Noam Aigerman, Vladimir Kim, Niloy J. Mitra [Paper] [Project Page]

Luca Morreale 49 Dec 13, 2022
Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation".

FPS-Net Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation", accepted by ISPRS journal of Photogrammetry

15 Nov 30, 2022
Its a Plant Leaf Disease Detection System based on Machine Learning.

My_Project_Code Its a Plant Leaf Disease Detection System based on Machine Learning. I have used Tomato Leaves Dataset from kaggle. This system detect

Sanskriti Sidola 3 Jun 15, 2022
Multi-Joint dynamics with Contact. A general purpose physics simulator.

MuJoCo Physics MuJoCo stands for Multi-Joint dynamics with Contact. It is a general purpose physics engine that aims to facilitate research and develo

DeepMind 5.2k Jan 02, 2023
Per-Pixel Classification is Not All You Need for Semantic Segmentation

MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation Bowen Cheng, Alexander G. Schwing, Alexander Kirillov [arXiv] [Proj

Facebook Research 1k Jan 08, 2023
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

1.4k Jan 04, 2023
YOLOv5 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and e

Ultralytics 34.1k Dec 31, 2022
Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis

Readme File for "Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis" by Ham, Imai, and Janson. (2022) All scripts were written and

0 Jan 27, 2022
An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities.

Playground for CLIP-like models Demo Colab Link GradCAM Visualization Naive Zero-shot Detection Smarter Zero-shot Detection Captcha Solver Changelog 2

Kevin Zakka 101 Dec 30, 2022
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
Image based Human Fall Detection

Here I integrated the YOLOv5 object detection algorithm with my own created dataset which consists of human activity images to achieve low cost, high accuracy, and real-time computing requirements

UTTEJ KUMAR 12 Dec 11, 2022
Pytorch implementation of forward and inverse Haar Wavelets 2D

Pytorch implementation of forward and inverse Haar Wavelets 2D

Sergei Belousov 9 Oct 30, 2022
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022
ML-PersonalWork - Big assignment PersonalWork in Machine Learning, 2021 autumn BUAA.

ML-PersonalWork - Big assignment PersonalWork in Machine Learning, 2021 autumn BUAA.

Snapdragon Lee 2 Dec 16, 2022
To build a regression model to predict the concrete compressive strength based on the different features in the training data.

Cement-Strength-Prediction Problem Statement To build a regression model to predict the concrete compressive strength based on the different features

Ashish Kumar 4 Jun 11, 2022
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding 📋 This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

55 Dec 21, 2022
An interactive DNN Model deployed on web that predicts the chance of heart failure for a patient with an accuracy of 98%

Heart Failure Predictor About A Web UI deployed Dense Neural Network Model Made using Tensorflow that predicts whether the patient is healthy or has c

Adit Ahmedabadi 0 Jan 09, 2022
Barbershop: GAN-based Image Compositing using Segmentation Masks (SIGGRAPH Asia 2021)

Barbershop: GAN-based Image Compositing using Segmentation Masks Barbershop: GAN-based Image Compositing using Segmentation Masks Peihao Zhu, Rameen A

Peihao Zhu 928 Dec 30, 2022
Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel

Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel This repository is the official PyTorch implementation of BSRDM w

Zongsheng Yue 69 Jan 05, 2023
Ludwig Benchmarking Toolkit

Ludwig Benchmarking Toolkit The Ludwig Benchmarking Toolkit is a personalized benchmarking toolkit for running end-to-end benchmark studies across an

HazyResearch 17 Nov 18, 2022