PyTorch Implementation for Deep Metric Learning Pipelines

Overview

Easily Extendable Basic Deep Metric Learning Pipeline

Karsten Roth ([email protected]), Biagio Brattoli ([email protected])

When using this repo in any academic work, please provide a reference to

@misc{roth2020revisiting,
    title={Revisiting Training Strategies and Generalization Performance in Deep Metric Learning},
    author={Karsten Roth and Timo Milbich and Samarth Sinha and Prateek Gupta and Björn Ommer and Joseph Paul Cohen},
    year={2020},
    eprint={2002.08473},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Based on an extendend version of this repo, we have created a thorough comparison and evaluation of Deep Metric Learning:

https://arxiv.org/abs/2002.08473

The newly released code can be found here: https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch

It contains more criteria, miner, metrics and logging options!


For usage, go to section 3 - for results to section 4

1. Overview

This repository contains a full, easily extendable pipeline to test and implement current and new deep metric learning methods. For referencing and testing, this repo contains implementations/dataloaders for:

Loss Functions

Sampling Methods

Datasets

Architectures

NOTE: PKU Vehicle-ID is (optional) because there is no direct way to download the dataset, as it requires special licensing. However, if this dataset becomes available (in the structure shown in part 2.2), it can be used directly.


1.1 Related Repos:


2. Repo & Dataset Structure

2.1 Repo Structure

Repository
│   ### General Files
│   README.md
│   requirements.txt    
│   installer.sh
|
|   ### Main Scripts
|   Standard_Training.py     (main training script)
|   losses.py   (collection of loss and sampling impl.)
│   datasets.py (dataloaders for all datasets)
│   
│   ### Utility scripts
|   auxiliaries.py  (set of useful utilities)
|   evaluate.py     (set of evaluation functions)
│   
│   ### Network Scripts
|   netlib.py       (contains impl. for ResNet50)
|   googlenet.py    (contains impl. for GoogLeNet)
│   
│   
└───Training Results (generated during Training)
|    │   e.g. cub200/Training_Run_Name
|    │   e.g. cars196/Training_Run_Name
|
│   
└───Datasets (should be added, if one does not want to set paths)
|    │   cub200
|    │   cars196
|    │   online_products
|    │   in-shop
|    │   vehicle_id

2.2 Dataset Structures

CUB200-2011/CARS196

cub200/cars196
└───images
|    └───001.Black_footed_Albatross
|           │   Black_Footed_Albatross_0001_796111
|           │   ...
|    ...

Online Products

online_products
└───images
|    └───bicycle_final
|           │   111085122871_0.jpg
|    ...
|
└───Info_Files
|    │   bicycle.txt
|    │   ...

In-Shop Clothes

in-shop
└─img
|    └─MEN
|         └─Denim
|               └─id_00000080
|                  │   01_1_front.jpg
|                  │   ...
|               ...
|         ...
|    ...
|
└─Eval
|  │   list_eval_partition.txt

PKU Vehicle ID

vehicle_id
└───image
|     │   <img>.jpg
|     |   ...
|     
└───train_test_split
|     |   test_list_800.txt
|     |   ...

3. Using the Pipeline

[1.] Requirements

The pipeline is build around Python3 (i.e. by installing Miniconda https://conda.io/miniconda.html') and Pytorch 1.0.0/1. It has been tested around cuda 8 and cuda 9.

To install the required libraries, either directly check requirements.txt or create a conda environment:

conda create -n <Env_Name> python=3.6

Activate it

conda activate <Env_Name>

and run

bash installer.sh

Note that for kMeans- and Nearest Neighbour Computation, the library faiss is used, which can allow to move these computations to GPU if speed is desired. However, in most cases, faiss is fast enough s.t. the computation of evaluation metrics is no bottleneck.
NOTE: If one wishes not to use faiss but standard sklearn, simply use auxiliaries_nofaiss.py to replace auxiliaries.py when importing the libraries.

[2.] Exemplary Runs

The main script is Standard_Training.py. If running without input arguments, training of ResNet50 on CUB200-2011 with Marginloss and Distance-sampling is performed.
Otherwise, the following flags suffice to train with different losses, sampling methods, architectures and datasets:

python Standard_Training.py --dataset <dataset> --loss <loss> --sampling <sampling> --arch <arch> --k_vals <k_vals> --embed_dim <embed_dim>

The following flags are available:

  • <dataset> <- cub200, cars196, online_products, in-shop, vehicle_id
  • <loss> <- marginloss, triplet, npair, proxynca
  • <sampling> <- distance, semihard, random, npair
  • <arch> <- resnet50, googlenet
  • <k_vals> <- List of Recall @ k values to evaluate on, e.g. 1 2 4 8
  • <embed_dim> <- Network embedding dimension. Default: 128 for ResNet50, 512 for GoogLeNet.

For all other training-specific arguments (e.g. batch-size, num. training epochs., ...), simply refer to the input arguments in Standard_Training.py.

NOTE: If one wishes to use a different learning rate for the final linear embedding layer, the flag --fc_lr_mul needs to be set to a value other than zero (i.e. 10 as is done in various implementations).

Finally, to decide the GPU to use and the name of the training folder in which network weights, sample recoveries and metrics are stored, set:

python Standard_Training.py --gpu <gpu_id> --savename <name_of_training_run>

If --savename is not set, a default name based on the starting date will be chosen.

If one wishes to simply use standard parameters and wants to get close to literature results (more or less, depends on seeds and overall training scheduling), refer to sample_training_runs.sh, which contains a list of executable one-liners.

[3.] Implementation Notes regarding Extendability:

To extend or test other sampling or loss methods, simply do:

For Batch-based Sampling:
In losses.py, add the sampling method, which should act on a batch (and the resp. set of labels), e.g.:

def new_sampling(self, batch, label, **additional_parameters): ...

This function should, if it needs to run with existing losses, a list of tuples containing indexes with respect to the batch, e.g. for sampling methods returning triplets:

return [(anchor_idx, positive_idx, negative_idx) for anchor_idx, positive_idx, negative_idx in zip(anchor_idxs, positive_idxs, negative_idxs)]

Also, don't forget to add a handle in Sampler.__init__().

For Data-specific Sampling:
To influence the data samples used to generate the batches, in datasets.py edit BaseTripletDataset.

For New Loss Functions:
Simply add a new class inheriting from torch.nn.Module. Refer to other loss variants to see how to do so. In general, include an instance of the Sampler-class, which will provide sampled data tuples during a forward()-pass, by calling self.sampler_instance.give(batch, labels, **additional_parameters).
Finally, include the loss function in the loss_select()-function. Parameters can be passed through the dictionary-notation (see other examples) and if learnable parameters are added, include them in the to_optim-list.

[4.] Stored Data:

By default, the following files are saved:

Name_of_Training_Run
|  checkpoint.pth.tar   -> Contains network state-dict.
|  hypa.pkl             -> Contains all network parameters as pickle.
|                          Can be used directly to recreate the network.
| log_train_Base.csv    -> Logged training data as CSV.                      
| log_val_Base.csv      -> Logged test metrics as CSV.                    
| Parameter_Info.txt    -> All Parameters stored as readable text-file.
| InfoPlot_Base.svg     -> Graphical summary of training/testing metrics progression.
| sample_recoveries.png -> Sample recoveries for best validation weights.
|                          Acts as a sanity test.

Sample Recoveries Note: Red denotes query images, while green show the resp. nearest neighbours.

Sample Recoveries Note: The header in the summary plot shows the best testing metrics over the whole run.

[5.] Additional Notes:

To finalize, several flags might be of interest when examining the respective runs:

--dist_measure: If set, the ratio of mean intraclass-distances over mean interclass distances
                (by measure of center-of-mass distances) is computed after each epoch and stored/plotted.
--grad_measure: If set, the average (absolute) gradients from the embedding layer to the last
                conv. layer are stored in a Pickle-File. This can be used to examine the change of features during each iteration.

For more details, refer to the respective classes in auxiliaries.py.


4. Results

These results are supposed to be performance estimates achieved by running the respective commands in sample_training_runs.sh. Note that the learning rate scheduling might not be fully optimised, so these values should only serve as reference/expectation, not what can be ultimately achieved with more tweaking.

Note also that there is a not insignificant dependency on the used seed.

CUB200

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 2 -- 4 -- 8
ResNet50 Margin/Distance 68.2 38.7 63.4 -- 74.9 -- 86.0 -- 90.4
ResNet50 Triplet/Softhard 66.2 35.5 61.2 -- 73.2 -- 82.4 -- 89.5
ResNet50 NPair/None 65.4 33.8 59.0 -- 71.3 -- 81.1 -- 88.8
ResNet50 ProxyNCA/None 68.1 38.1 64.0 -- 75.4 -- 84.2 -- 90.5

Cars196

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 2 -- 4 -- 8
ResNet50 Margin/Distance 67.2 37.6 79.3 -- 87.1 -- 92.1 -- 95.4
ResNet50 Triplet/Softhard 64.4 32.4 75.4 -- 84.2 -- 90.1 -- 94.1
ResNet50 NPair/None 62.3 30.1 69.5 -- 80.2 -- 87.3 -- 92.1
ResNet50 ProxyNCA/None 66.3 35.8 80.0 -- 87.2 -- 91.8 -- 95.1

Online Products

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 10 -- 100 -- 1000
ResNet50 Margin/Distance 89.6 34.9 76.1 -- 88.7 -- 95.1 -- 98.3
ResNet50 Triplet/Softhard 89.1 33.7 74.3 -- 87.6 -- 94.9 -- 98.5
ResNet50 NPair/None 88.8 31.1 70.9 -- 85.2 -- 93.8 -- 98.2

In-Shop Clothes

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 10 -- 20 -- 30 -- 50
ResNet50 Margin/Distance 88.2 27.7 84.5 -- 96.1 -- 97.4 -- 97.9 -- 98.5
ResNet50 Triplet/Semihard 89.0 30.8 83.9 -- 96.3 -- 97.6 -- 98.4 -- 98.8
ResNet50 NPair/None 88.0 27.6 80.9 -- 95.0 -- 96.6 -- 97.5 -- 98.2

NOTE:

  1. Regarding Vehicle-ID: Due to the number of test sets, size of the training set and little public accessibility, results are not included for the time being.
  2. Regarding ProxyNCA for Online Products and In-Shop Clothes: Due to the high number of classes, the number of proxies required is too high for useful training (>10000 proxies).

ToDO:

  • Fix Version in requirements.txt
  • Add Results for Implementations
  • Finalize Comments
  • Add Inception-BN
  • Add Lifted Structure Loss
Owner
Karsten Roth
PhD (IMPRS-IS, ELLIS) EML Tuebingen | prev. @VectorInstitute, @mila-iqia and @aws.
Karsten Roth
SOLOv2 on onnx & tensorRT

SOLOv2.tensorRT: NOTE: code based on WXinlong/SOLO add support to TensorRT inference onnxruntime tensorRT full_dims and dynamic shape postprocess with

47 Nov 26, 2022
Python package for covariance matrices manipulation and Biosignal classification with application in Brain Computer interface

pyRiemann pyRiemann is a python package for covariance matrices manipulation and classification through Riemannian geometry. The primary target is cla

447 Jan 05, 2023
Code for KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs

KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs Check out the paper on arXiv: https://arxiv.org/abs/2103.13744 This repo cont

Christian Reiser 373 Dec 20, 2022
Hybrid CenterNet - Hybrid-supervised object detection / Weakly semi-supervised object detection

Hybrid-Supervised Object Detection System Object detection system trained by hybrid-supervision/weakly semi-supervision (HSOD/WSSOD): This project is

5 Dec 10, 2022
Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training.

Updates (2020/06/21) Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training. Pyr

1.3k Jan 04, 2023
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
PyTorch deep learning projects made easy.

PyTorch Template Project PyTorch deep learning project made easy. PyTorch Template Project Requirements Features Folder Structure Usage Config file fo

Victor Huang 3.8k Jan 01, 2023
An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicity.

Fast Face Classification (F²C) This is the code of our paper An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicit

33 Jun 27, 2021
Which Style Makes Me Attractive? Interpretable Control Discovery and Counterfactual Explanation on StyleGAN

Interpretable Control Exploration and Counterfactual Explanation (ICE) on StyleGAN Which Style Makes Me Attractive? Interpretable Control Discovery an

Bo Li 11 Dec 01, 2022
Source Code for our paper: Understand me, if you refer to Aspect Knowledge: Knowledge-aware Gated Recurrent Memory Network

KaGRMN-DSG_ABSA This repository contains the PyTorch source Code for our paper: Understand me, if you refer to Aspect Knowledge: Knowledge-aware Gated

XingBowen 4 May 20, 2022
Official Pytorch Implementation of Unsupervised Image Denoising with Frequency Domain Knowledge

Unsupervised Image Denoising with Frequency Domain Knowledge (BMVC 2021 Oral) : Official Project Page This repository provides the official PyTorch im

Donggon Jang 12 Sep 26, 2022
The official PyTorch code for 'DER: Dynamically Expandable Representation for Class Incremental Learning' accepted by CVPR2021

DER.ClassIL.Pytorch This repo is the official implementation of DER: Dynamically Expandable Representation for Class Incremental Learning (CVPR 2021)

rhyssiyan 108 Jan 01, 2023
Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization

Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization This repository contains the code for the BBI optimizer, introduced in the p

G. Bruno De Luca 5 Sep 06, 2022
Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN)

Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN) This is the implementation of the paper Multi-Age

Future Power Networks 83 Jan 06, 2023
Repo for the paper "DiLBERT: Cheap Embeddings for Disease Related Medical NLP"

DiLBERT Repo for the paper "DiLBERT: Cheap Embeddings for Disease Related Medical NLP" Pretrained Model The pretrained model presented in the paper is

Kevin Roitero 2 Dec 15, 2022
Code for reproducing experiments in "Improved Training of Wasserstein GANs"

Improved Training of Wasserstein GANs Code for reproducing experiments in "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, Tensor

Ishaan Gulrajani 2.2k Jan 01, 2023
SNIPS: Solving Noisy Inverse Problems Stochastically

SNIPS: Solving Noisy Inverse Problems Stochastically This repo contains the official implementation for the paper SNIPS: Solving Noisy Inverse Problem

Bahjat Kawar 35 Nov 09, 2022
LBBA-boosted WSOD

LBBA-boosted WSOD Summary Our code is based on ruotianluo/pytorch-faster-rcnn and WSCDN Sincerely thanks for your resources. Newer version of our code

Martin Dong 20 Sep 19, 2022
TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffic Environments for IV 2022.

TorchGRL TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffi

XXQQ 42 Dec 09, 2022
Simulator for FRC 2022 challenge: Rapid React

rrsim Simulator for FRC 2022 challenge: Rapid React out-1.mp4 Usage In order to run the simulator use the following: python3 rrsim.py [config_path] wh

1 Jan 18, 2022