Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Overview

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Negative Loss, Transfer Learning/Fine-Tuning Question

    Negative Loss, Transfer Learning/Fine-Tuning Question

    Hi! Thanks for sharing this repo -- really clean and easy to use.

    When training using the PyTorch Lightning script from the repo, my loss is negative (and gets more negative over time) when training. Is this expected? Screenshot 2020-06-22 at 6 23 47 PM


    I'm curious to know if you've fine-tuned a pretrained model using this BYOL as the README example suggested. If yes, how were the results? Any intuition regarding how many epochs to fine-tune for?

    Thanks!

    opened by rsomani95 13
  • AssertionError: hidden layer never emitted an output with multi-gpu training

    AssertionError: hidden layer never emitted an output with multi-gpu training

    I tried your library with a WideResnet40-2 model and used layer_index=-2.

    The lightning example works fine for single-gpu but i got the error with multiple GPUs.

    opened by reactivetype 7
  • How to transfer the trained ckpt to pytorch.pth model?

    How to transfer the trained ckpt to pytorch.pth model?

    I use the example script to train a model, I got a ckpt file. but how could I extra the trained resnet50.pth instead of the whole SelfSupervisedLearner? Sorry I am new for pytorch lightning lib. What I want is the SelfSupervised resnet50.pth, because I want this to replace the original ImageNet-pretrained one. Thank you a lot.

    opened by knaffe 5
  • Training loss decreased and then increased

    Training loss decreased and then increased

    Hi, I used your example on my own data. The training loss decreased and then increased after 100 epochs, which is wired. Did you meet similar situations? Is it hard to train the model? the batchsize is 128/256 lr is 0.1/0.2 weight_decay is 1e-6

    opened by easonyang1996 4
  • Can't load ckpt

    Can't load ckpt

    I use byol-pytorch-master/examples/lightning/train.py to generate ckpt locally after training, but when I load ckpt, there will be the following errors. How should I load it? Thanks a lot! 截屏2020-11-18 上午12 51 48

    opened by AndrewTal 4
  • BYOL uses different augmentations for view1 and view2

    BYOL uses different augmentations for view1 and view2

    opened by OlivierDehaene 4
  • Transferring results on Cifar and other datasets

    Transferring results on Cifar and other datasets

    Thanks for your open sourcing!

    I notice that the BYOL has a large gap on the transferring downstream datasets: e.g., SimCLR reaches 71.6% on Cifar 100, while BYOL can reach to 78.4%.

    I understand that this might depends on the downstream training protocols. And could you provide us a sample code on that, especially for the LBFGS optimized logistic regressor?

    opened by jacobswan1 4
  • The saved network is same as the initial one?

    The saved network is same as the initial one?

    Firstly, thank you so much for this clean implementation!!

    The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

    The code I tested:

    import torch
    from net.byol import BYOL
    from torchvision import models
     
           
    resnet = models.resnet50(pretrained=True)
    param_1 = resnet.parameters()
    
    learner = BYOL(
        resnet,
        image_size = 256,
        hidden_layer = 'avgpool'
    )
    
    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
    
    def sample_unlabelled_images():
        return torch.randn(20, 3, 256, 256)
    
    for _ in range(2):
        images = sample_unlabelled_images()
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of target encoder
    
    # save your improved network
    torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
    
    # restore the model      
    resnet2 = models.resnet50()
    resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
    param_2 = resnet2.parameters()
    
    # test whether two models are the same 
    for p1, p2 in zip(param_1, param_2):
        if p1.data.ne(p2.data).sum() > 0:
            print('They are different.')
    print('They are same.')
    
    opened by KimMeen 3
  • the maximum batch size can only be set to 32

    the maximum batch size can only be set to 32

    When I run the code with a 2080ti GPU with 10G memory, the maximum batch size can only be set to 32. Is there any place in the code that takes up a lot of video memory?

    opened by cuixianheng 3
  • Pretrained network

    Pretrained network

    Hi, thanks for sharing the code and making it so easy to use. I see in the example you set resnet = models.resnet50(pretrained=True). Is this what is done in the paper? Shouldn't self-supervised-learned networks be trained from scratch?

    Thanks again, P.

    opened by pmorerio 3
  • Singleton Class Members

    Singleton Class Members

    Forgive me for my unfamiliarity with software design, but I'm wondering why it is necessary to write a singleton wrapper for projector and target_encoder. Is there any disadvantage of initializing them in __init__?

    opened by wentaoyuan 3
  • Increase EMA-parameter during training

    Increase EMA-parameter during training

    Hi, I noticed that the EMA-parameter (called beta in the code, τ in the paper) is not updated during training. In the paper they describe that they increase τ from the start value to 1 during training: "Specifically, we set τ = 1 − (1 − τbase) · (cos(πk/K) + 1)/2 with k the current training step and K the maximum number of training steps." This makes a huge difference to the validation loss at the end of the training.

    without_tau_update with_tau_update

    opened by Benjamin-Hansson 1
  • Why the loss is different from BYOL authors'

    Why the loss is different from BYOL authors'

    I found the loss is different from the loss said in BYOL paper which should be a L2 loss and I did't find explanation... The loss in this repo is a cosine loss, and I just want to know why. BTW, thanks for this great repo!

    opened by Jing-XING 2
  • How to cluster/predict images?

    How to cluster/predict images?

    Hi, I have trained using examples given with pytorch-lightning. I couldn't find code to do clustering of images after training. How can I find which image falls in which cluster? Is there any predictor API? I want to do something like this

    image

    opened by laxmimerit 1
  • BN layer weights and biases are not updated

    BN layer weights and biases are not updated

    Thanks for sharing this repo, great work!

    I trained BYOL on my data and noticed that the weights and biases for BN layers are not updated on the saved model. I used resnet18 without pretrained weights resnet = models.resnet50(pretrained=False). After training for multiple epochs, the saved model has bn1.weight all equal to 1.0 and bn1.bias all equal to 0.0 .

    Is this the expected behavior or am I missing something? Appreciate your response!

    opened by kregmi 1
  •  Warning: grad and param do not obey the gradient layout contract.

    Warning: grad and param do not obey the gradient layout contract.

    Has anybody gotten a similar warning when using it?

    Warning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. grad.sizes() = [512, 256, 1, 1], strides() = [256, 1, 1, 1] param.sizes() = [512, 256, 1, 1], strides() = [256, 1, 256, 256] (function operator())

    opened by mohaEs 3
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
TICC is a python solver for efficiently segmenting and clustering a multivariate time series

TICC TICC is a python solver for efficiently segmenting and clustering a multivariate time series. It takes as input a T-by-n data matrix, a regulariz

406 Dec 12, 2022
CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

Facebook Research 721 Jan 03, 2023
Implementation of algorithms for continuous control (DDPG and NAF).

DEPRECATION This repository is deprecated and is no longer maintaned. Please see a more recent implementation of RL for continuous control at jax-sac.

Ilya Kostrikov 288 Dec 31, 2022
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

Phil Wang 208 Dec 25, 2022
StorSeismic: An approach to pre-train a neural network to store seismic data features

StorSeismic: An approach to pre-train a neural network to store seismic data features This repository contains codes and resources to reproduce experi

Seismic Wave Analysis Group 11 Dec 05, 2022
BigbrotherBENL - Face recognition on the Big Brother episodes in Belgium and the Netherlands.

BigbrotherBENL - Face recognition on the Big Brother episodes in Belgium and the Netherlands. Keeping statistics of whom are most visible and recognisable in the series and wether or not it has an im

Frederik 2 Jan 04, 2022
A PyTorch implementation of the paper "Semantic Image Synthesis via Adversarial Learning" in ICCV 2017

Semantic Image Synthesis via Adversarial Learning This is a PyTorch implementation of the paper Semantic Image Synthesis via Adversarial Learning. Req

Seonghyeon Nam 146 Nov 25, 2022
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pedro Savarese 35 Jul 29, 2022
A python interface for training Reinforcement Learning bots to battle on pokemon showdown

The pokemon showdown Python environment A Python interface to create battling pokemon agents. poke-env offers an easy-to-use interface for creating ru

Haris Sahovic 184 Dec 30, 2022
Videocaptioning.pytorch - A simple implementation of video captioning

pytorch implementation of video captioning recommend installing pytorch and pyth

Yiyu Wang 2 Jan 01, 2022
A Python library for generating new text from existing samples.

ReMarkov is a Python library for generating text from existing samples using Markov chains. You can use it to customize all sorts of writing from birt

8 May 17, 2022
"Structure-Augmented Text Representation Learning for Efficient Knowledge Graph Completion"(WWW 2021)

STAR_KGC This repo contains the source code of the paper accepted by WWW'2021. "Structure-Augmented Text Representation Learning for Efficient Knowled

Bo Wang 60 Dec 26, 2022
Code for "Unsupervised Source Separation via Bayesian inference in the latent domain"

LQVAE-separation Code for "Unsupervised Source Separation via Bayesian inference in the latent domain" Paper Samples GT Compressed Separated Drums GT

Michele Mancusi 30 Oct 25, 2022
HCQ: Hybrid Contrastive Quantization for Efficient Cross-View Video Retrieval

HCQ: Hybrid Contrastive Quantization for Efficient Cross-View Video Retrieval [toc] 1. Introduction This repository provides the code for our paper at

13 Dec 08, 2022
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
📝 Wrapper library for text generation / language models at char and word level with RNN in TensorFlow

tensorlm Generate Shakespeare poems with 4 lines of code. Installation tensorlm is written in / for Python 3.4+ and TensorFlow 1.1+ pip3 install tenso

Kilian Batzner 63 May 22, 2021
PyTorch implementation of Convolutional Neural Fabrics http://arxiv.org/abs/1606.02492

PyTorch implementation of Convolutional Neural Fabrics arxiv:1606.02492 There are some minor differences: The raw image is first convolved, to obtain

Anuvabh Dutt 25 Dec 22, 2021
Codes for SIGIR'22 Paper 'On-Device Next-Item Recommendation with Self-Supervised Knowledge Distillation'

OD-Rec Codes for SIGIR'22 Paper 'On-Device Next-Item Recommendation with Self-Supervised Knowledge Distillation' Paper, saved teacher models and Andro

Xin Xia 11 Nov 22, 2022
🧮 Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model after All

Accompanying source code to the paper "Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model A

Florian Wilhelm 39 Dec 03, 2022