Feature extraction made simple with torchextractor

Overview

torchextractor: PyTorch Intermediate Feature Extraction

PyPI - Python Version PyPI Read the Docs Upload Python Package GitHub

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

Read the documentation

FAQ

• How do I know the names of the modules?

You can print all module names like this:

tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Not yet, but we are working on it. Compiling registered hook of a module was just recently added in PyTorch v1.8.0.

• "One more thing!" 😉

By default we capture the latest output of the relevant modules, but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you can do the following:

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

To Build documentation:

cd docs
pip install requirements.txt
make html
You might also like...
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

This repository contains the code for our fast polygonal building extraction from overhead images pipeline.
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams
Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams

Adversarial Robustness Toolbox (ART) is a Python library for Machine Learning Security. ART provides tools that enable developers and researchers to defend and evaluate Machine Learning models and applications against the adversarial threats of Evasion, Poisoning, Extraction, and Inference. ART supports all popular machine learning frameworks (TensorFlow, Keras, PyTorch, MXNet, scikit-learn, XGBoost, LightGBM, CatBoost, GPy, etc.), all data types (images, tables, audio, video, etc.) and machine learning tasks (classification, object detection, speech recognition, generation, certification, etc.).

Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

Comments
  • Only extracting part of the intermediate feature with DataParallel

    Only extracting part of the intermediate feature with DataParallel

    Hi @antoinebrl,

    I am using torch.nn.DataParallel on a 2-GPU machine with a batch size of N. Data parallel training will split the input data batch into 2 pieces sequentially and sends them to GPUs.

    When using torchextractor to obtain the intermediate feature, the input data size and the output size are both N as expected, but the feature size becomes N/2. Does this mean we only extract the features of one GPU? I'm not sure because I didn't find an exact match.

    Can you please explain why this happens? Maybe the normal behavior is returning features from all GPUs or from a specified one?

    A minimal example to reproduce:

    import torch
    import torchvision
    import torchextractor as tx
    
    model = torchvision.models.resnet18(pretrained=True)
    model_gpu = torch.nn.DataParallel(torchvision.models.resnet18(pretrained=True))
    model_gpu.cuda()
    
    model = tx.Extractor(model, ["layer1"])
    model_gpu = tx.Extractor(model_gpu, ["module.layer1"])
    dummy_input = torch.rand(8, 3, 224, 224)
    _, features = model(dummy_input)
    _, features_gpu = model_gpu(dummy_input)
    feature_shapes = {name: f.shape for name, f in features.items()}
    print(feature_shapes)
    feature_shapes_gpu = {name: f.shape for name, f in features_gpu.items()}
    print(feature_shapes_gpu)
    
    # {'layer1': torch.Size([8, 64, 56, 56])}
    # {'module.layer1': torch.Size([4, 64, 56, 56])}
    
    opened by wydwww 5
Releases(v0.3.0)
A style-based Quantum Generative Adversarial Network

Style-qGAN A style based Quantum Generative Adversarial Network (style-qGAN) model for Monte Carlo event generation. Tutorial We have prepared a noteb

9 Nov 24, 2022
Go from graph data to a secure and interactive visual graph app in 15 minutes. Batteries-included self-hosting of graph data apps with Streamlit, Graphistry, RAPIDS, and more!

✔️ Linux ✔️ OS X ❌ Windows (#39) Welcome to graph-app-kit Turn your graph data into a secure and interactive visual graph app in 15 minutes! Why This

Graphistry 107 Jan 02, 2023
A Pytorch Implementation of [Source data‐free domain adaptation of object detector through domain

A Pytorch Implementation of Source data‐free domain adaptation of object detector through domain‐specific perturbation Please follow Faster R-CNN and

1 Dec 25, 2021
An all-in-one application to visualize multiple different local path planning algorithms

Table of Contents Table of Contents Local Planner Visualization Project (LPVP) Features Installation/Usage Local Planners Probabilistic Roadmap (PRM)

Abdur Javaid 47 Dec 30, 2022
Approximate Nearest Neighbors in C++/Python optimized for memory usage and loading/saving to disk

Annoy Annoy (Approximate Nearest Neighbors Oh Yeah) is a C++ library with Python bindings to search for points in space that are close to a given quer

Spotify 10.6k Jan 04, 2023
Implementation of ToeplitzLDA for spatiotemporal stationary time series data.

Code for the ToeplitzLDA classifier proposed in here. The classifier conforms sklearn and can be used as a drop-in replacement for other LDA classifiers. For in-depth usage refer to the learning from

Jan Sosulski 5 Nov 07, 2022
A library for using chemistry in your applications

Chemistry in python Resources Used The following items are not made by me! Click the words to go to the original source Periodic Tab Json - Used in -

Tech Penguin 28 Dec 17, 2021
Official Pytorch Implementation of 3DV2021 paper: SAFA: Structure Aware Face Animation.

SAFA: Structure Aware Face Animation (3DV2021) Official Pytorch Implementation of 3DV2021 paper: SAFA: Structure Aware Face Animation. Getting Started

QiulinW 122 Dec 23, 2022
Pytorch Implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension)

DiffSinger - PyTorch Implementation PyTorch implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension). Status

Keon Lee 152 Jan 02, 2023
LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

Simon Boehm 183 Jan 02, 2023
POT : Python Optimal Transport

POT: Python Optimal Transport This open source Python library provide several solvers for optimization problems related to Optimal Transport for signa

Python Optimal Transport 1.7k Dec 31, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

100 Dec 28, 2022
Supplementary materials to "Spin-optomechanical quantum interface enabled by an ultrasmall mechanical and optical mode volume cavity" by H. Raniwala, S. Krastanov, M. Eichenfield, and D. R. Englund, 2022

Supplementary materials to "Spin-optomechanical quantum interface enabled by an ultrasmall mechanical and optical mode volume cavity" by H. Raniwala,

Stefan Krastanov 1 Jan 17, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
R-Drop: Regularized Dropout for Neural Networks

R-Drop: Regularized Dropout for Neural Networks R-drop is a simple yet very effective regularization method built upon dropout, by minimizing the bidi

756 Dec 27, 2022
Joint-task Self-supervised Learning for Temporal Correspondence (NeurIPS 2019)

Joint-task Self-supervised Learning for Temporal Correspondence Project | Paper Overview Joint-task Self-supervised Learning for Temporal Corresponden

Sifei Liu 167 Dec 14, 2022
Frequency Domain Image Translation: More Photo-realistic, Better Identity-preserving

Frequency Domain Image Translation: More Photo-realistic, Better Identity-preserving This is the source code for our paper Frequency Domain Image Tran

Mu Cai 52 Dec 23, 2022
Cross-modal Retrieval using Transformer Encoder Reasoning Networks (TERN). With use of Metric Learning and FAISS for fast similarity search on GPU

Cross-modal Retrieval using Transformer Encoder Reasoning Networks This project reimplements the idea from "Transformer Reasoning Network for Image-Te

Minh-Khoi Pham 5 Nov 05, 2022
Convert ONNX model graph to Keras model format.

Convert ONNX model graph to Keras model format.

Grigory Malivenko 175 Dec 28, 2022
JupyterLite demo deployed to GitHub Pages 🚀

JupyterLite Demo JupyterLite deployed as a static site to GitHub Pages, for demo purposes. ✨ Try it in your browser ✨ ➡️ https://jupyterlite.github.io

JupyterLite 223 Jan 04, 2023