Official implementation of NeurIPS 2021 paper "Contextual Similarity Aggregation with Self-attention for Visual Re-ranking"

Related tags

Deep LearningCSA
Overview

CSA: Contextual Similarity Aggregation with Self-attention for Visual Re-ranking

PyTorch training code for CSA (Contextual Similarity Aggregation). We propose a visual re-ranking method by contextual similarity aggregation with transformer, obtaining 80.3 mAP on ROxf with Medium evaluation protocols. Inference in 50 lines of PyTorch.

CSA

What it is. Unlike traditional visual reranking techniques, CSA uses the similarity between the image and the anchor image as a representation of the image, which is defined as affinity feature. It consists of a contrastive loss that forces the relevant images to have larger cosine similarity and vice versa, an MSE loss that preserves the information of the original affinity features, and a Transformer encoder architecture. Given ranking list returned by the first-round retrieval, CSA first choose the top-L images in ranking list as the anchor images and calculates the affinity features of the top-K candidates,then dynamically refine the affinity features of different candiates in parallel. Due to this parallel nature, CSA is very fast and efficient.

About the code. CSA is very simple to implement and experiment with, and we provide a Notebook showing how to do inference with CSA in only a few lines of PyTorch code. Training code follows this idea - it is not a library, but simply a train.py importing model and criterion definitions with standard training loops.

mAP performance of the proposed model

We provide results of baseline CSA and CSA trained with data augmentation. mAP is computed with Medium and Hard evaluation protocols. model will come soon. CSA

Requirements

  • Python 3
  • PyTorch tested on 1.7.1+, torchvision 0.8.2+
  • numpy
  • matplotlib

Usage - Visual Re-ranking

There are no extra compiled components in CSA and package dependencies are minimal, so the code is very simple to use. We provide instructions how to install dependencies via conda. Install PyTorch 1.7.1+ and torchvision 0.8.2+:

conda install -c pytorch pytorch torchvision

Data preparation

Before going further, please check out Filip Radenovic's great repository on image retrieval. We use his code and model to extract features for training images. If you use this code in your research, please also cite their work! link to license

Download and extract rSfm120k train and val images with annotations from http://cmp.felk.cvut.cz/cnnimageretrieval/.

Download ROxf and RPar datastes with annotations. Prepare features for testing and training images with Filip Radenovic's model and code. We expect the directory structure to be the following:

path/to/data/
  ├─ annotations # annotation pkl files
  │   ├─ retrieval-SfM-120k.pkl
  │   ├─ roxford5k
  |   |   ├─ gnd_roxford5k.mat
  |   |   └─ gnd_roxford5k.pkl
  |   └─ rparis6k
  |   |   ├─ gnd_rparis6k.mat
  |   |   └─ gnd_rparis6k.pkl
  ├─ test # test features		
  |   ├─ r1m
  |   |   ├─ gl18-tl-resnet101-gem-w.pkl
  |   |   └─ rSfM120k-tl-resnet101-gem-w.pkl
  │   ├─ roxford5k
  |   |   ├─ gl18-tl-resnet101-gem-w.pkl
  |   |   └─ rSfM120k-tl-resnet101-gem-w.pkl
  |   └─ rparis6k
  |   |   ├─ gl18-tl-resnet101-gem-w.pkl
  |   |   └─ rSfM120k-tl-resnet101-gem-w.pkl
  └─ train # train features
      ├─ gl18-tl-resnet50-gem-w.pkl
      ├─ gl18-tl-resnet101-gem-w.pkl
      └─ gl18-tl-resnet152-gem-w.pkl

Training

To train baseline CSA on a single node with 4 gpus for 100 epochs run:

sh experiment_rSfm120k.sh

A single epoch takes 10 minutes, so 100 epoch training takes around 17 hours on a single machine with 4 2080Ti cards. To ease reproduction of our results we provide results and training logs for 200 epoch schedule (34 hours on a single machine).

We train CSA with SGD setting learning rate in the transformer to 0.1. The transformer is trained with dropout of 0.1, and the whole model is trained with grad clip of 1.0. To train CSA with data augmentation a single node with 4 gpus for 100 epochs run:

sh experiment_augrSfm120k.sh

Evaluation

To evaluate CSA on Roxf and Rparis with a single GPU run:

sh test.sh

and get results as below

>> Test Dataset: roxford5k *** fist-stage >>
>> gl18-tl-resnet101-gem-w: mAP Medium: 67.3, Hard: 44.24
>> gl18-tl-resnet101-gem-w: [email protected][1, 5, 10] Medium: [95.71 90.29 84.57], Hard: [87.14 69.71 59.86]

>> Test Dataset: roxford5k *** rerank-topk1024 >>
>> gl18-tl-resnet101-gem-w: mAP Medium: 77.92, Hard: 58.43
>> gl18-tl-resnet101-gem-w: [email protected][1, 5, 10] Medium: [94.29 93.14 89.71], Hard: [87.14 83.43 73.14]

>> Test Dataset: rparis6k *** fist-stage >>
>> gl18-tl-resnet101-gem-w: mAP Medium: 80.57, Hard: 61.46
>> gl18-tl-resnet101-gem-w: [email protected][1, 5, 10] Medium: [100.    98.    96.86], Hard: [97.14 93.14 90.57]

>> Test Dataset: rparis6k *** query-rerank-1024 >>
>> gl18-tl-resnet101-gem-w: mAP Medium: 87.2, Hard: 74.41
>> gl18-tl-resnet101-gem-w: [email protected][1, 5, 10] Medium: [100.    97.14  96.57], Hard: [95.71 92.86 90.14]

Qualitative examples

Selected qualitative examples of our re-ranking method. Top-10 results are shown in the figure. The figure is divided into four groups which consist of a result of initial retrieval and a result of our re-ranking method. The first two groups are the successful cases and the other two groups arethe failed cases. The images on the left with orange bounding boxes are the queries. The image with green denotes the true positives and the red bounding boxes are false positives. CSA

License

CSA is released under the MIT license. Please see the LICENSE file for more information.

Owner
Hui Wu
Department of Electronic Engineering and Information Science University of Science and Technology of China
Hui Wu
Code for Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations

Implementation for Iso-Points (CVPR 2021) Official code for paper Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations paper |

Yifan Wang 66 Nov 08, 2022
Efficient and Scalable Physics-Informed Deep Learning and Scientific Machine Learning on top of Tensorflow for multi-worker distributed computing

Notice: Support for Python 3.6 will be dropped in v.0.2.1, please plan accordingly! Efficient and Scalable Physics-Informed Deep Learning Collocation-

tensordiffeq 74 Dec 09, 2022
Net2net - Network-to-Network Translation with Conditional Invertible Neural Networks

Net2Net Code accompanying the NeurIPS 2020 oral paper Network-to-Network Translation with Conditional Invertible Neural Networks Robin Rombach*, Patri

CompVis Heidelberg 206 Dec 20, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees

ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees This repository is the official implementation of the empirica

Kuan-Lin (Jason) Chen 2 Oct 02, 2022
Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet bench

Google Research 169 Dec 20, 2022
Pytorch implementation of Bert and Pals: Projected Attention Layers for Efficient Adaptation in Multi-Task Learning

PyTorch implementation of BERT and PALs Introduction Work by Asa Cooper Stickland and Iain Murray, University of Edinburgh. Code for BERT and PALs; mo

Asa Cooper Stickland 70 Dec 29, 2022
Patch-Based Deep Autoencoder for Point Cloud Geometry Compression

Patch-Based Deep Autoencoder for Point Cloud Geometry Compression Overview The ever-increasing 3D application makes the point cloud compression unprec

17 Dec 05, 2022
We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which

NEU-StatsML-Research 21 Sep 08, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf

Behavior-Sequence-Transformer-Pytorch This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf This model

Jaime Ferrando Huertas 83 Jan 05, 2023
This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 832 Jan 08, 2023
A sketch extractor for anime/illustration.

Anime2Sketch Anime2Sketch: A sketch extractor for illustration, anime art, manga By Xiaoyu Xiang Updates 2021.5.2: Upload more example results of anim

Xiaoyu Xiang 1.6k Jan 01, 2023
A unet implementation for Image semantic segmentation

Unet-pytorch a unet implementation for Image semantic segmentation 参考网上的Unet做分割的代码,做了一个针对kaggle地盐识别的,请去以下地址获取数据集: https://www.kaggle.com/c/tgs-salt-id

Rabbit 3 Jun 29, 2022
Learning to Predict Gradients for Semi-Supervised Continual Learning

Learning to Predict Gradients for Semi-Supervised Continual Learning Code for project: "Learning to Predict Gradients for Semi-Supervised Continual Le

Yan Luo 2 Mar 05, 2022
Code for "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" paper

UNICORN 🦄 Webpage | Paper | BibTex PyTorch implementation of "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" pap

118 Jan 06, 2023
Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Şebnem 6 Jan 18, 2022
TICC is a python solver for efficiently segmenting and clustering a multivariate time series

TICC TICC is a python solver for efficiently segmenting and clustering a multivariate time series. It takes as input a T-by-n data matrix, a regulariz

406 Dec 12, 2022
Solve a Rubiks Cube using Python Opencv and Kociemba module

Rubiks_Cube_Solver Solve a Rubiks Cube using Python Opencv and Kociemba module Main Steps Get the countours of the cube check whether there are tota

Adarsh Badagala 176 Jan 01, 2023
Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper]

Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper] Downloads [Downloads] Trained ckpt files for NYU Depth V2 and

98 Jan 01, 2023