PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Overview

Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Created by Prarthana Bhattacharyya.

Disclaimer: This is not an official product and is meant to be a proof-of-concept and for academic/educational use only.

This repository contains the PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime, to be presented at ICASSP-2022.

Self-supervision has shown outstanding results for natural language processing, and more recently, for image recognition. Simultaneously, vision transformers and its variants have emerged as a promising and scalable alternative to convolutions on various computer vision tasks. In this paper, we are the first to question if self-supervised vision transformers (SSL-ViTs) can be adapted to two important computer vision tasks in the low-label, high-data regime: few-shot image classification and zero-shot image retrieval. The motivation is to reduce the number of manual annotations required to train a visual embedder, and to produce generalizable, semantically meaningful and robust embeddings.


Results

  • SSL-ViT + few-shot image classification:
  • Qualitative analysis for base-classes chosen by supervised CNN and SSL-ViT for few-shot distribution calibration:
  • SSL-ViT + zero-shot image retrieval:

Pretraining Self-Supervised ViT

  • Run DINO with ViT-small network on a single node with 4 GPUs for 100 epochs with the following command.
cd dino/
python -m torch.distributed.launch --nproc_per_node=4 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
  • For mini-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_mini.txt For tiered-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_tiered.txt
  • For CUB-200, Cars-196 and SOP, we use the pretrained model from:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

Visual Representation Learning with Self-Supervised ViT for Low-Label High-Data Regime

Dataset Preparation

Please follow the instruction in FRN for few-shot image classification and RevisitDML for zero-shot image retrieval to download the datasets and put the corresponding datasets in ssl-vit-fewshot/data and DIML/data folder.

Training and Evaluation for few-shot image classification

  • The first step is to extract features for base and novel classes using the pretrained SSL-ViT.
  • get_dino_miniimagenet_feats.ipynb extracts SSL-ViT features for the base and novel classes.
  • Change the hyper-parameter data_path to use CUB or tiered-ImageNet.
  • The SSL-ViT checkpoints for the various datasets are provided below (Note: this has only been trained without labels). We also provide the extracted features which need to be stored in ssl-vit-fewshot/dino_features_data/.
arch dataset download extracted-train extracted-test
ViT-S/16 mini-ImageNet mini_imagenet_checkpoint.pth train.p test.p
ViT-S/16 tiered-ImageNet tiered_imagenet_checkpoint.pth train.p test.p
ViT-S/16 CUB cub_checkpoint.pth train.p test.p
  • For n-way-k-shot evaluation, we provide miniimagenet_evaluate_dinoDC.ipynb.

Training and Evaluation for zero-shot image retrieval

  • To train the baseline CNN models, run the scripts in DIML/scripts/baselines. The checkpoints are saved in Training_Results folder. For example:
cd DIML/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh
  • To train the supervised ViT and self-supervised ViT:
cp -r ssl-vit-retrieval/architectures/* DIML/ssl-vit-retrieval/architectures/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch vits
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch dino
  • To test the models, first edit the checkpoint paths in test_diml.py, then run
CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200
dataset Loss SSL-ViT-download
CUB Margin cub_ssl-vit-margin.pth
CUB Proxy-NCA cub_ssl-vit-proxynca.pth
CUB Multi-Similarity cub_ssl-vit-ms.pth
Cars-196 Margin cars_ssl-vit-margin.pth
Cars-196 Proxy-NCA cars_ssl-vit-proxynca.pth
Cars-196 Multi-Similarity cars_ssl-vit-ms.pth

Acknowledgement

The code is based on:

Owner
Prarthana Bhattacharyya
Ph.D. Candidate @WISELab-UWaterloo
Prarthana Bhattacharyya
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
All supplementary material used by me while TA-ing CS3244: Machine Learning

CS3244-Tutorial-Material All supplementary material used by me while TA-ing CS3244: Machine Learning at NUS School of Computing. What is this? I teach

Rishabh Anand 18 Sep 23, 2022
Code for the paper: Audio-Visual Scene Analysis with Self-Supervised Multisensory Features

[Paper] [Project page] This repository contains code for the paper: Andrew Owens, Alexei A. Efros. Audio-Visual Scene Analysis with Self-Supervised Mu

Andrew Owens 202 Dec 13, 2022
Depression Asisstant GDSC Challenge Solution

Depression Asisstant can help you give solution. Please using Python version 3.9.5 for contribute.

Ananda Rauf 1 Jan 30, 2022
The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper.

Intermdiate layer matters - SSL The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper. Downl

Aakash Kaku 35 Sep 19, 2022
Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Efficient Sharpness-aware Minimization for Improved Training of Neural Networks Code for “Efficient Sharpness-aware Minimization for Improved Training

Angusdu 32 Oct 18, 2022
A Kaggle competition: discriminate gender based on handwriting

Gender discrimination based on handwriting See http://fastml.com/gender-discrimination/ for description. prep_data.py - a first step chunk_by_authors.

Zygmunt Zając 22 Jul 20, 2022
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
Robust Partial Matching for Person Search in the Wild

APNet for Person Search Introduction This is the code of Robust Partial Matching for Person Search in the Wild accepted in CVPR2020. The Align-to-Part

Yingji Zhong 36 Dec 18, 2022
[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Graph Stochastic Attention (GSAT) The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Att

85 Nov 27, 2022
Human annotated noisy labels for CIFAR-10 and CIFAR-100.

Dataloader for CIFAR-N CIFAR-10N noise_label = torch.load('./data/CIFAR-10_human.pt') clean_label = noise_label['clean_label'] worst_label = noise_lab

<a href=[email protected]"> 117 Nov 30, 2022
Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations

Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations Code repo for paper Trans-Encoder: Unsupervised sentence-pa

Amazon 101 Dec 29, 2022
Crossover Learning for Fast Online Video Instance Segmentation (ICCV 2021)

TL;DR: CrossVIS (Crossover Learning for Fast Online Video Instance Segmentation) proposes a novel crossover learning paradigm to fully leverage rich c

Hust Visual Learning Team 79 Nov 25, 2022
An Evaluation of Generative Adversarial Networks for Collaborative Filtering.

An Evaluation of Generative Adversarial Networks for Collaborative Filtering. This repository was developed by Fernando B. Pérez Maurera. Fernando is

Fernando Benjamín PÉREZ MAURERA 0 Jan 19, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Code for "Primitive Representation Learning for Scene Text Recognition" (CVPR 2021)

Primitive Representation Learning Network (PREN) This repository contains the code for our paper accepted by CVPR 2021 Primitive Representation Learni

Ruijie Yan 76 Jan 02, 2023
Credo AI Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data assessment, and acts as a central gateway to assessments created in the open source community.

Lens by Credo AI - Responsible AI Assessment Framework Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data a

Credo AI 27 Dec 14, 2022
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
Source code for our paper "Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash"

Learning to Break Deep Perceptual Hashing: The Use Case NeuralHash Abstract: Apple recently revealed its deep perceptual hashing system NeuralHash to

<a href=[email protected]"> 11 Dec 03, 2022
A fast implementation of bss_eval metrics for blind source separation

fast_bss_eval Do you have a zillion BSS audio files to process and it is taking days ? Is your simulation never ending ? Fear no more! fast_bss_eval i

Robin Scheibler 99 Dec 13, 2022