Convert scikit-learn models to PyTorch modules

Related tags

Deep Learningsk2torch
Overview

sk2torch

sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript.

Problems solved by this project:

  1. scikit-learn cannot perform inference on a GPU. Models like SVMs have a lot to gain from fast GPU primitives, and converting the models to PyTorch gives immediate access to these primitives.
  2. While scikit-learn supports serialization through pickle, saved models are not reproducible across versions of the library. On the other hand, TorchScript provides a convenient, safe way to save a model with its corresponding implementation. The resulting models can be loaded anywhere that PyTorch is installed, even without importing sk2torch.
  3. While certain models like SVMs and linear classifiers are theoretically end-to-end differentiable, scikit-learn provides no mechanism to compute gradients through trained models. PyTorch provides this functionality mostly for free.

See Usage for a high-level example of using the library. See How it works to see which modules are supported.

For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by this script):

A vector field quiver plot with two modes

Usage

First, train a model with scikit-learn as usual:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

Then call sk2torch.wrap on the model to create a PyTorch equivalent:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

You can save a model with TorchScript:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... sk2torch need not be installed to load the model.
loaded_model = torch.jit.load("path.pt")

For a full example of training a model and using its PyTorch translation, see examples/svm_vector_field.py.

How it works

sk2torch contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator X, a class TorchX in sk2torch will be able to read the attributes of X and convert them to torch.Tensor or simple Python types. TorchX subclasses torch.nn.Module and has a method for each inference API of X (e.g. predict, decision_function, etc.).

Which modules are supported? The easiest way to get an up-to-date list is via the supported_classes() function, which returns all wrap()able scikit-learn classes:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

Comparison to sklearn-onnx

sklearn-onnx is an open source package for converting trained scikit-learn models into ONNX. Like sk2torch, sklearn-onnx re-implements inference functions for various models, meaning that it can also provide serialization and GPU acceleration for supported modules.

Naturally, neither library will support modules that aren't manually ported. As a result, the two libraries support different subsets of all available models/methods. For example, sk2torch supports the SVC probability prediction methods predict_proba and predict_log_prob, whereas sklearn-onnx does not.

While sklearn-onnx exports models to ONNX, sk2torch exports models to Python objects with familiar method names that can be fine-tuned, backpropagated through, and serialized in a user-friendly way. PyTorch is strictly more general than ONNX, since PyTorch models can be converted to ONNX if desired.

Owner
Alex Nichol
Web developer, math geek, and AI enthusiast.
Alex Nichol
CVPR 2021

Smoothing the Disentangled Latent Style Space for Unsupervised Image-to-image Translation [Paper] | [Poster] | [Codes] Yahui Liu1,3, Enver Sangineto1,

Yahui Liu 37 Sep 12, 2022
GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily

GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily Abstract Graph Neural Networks (GNNs) are widely used on a

10 Dec 20, 2022
Neural Geometric Level of Detail: Real-time Rendering with Implicit 3D Shapes (CVPR 2021 Oral)

Neural Geometric Level of Detail: Real-time Rendering with Implicit 3D Surfaces Official code release for NGLOD. For technical details, please refer t

659 Dec 27, 2022
Supplementary code for SIGGRAPH 2021 paper: Discovering Diverse Athletic Jumping Strategies

SIGGRAPH 2021: Discovering Diverse Athletic Jumping Strategies project page paper demo video Prerequisites Important Notes We suspect there are bugs i

54 Dec 06, 2022
Models Supported: AlbUNet [18, 34, 50, 101, 152] (1D and 2D versions for Single and Multiclass Segmentation, Feature Extraction with supports for Deep Supervision and Guided Attention)

AlbUNet-1D-2D-Tensorflow-Keras This repository contains 1D and 2D Signal Segmentation Model Builder for AlbUNet and several of its variants developed

Sakib Mahmud 1 Nov 15, 2021
This is the repository for The Machine Learning Workshops, published by AI DOJO

This is the repository for The Machine Learning Workshops, published by AI DOJO. It contains all the workshop's code with supporting project files necessary to work through the code.

AI Dojo 12 May 06, 2022
Code of Puregaze: Purifying gaze feature for generalizable gaze estimation, AAAI 2022.

PureGaze: Purifying Gaze Feature for Generalizable Gaze Estimation Description Our work is accpeted by AAAI 2022. Picture: We propose a domain-general

39 Dec 05, 2022
AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

4 Feb 13, 2022
I-BERT: Integer-only BERT Quantization

I-BERT: Integer-only BERT Quantization HuggingFace Implementation I-BERT is also available in the master branch of HuggingFace! Visit the following li

Sehoon Kim 139 Dec 27, 2022
Deep Learning and Logical Reasoning from Data and Knowledge

Logic Tensor Networks (LTN) Logic Tensor Network (LTN) is a neurosymbolic framework that supports querying, learning and reasoning with both rich data

171 Dec 29, 2022
TensorFlow implementation of PHM (Parameterization of Hypercomplex Multiplication)

Parameterization of Hypercomplex Multiplications (PHM) This repository contains the TensorFlow implementation of PHM (Parameterization of Hypercomplex

Aston Zhang 9 Oct 26, 2022
The code for our NeurIPS 2021 paper "Kernelized Heterogeneous Risk Minimization".

Kernelized-HRM Jiashuo Liu, Zheyuan Hu The code for our NeurIPS 2021 paper "Kernelized Heterogeneous Risk Minimization"[1]. This repo contains the cod

Liu Jiashuo 8 Nov 20, 2022
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil

AaltoML 165 Nov 27, 2022
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
Simulating an AI playing 2048 using the Expectimax algorithm

2048-expectimax Simulating an AI playing 2048 using the Expectimax algorithm The base game engine uses code from here. The AI player is modeled as a m

Subha Ramesh 2 Jan 31, 2022
Offline Reinforcement Learning with Implicit Q-Learning

Offline Reinforcement Learning with Implicit Q-Learning This repository contains the official implementation of Offline Reinforcement Learning with Im

Ilya Kostrikov 125 Dec 31, 2022
Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness

Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness This repository contains the code used for the exper

H.R. Oosterhuis 28 Nov 29, 2022
Sibur challange 2021 competition - 6 place

sibur challange 2021 Решение на 6 место: https://sibur.ai-community.com/competitions/5/tasks/13 Скор 1.4066/1.4159 public/private. Архитектура - однос

Ivan 5 Jan 11, 2022
Official Implementation of SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations

Official Implementation of SimIPU SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations Since

Zhyever 37 Dec 01, 2022
We utilize deep reinforcement learning to obtain favorable trajectories for visual-inertial system calibration.

Unified Data Collection for Visual-Inertial Calibration via Deep Reinforcement Learning Update: The lastest code will be updated in this branch. Pleas

ETHZ ASL 27 Dec 29, 2022