Compare neural networks by their feature similarity

Overview

PyTorch Model Compare

A tiny package to compare two neural networks in PyTorch. There are many ways to compare two neural networks, but one robust and scalable way is using the Centered Kernel Alignment (CKA) metric, where the features of the networks are compared.

Centered Kernel Alignment

Centered Kernel Alignment (CKA) is a representation similarity metric that is widely used for understanding the representations learned by neural networks. Specifically, CKA takes two feature maps / representations X and Y as input and computes their normalized similarity (in terms of the Hilbert-Schmidt Independence Criterion (HSIC)) as

CKA original version

Where K and L are similarity matrices of X and Y respectively. However, the above formula is not scalable against deep architectures and large datasets. Therefore, a minibatch version can be constructed that uses an unbiased estimator of the HSIC as

alt text

alt text

The above form of CKA is from the 2021 ICLR paper by Nguyen T., Raghu M, Kornblith S.

Getting Started

Installation

pip install torch_cka

Usage

from torch_cka import CKA
model1 = resnet18(pretrained=True)  # Or any neural network of your choice
model2 = resnet34(pretrained=True)

dataloader = DataLoader(your_dataset, 
                        batch_size=batch_size, # according to your device memory
                        shuffle=False)  # Don't forget to seed your dataloader

cka = CKA(model1, model2,
          model1_name="ResNet18",   # good idea to provide names to avoid confusion
          model2_name="ResNet34",   
          model1_layers=layer_names_resnet18, # List of layers to extract features from
          model2_layers=layer_names_resnet34, # extracts all layer features by default
          device='cuda')

cka.compare(dataloader) # secondary dataloader is optional

results = cka.export()  # returns a dict that contains model names, layer names
                        # and the CKA matrix

Examples

torch_cka can be used with any pytorch model (subclass of nn.Module) and can be used with pretrained models available from popular sources like torchHub, timm, huggingface etc. Some examples of where this package can come in handy are illustrated below.

Comparing the effect of Depth

A simple experiment is to analyse the features learned by two architectures of the same family - ResNets but of different depths. Taking two ResNets - ResNet18 and ResNet34 - pre-trained on the Imagenet dataset, we can analyse how they produce their features on, say CIFAR10 for simplicity. This comparison is shown as a heatmap below.

alt text

We see high degree of similarity between the two models in lower layers as they both learn similar representations from the data. However at higher layers, the similarity reduces as the deeper model (ResNet34) learn higher order features which the is elusive to the shallower model (ResNet18). Yet, they do indeed have certain similarity in their last fc layer which acts as the feature classifier.

Comparing Two Similar Architectures

Another way of using CKA is in ablation studies. We can go further than those ablation studies that only focus on resultant performance and employ CKA to study the internal representations. Case in point - ResNet50 and WideResNet50 (k=2). WideResNet50 has the same architecture as ResNet50 except having wider residual bottleneck layers (by a factor of 2 in this case).

alt text

We clearly notice that the learned features are indeed different after the first few layers. The width has a more pronounced effect in deeper layers as compared to the earlier layers as both networks seem to learn similar features in the initial layers.

As a bonus, here is a comparison between ViT and the latest SOTA model Swin Transformer pretrained on ImageNet22k.

alt text

Comparing quite different architectures

CNNs have been analysed a lot over the past decade since AlexNet. We somewhat know what sort of features they learn across their layers (through visualizations) and we have put them to good use. One interesting approach is to compare these understandable features with newer models that don't permit easy visualizations (like recent vision transformer architectures) and study them. This has indeed been a hot research topic (see Raghu et.al 2021).

alt text

Comparing Datasets

Yet another application is to compare two datasets - preferably two versions of the data. This is especially useful in production where data drift is a known issue. If you have an updated version of a dataset, you can study how your model will perform on it by comparing the representations of the datasets. This can be more telling about actual performance than simply comparing the datasets directly.

This can also be quite useful in studying the performance of a model on downstream tasks and fine-tuning. For instance, if the CKA score is high for some features on different datasets, then those can be frozen during fine-tuning. As an example, the following figure compares the features of a pretrained Resnet50 on the Imagenet test data and the VOC dataset. Clearly, the pretrained features have little correlation with the VOC dataset. Therefore, we have to resort to fine-tuning to get at least satisfactory results.

alt text

Tips

  • If your model is large (lots of layers or large feature maps), try to extract from select layers. This is to avoid out of memory issues.
  • If you still want to compare the entire feature map, you can run it multiple times with few layers at each iteration and export your data using cka.export(). The exported data can then be concatenated to produce the full CKA matrix.
  • Give proper model names to avoid confusion when interpreting the results. The code automatically extracts the model name for you by default, but it is good practice to label the models according to your use case.
  • When providing your dataloader(s) to the compare() function, it is important that they are seeded properly for reproducibility.
  • When comparing datasets, be sure to set drop_last=True when building the dataloader. This resolves shape mismatch issues - especially in differently sized datasets.

Citation

If you use this repo in your project or research, please cite as -

@software{subramanian2021torch_cka,
    author={Anand Subramanian},
    title={torch_cka},
    url={https://github.com/AntixK/PyTorch-Model-Compare},
    year={2021}
}
Owner
Anand Krishnamoorthy
Research Engineer
Anand Krishnamoorthy
Source code for our paper "Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures"

Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures Code for the Multiplex Molecular Graph Neural Network (M

shzhang 59 Dec 10, 2022
Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment"

DSN-IQA Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment" Requirements Python =3.8.0 Pytorch =1.7.1 Usage wit

7 Oct 13, 2022
Solver for Large-Scale Rank-One Semidefinite Relaxations

STRIDE: spectrahedral proximal gradient descent along vertices A Solver for Large-Scale Rank-One Semidefinite Relaxations About STRIDE is designed for

48 Dec 20, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
Semantic Bottleneck Scene Generation

SB-GAN Semantic Bottleneck Scene Generation Coupling the high-fidelity generation capabilities of label-conditional image synthesis methods with the f

Samaneh Azadi 41 Nov 28, 2022
A GUI to automatically create a TOPAS-readable MLC simulation file

Python script to create a TOPAS-readable simulation file descriring a Multi-Leaf-Collimator. Builds the MLC using the data from a 3D .stl file.

Sebastian Schäfer 0 Jun 19, 2022
⚓ Eurybia monitor model drift over time and securize model deployment with data validation

View Demo · Documentation · Medium article 🔍 Overview Eurybia is a Python library which aims to help in : Detecting data drift and model drift Valida

MAIF 172 Dec 27, 2022
A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

ICT.MIRACLE lab 75 Dec 26, 2022
A best practice for tensorflow project template architecture.

A best practice for tensorflow project template architecture.

Mahmoud Gamal Salem 3.6k Dec 22, 2022
Spectrum Surveying: Active Radio Map Estimation with Autonomous UAVs

Spectrum Surveying: The Python code in this repository implements the simulations and plots the figures described in the paper “Spectrum Surveying: Ac

Universitetet i Agder 2 Dec 06, 2022
Only valid pull requests will be allowed. Use python only and readme changes will not be accepted.

❌ This repo is excluded from hacktoberfest This repo is for python beginners and contains lot of beginner python projects for practice. You can also s

Prajjwal Pathak 50 Dec 28, 2022
[RSS 2021] An End-to-End Differentiable Framework for Contact-Aware Robot Design

DiffHand This repository contains the implementation for the paper An End-to-End Differentiable Framework for Contact-Aware Robot Design (RSS 2021). I

Jie Xu 60 Jan 04, 2023
TensorFlow-based implementation of "Pyramid Scene Parsing Network".

PSPNet_tensorflow Important Code is fine for inference. However, the training code is just for reference and might be only used for fine-tuning. If yo

HsuanKung Yang 323 Dec 20, 2022
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
nextPARS, a novel Illumina-based implementation of in-vitro parallel probing of RNA structures.

nextPARS, a novel Illumina-based implementation of in-vitro parallel probing of RNA structures. Here you will find the scripts necessary to produce th

Jesse Willis 0 Jan 20, 2022
Fair Recommendation in Two-Sided Platforms

Fair Recommendation in Two-Sided Platforms

gourabgggg 1 Nov 10, 2021
YOLOv4-v3 Training Automation API for Linux

This repository allows you to get started with training a state-of-the-art Deep Learning model with little to no configuration needed! You provide your labeled dataset or label your dataset using our

BMW TechOffice MUNICH 626 Dec 31, 2022
Website which uses Deep Learning to generate horror stories.

Creepypasta - Text Generator Website which uses Deep Learning to generate horror stories. View Demo · View Website Repo · Report Bug · Request Feature

Dhairya Sharma 5 Oct 14, 2022
Moer Grounded Image Captioning by Distilling Image-Text Matching Model

Moer Grounded Image Captioning by Distilling Image-Text Matching Model Requirements Python 3.7 Pytorch 1.2 Prepare data Please use git clone --recurse

YE Zhou 60 Dec 16, 2022
Docker containers of baseline agents for the Crafter environment

Crafter Baselines This repository contains Docker containers for running various baselines on the Crafter environment. Reward Agents DreamerV2 based o

Danijar Hafner 17 Sep 25, 2022