PyTorch implementation of SwAV (Swapping Assignments between Views)

Related tags

Deep Learningswav
Overview

Unsupervised Learning of Visual Features by Contrasting Cluster Assignments

This code provides a PyTorch implementation and pretrained models for SwAV (Swapping Assignments between Views), as described in the paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.

SwAV Illustration

SwAV is an efficient and simple method for pre-training convnets without using annotations. Similarly to contrastive approaches, SwAV learns representations by comparing transformations of an image, but unlike contrastive methods, it does not require to compute feature pairwise comparisons. It makes our framework more efficient since it does not require a large memory bank or an auxiliary momentum network. Specifically, our method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations (or “views”) of the same image, instead of comparing features directly. Simply put, we use a “swapped” prediction mechanism where we predict the cluster assignment of a view from the representation of another view. Our method can be trained with large and small batches and can scale to unlimited amounts of data.

Model Zoo

We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone. To load our best SwAV pre-trained ResNet-50 model, simply do:

import torch
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')

We provide several baseline SwAV pre-trained models with ResNet-50 architecture in torchvision format. We also provide models pre-trained with DeepCluster-v2 and SeLa-v2 obtained by applying improvements from the self-supervised community to DeepCluster and SeLa (see details in the appendix of our paper).

method epochs batch-size multi-crop ImageNet top-1 acc. url args
SwAV 800 4096 2x224 + 6x96 75.3 model script
SwAV 400 4096 2x224 + 6x96 74.6 model script
SwAV 200 4096 2x224 + 6x96 73.9 model script
SwAV 100 4096 2x224 + 6x96 72.1 model script
SwAV 200 256 2x224 + 6x96 72.7 model script
SwAV 400 256 2x224 + 6x96 74.3 model script
SwAV 400 4096 2x224 70.1 model script
DeepCluster-v2 800 4096 2x224 + 6x96 75.2 model script
DeepCluster-v2 400 4096 2x160 + 4x96 74.3 model script
DeepCluster-v2 400 4096 2x224 70.2 model script
SeLa-v2 400 4096 2x160 + 4x96 71.8 model -
SeLa-v2 400 4096 2x224 67.2 model -

Larger architectures

We provide SwAV models with ResNet-50 networks where we multiply the width by a factor ×2, ×4, and ×5. To load the corresponding backbone you can use:

import torch
rn50w2 = torch.hub.load('facebookresearch/swav:main', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav:main', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')
network parameters epochs ImageNet top-1 acc. url args
RN50-w2 94M 400 77.3 model script
RN50-w4 375M 400 77.9 model script
RN50-w5 586M 400 78.5 model -

Running times

We provide the running times for some of our runs:

method batch-size multi-crop scripts time per epoch
SwAV 4096 2x224 + 6x96 * * * * 3min40s
SwAV 256 2x224 + 6x96 * * 52min10s
DeepCluster-v2 4096 2x160 + 4x96 * 3min13s

Running SwAV unsupervised training

Requirements

Singlenode training

SwAV is very simple to implement and experiment with. Our implementation consists in a main_swav.py file from which are imported the dataset definition src/multicropdataset.py, the model architecture src/resnet50.py and some miscellaneous training utilities src/utils.py.

For example, to train SwAV baseline on a single node with 8 gpus for 400 epochs, run:

python -m torch.distributed.launch --nproc_per_node=8 main_swav.py \
--data_path /path/to/imagenet/train \
--epochs 400 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15

Multinode training

Distributed training is available via Slurm. We provide several SBATCH scripts to reproduce our SwAV models. For example, to train SwAV on 8 nodes and 64 GPUs with a batch size of 4096 for 800 epochs run:

sbatch ./scripts/swav_800ep_pretrain.sh

Note that you might need to remove the copyright header from the sbatch file to launch it.

Set up dist_url parameter: We refer the user to pytorch distributed documentation (env or file or tcp) for setting the distributed initialization method (parameter dist_url) correctly. In the provided sbatch files, we use the tcp init method (see * for example).

Evaluating models

Evaluate models: Linear classification on ImageNet

To train a supervised linear classifier on frozen features/weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar

The resulting linear classifier can be downloaded here.

Evaluate models: Semi-supervised learning on ImageNet

To reproduce our results and fine-tune a network with 1% or 10% of ImageNet labels on a single node with 8 gpus, run:

  • 10% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "10" \
--lr 0.01 \
--lr_last_layer 0.2
  • 1% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "1" \
--lr 0.02 \
--lr_last_layer 5

Evaluate models: Transferring to Detection with DETR

DETR is a recent object detection framework that reaches competitive performance with Faster R-CNN while being conceptually simpler and trainable end-to-end. We evaluate our SwAV ResNet-50 backbone on object detection on COCO dataset using DETR framework with full fine-tuning. Here are the instructions for reproducing our experiments:

  1. Install detr and prepare COCO dataset following these instructions.

  2. Apply the changes highlighted in this gist to detr backbone file in order to load SwAV backbone instead of ImageNet supervised weights.

  3. Launch training from detr repository with run_with_submitit.py.

python run_with_submitit.py --batch_size 4 --nodes 2 --lr_backbone 5e-5

Common Issues

For help or issues using SwAV, please submit a GitHub issue.

The loss does not decrease and is stuck at ln(nmb_prototypes) (8.006 for 3000 prototypes).

It sometimes happens that the system collapses at the beginning and does not manage to converge. We have found the following empirical workarounds to improve convergence and avoid collapsing at the beginning:

  • use a lower epsilon value (--epsilon 0.03 instead of the default 0.05)
  • carefully tune the hyper-parameters
  • freeze the prototypes during first iterations (freeze_prototypes_niters argument)
  • switch to hard assignment
  • remove batch-normalization layer from the projection head
  • reduce the difficulty of the problem (less crops or softer data augmentation)

We now analyze the collapsing problem: it happens when all examples are mapped to the same unique representation. In other words, the convnet always has the same output regardless of its input, it is a constant function. All examples gets the same cluster assignment because they are identical, and the only valid assignment that satisfy the equipartition constraint in this case is the uniform assignment (1/K where K is the number of prototypes). In turn, this uniform assignment is trivial to predict since it is the same for all examples. Reducing epsilon parameter (see Eq(3) of our paper) encourages the assignments Q to be sharper (i.e. less uniform), which strongly helps avoiding collapse. However, using a too low value for epsilon may lead to numerical instability.

Training gets unstable when using the queue.

The queue is composed of feature representations from the previous batches. These lines discard the oldest feature representations from the queue and save the newest one (i.e. from the current batch) through a round-robin mechanism. This way, the assignment problem is performed on more samples: without the queue we assign B examples to num_prototypes clusters where B is the total batch size while with the queue we assign (B + queue_length) examples to num_prototypes clusters. This is especially useful when working with small batches because it improves the precision of the assignment.

If you start using the queue too early or if you use a too large queue, this can considerably disturb training: this is because the queue members are too inconsistent. After introducing the queue the loss should be lower than what it was without the queue. On the following loss curve (30 first epochs of this script) we introduced the queue at epoch 15. We observe that it made the loss go more down.

SwAV training loss batch_size=256 during the first 30 epochs

If when introducing the queue, the loss goes up and does not decrease afterwards you should stop your training and change the queue parameters. We recommend (i) using a smaller queue, (ii) starting the queue later in training.

License

See the LICENSE file for more details.

See also

PyTorch Lightning Bolts: Implementation by the Lightning team.

SwAV-TF: A TensorFlow re-implementation.

Citation

If you find this repository useful in your research, please cite:

@article{caron2020unsupervised,
  title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments},
  author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand},
  booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)},
  year={2020}
}
Owner
Meta Research
Meta Research
Another pytorch implementation of FCN (Fully Convolutional Networks)

FCN-pytorch-easiest Trying to be the easiest FCN pytorch implementation and just in a get and use fashion Here I use a handbag semantic segmentation f

Y. Dong 158 Dec 21, 2022
Visual Tracking by TridenAlign and Context Embedding

Visual Tracking by TridentAlign and Context Embedding (TACT) Test code for "Visual Tracking by TridentAlign and Context Embedding" Janghoon Choi, Juns

Janghoon Choi 32 Aug 25, 2021
Python-experiments - A Repository which contains python scripts to automate things and make your life easier with python

Python Experiments A Repository which contains python scripts to automate things

Vivek Kumar Singh 11 Sep 25, 2022
Code, Models and Datasets for OpenViDial Dataset

OpenViDial This repo contains downloading instructions for the OpenViDial dataset in 《OpenViDial: A Large-Scale, Open-Domain Dialogue Dataset with Vis

119 Dec 08, 2022
Intro-to-dl - Resources for "Introduction to Deep Learning" course.

Introduction to Deep Learning course resources https://www.coursera.org/learn/intro-to-deep-learning Running on Google Colab (tested for all weeks) Go

Advanced Machine Learning specialisation by HSE 761 Dec 24, 2022
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
Reinforcement learning framework and algorithms implemented in PyTorch.

Reinforcement learning framework and algorithms implemented in PyTorch.

Robotic AI & Learning Lab Berkeley 2.1k Jan 04, 2023
Code for the paper "Generative design of breakwaters usign deep convolutional neural network as a surrogate model"

Generative design of breakwaters usign deep convolutional neural network as a surrogate model This repository contains the code for the paper "Generat

2 Apr 10, 2022
piSTAR Lab is a modular platform built to make AI experimentation accessible and fun. (pistar.ai)

piSTAR Lab WARNING: This is an early release. Overview piSTAR Lab is a modular deep reinforcement learning platform built to make AI experimentation a

piSTAR Lab 0 Aug 01, 2022
Log4j JNDI inj. vuln scanner

Log-4-JAM - Log 4 Just Another Mess Log4j JNDI inj. vuln scanner Requirements pip3 install requests_toolbelt Usage # make sure target list has http/ht

Ashish Kunwar 66 Nov 09, 2022
Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

CSRA This is the official code of ICCV 2021 paper: Residual Attention: A Simple But Effective Method for Multi-Label Recoginition Demo, Train and Vali

163 Dec 22, 2022
MMGeneration is a powerful toolkit for generative models, based on PyTorch and MMCV.

Documentation: https://mmgeneration.readthedocs.io/ Introduction English | 简体中文 MMGeneration is a powerful toolkit for generative models, especially f

OpenMMLab 1.3k Dec 29, 2022
A Python Package for Convex Regression and Frontier Estimation

pyStoNED pyStoNED is a Python package that provides functions for estimating multivariate convex regression, convex quantile regression, convex expect

Sheng Dai 17 Jan 08, 2023
joint detection and semantic segmentation, based on ultralytics/yolov5,

Multi YOLO V5——Detection and Semantic Segmentation Overeview This is my undergraduate graduation project which based on ultralytics YOLO V5 tag v5.0.

477 Jan 06, 2023
K Closest Points and Maximum Clique Pruning for Efficient and Effective 3D Laser Scan Matching (To appear in RA-L 2022)

KCP The official implementation of KCP: k Closest Points and Maximum Clique Pruning for Efficient and Effective 3D Laser Scan Matching, accepted for p

Yu-Kai Lin 109 Dec 14, 2022
Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation

Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation This repository contains code and data f

Zoey Liu 0 Jan 07, 2022
[ICML 2021] Towards Understanding and Mitigating Social Biases in Language Models

Towards Understanding and Mitigating Social Biases in Language Models This repo contains code and data for evaluating and mitigating bias from generat

Paul Liang 42 Jan 03, 2023
Implementation of Hire-MLP: Vision MLP via Hierarchical Rearrangement and An Image Patch is a Wave: Phase-Aware Vision MLP.

Hire-Wave-MLP.pytorch Implementation of Hire-MLP: Vision MLP via Hierarchical Rearrangement and An Image Patch is a Wave: Phase-Aware Vision MLP Resul

Nevermore 29 Oct 28, 2022
PyTorch implementation for Partially View-aligned Representation Learning with Noise-robust Contrastive Loss (CVPR 2021)

2021-CVPR-MvCLN This repo contains the code and data of the following paper accepted by CVPR 2021 Partially View-aligned Representation Learning with

XLearning Group 33 Nov 01, 2022
Knowledge Management for Humans using Machine Learning & Tags

HyperTag HyperTag helps humans intuitively express how they think about their files using tags and machine learning.

Ravn Tech, Inc. 165 Nov 04, 2022