A python library to build Model Trees with Linear Models at the leaves.

Overview

linear-tree

A python library to build Model Trees with Linear Models at the leaves.

Overview

Linear Model Trees combine the learning ability of Decision Tree with the predictive and explicative power of Linear Models. Like in tree-based algorithms, the data are split according to simple decision rules. The goodness of slits is evaluated in gain terms fitting Linear Models in the nodes. This implies that the models in the leaves are linear instead of constant approximations like in classical Decision Trees.

linear-tree is developed to be fully integrable with scikit-learn. LinearTreeRegressor and LinearTreeClassifier are provided as scikit-learn BaseEstimator. They are wrappers that build a decision tree on the data fitting a linear estimator from sklearn.linear_model. All the models available in sklearn.linear_model can be used as linear estimators.

Installation

pip install linear-tree

The module depends on NumPy, SciPy and Scikit-Learn (>=0.23.0). Python 3.6 or above is supported.

Media

Usage

Regression
from sklearn.linear_model import LinearRegression
from lineartree import LinearTreeRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=4,
                       n_informative=2, n_targets=1,
                       random_state=0, shuffle=False)
regr = LinearTreeRegressor(base_estimator=LinearRegression())
regr.fit(X, y)
Classification
from sklearn.linear_model import RidgeClassifier
from lineartree import LinearTreeClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False)
clf = LinearTreeClassifier(base_estimator=RidgeClassifier())
clf.fit(X, y)

More examples in the notebooks folder.

Check the API Reference to see the parameter configurations and the available methods.

Examples

Show the model tree structure:

plot tree

Linear Tree Regressor at work:

linear tree regressor

Linear Tree Classifier at work:

linear tree classifier

Extract and examine coefficients at the leaves:

leaf coefficients

Comments
  • finding breakpoint

    finding breakpoint

    Hello,

    thank you for your nice tool. I am using the function LinearTreeRegressor to draw a continuous piecewise linear. It works well, I am wondering, is it possible to show the location (the coordinates) of the breakpoints?

    thank you

    opened by ZhengLiu1119 5
  • Allow the hyperparameter

    Allow the hyperparameter "max_depth = 0".

    Thanks for the good library.

    When using LinearTreeRegressor, I think that max_depth is often optimized by cross-validation.

    This library allows max_depth in the range 1-20. However, depending on the dataset, simple linear regression may be suitable. Even in such a dataset, max_depth is forced to be 1 or more, so Simple Linear Regression cannot be applied properly with LinearTreeRegressor.

    • Of course, it is appropriate to use sklearn.linear_model.LinearRegression for such datasets.

    My suggestion is to change to a program that uses base_estimator to perform regression when "max_depth = 0". With this change, LinearTreeRegressor can flexibly respond to both segmented regression and simple regression by changing hyperparameters.

    opened by jckkvs 4
  • Error when running with multiple jobs: unexpected keyword argument 'target_offload'

    Error when running with multiple jobs: unexpected keyword argument 'target_offload'

    I have been using your library for quite a while and am super happy with it. So first, thanks a lot!

    Lately, I used my framework (which also uses your library) on modern many core server with many jobs. Worked fine. Now I have updated everything via pip and with 8 jobs on my MacBook, I got the following error.

    This error does not occur when using only a single job (I pass the number of jobs to n_jobs).

    I cannot nail the down the actual problem, but since it occurred right after the upgrade, I assume this might be the reason?

    Am I doing something wrong here?

    """
    Traceback (most recent call last):
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 436, in _process_worker
        r = call_item()
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 288, in __call__
        return self.fn(*self.args, **self.kwargs)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 595, in __call__
        return self.func(*args, **kwargs)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in __call__
        for func, args, kwargs in self.items]
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in <listcomp>
        for func, args, kwargs in self.items]
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 56, in __call__
        with config_context(**self.config):
      File "/Users/martin/opt/anaconda3/lib/python3.7/contextlib.py", line 239, in helper
        return _GeneratorContextManager(func, args, kwds)
      File "/Users/martin/opt/anaconda3/lib/python3.7/contextlib.py", line 82, in __init__
        self.gen = func(*args, **kwds)
    TypeError: config_context() got an unexpected keyword argument 'target_offload'
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "compression_selection_pipeline.py", line 41, in <module>
        model_pipeline.learn_runtime_models(calibration_result_dir)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 670, in learn_runtime_models
        non_splitting_models("table_scan", table_scans)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 590, in non_splitting_models
        fitted_model = model_dict["model"].fit(X_train, y_train)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 209, in fit
        return self.regression.fit(X, y)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/lineartree.py", line 187, in fit
        self._fit(X, y, sample_weight)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 576, in _fit
        self._grow(X, y, sample_weight)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 387, in _grow
        loss=loss)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 285, in _split
        for feat in split_feat)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 1056, in __call__
        self.retrieve()
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 935, in retrieve
        self._output.extend(job.get(timeout=self.timeout))
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 542, in wrap_future_result
        return future.result(timeout=timeout)
      File "/Users/martin/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 435, in result
        return self.__get_result()
      File "/Users/martin/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 384, in __get_result
        raise self._exception
    TypeError: config_context() got an unexpected keyword argument 'target_offload'
    

    PS: I have already left a star. :D

    opened by Bouncner 3
  • Option to specify features to use for splitting and for leaf models

    Option to specify features to use for splitting and for leaf models

    Added two additional parameters:

    • split_features: Indices of features that can be used for splitting. Default all.
    • linear_features: Indices of features that are used by the linear models in the leaves. Default all except for categorical features

    This implements a feature requested in https://github.com/cerlymarco/linear-tree/issues/2

    Potential performance improvement: Currently the code still computes bins for all features and not only for those used for splitting.

    opened by JonasRauch 3
  • Rationale for rounding during _parallel_binning_fit and _grow

    Rationale for rounding during _parallel_binning_fit and _grow

    I noticed that the implementations of _parallel_binning_fit and _grow internally round loss values to 5 decimal places. This makes the regression results dependent on the scale of the labels, as data with a lower natural loss value will result in many different splits of the data having the same loss when rounded to 5 decimal places. Is there a reason why this is the case?

    This behavior can be observed by fitting a LinearTreeRegressor using the default loss function and multiplying the scale of the labels by a small number (like 1e-9). This will result in the regressor no longer learning any splits.

    opened by session-id 2
  • ValueError: Invalid parameter linearforestregression for estimator Pipeline

    ValueError: Invalid parameter linearforestregression for estimator Pipeline

    Great work! I'm new to ML and stuck with this. I'm trying to combine pipeline and GridSearch to search for best possible hyperparameters for a model.

    image

    I got the following error:

    image

    Kindly help : )

    opened by NousMei 2
  • Performance and possibility to split only on subset of features

    Performance and possibility to split only on subset of features

    Hey, I have been playing around a lot with your linear trees. Like them very much. Thanks!

    Nevertheless, I am somewhat disappointed by the runtime performance. Compared to XGBoost Regressors (I know it's not a fair comparison) or linear regressions (also not fair), the linear tree is reeeeeaally slow. 50k observations, 80 features: 2s for linear regression, 27s for XGBoost, and 300s for the linear tree. Have you seen similar runtimes or might I be using it wrong?

    Another aspects that's interesting to me is the question whether is possibe to limit the features which are used for splits. I haven't found it in the code. Any change to see it in the future?

    opened by Bouncner 2
  • export to graphviz  -AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    export to graphviz -AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    Hi

    thanks for writing this great package!

    I was trying to display the decision tree with graphviz I get this error

    AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    from lineartree import LinearTreeRegressor from sklearn.linear_model import LinearRegression

    reg = LinearTreeRegressor(base_estimator=LinearRegression()) reg.fit(train[x_cols], train["y"])

    from graphviz import Source from sklearn import tree

    graph = Source( tree.export_graphviz(reg, out_file=None,feature_names=train.columns))

    opened by ricmarchao 2
  • numpy deprecation warning

    numpy deprecation warning

    /lineartree/_classes.py:338: DeprecationWarning:

    the interpolation= argument to quantile was renamed to method=, which has additional options. Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they. (Deprecated NumPy 1.22)

    Seems like a quick update here would get this warning to stop showing up, right? I can always ignore it, but figured I would mention it in case it is actually an error on my side.

    Also, sorry, I don't actually what the best open source etiquette is. If I'm supposed to create a pull request with a proposed fix instead of just mentioning it then feel free to correct me.

    opened by paul-brenner 1
  • How to gridsearch tree and regression parameters?

    How to gridsearch tree and regression parameters?

    Hi, I am wondering how to perform a GridsearchCV to find best parameters for the tree and regression model? For now I am able to tune the tree component of my model:

    `

     param_grid={
        'n_estimators': [50, 100, 500, 700],
        'max_depth': [10, 20, 30, 50],
        'min_samples_split' : [2, 4, 8, 16, 32],
        'max_features' : ['sqrt', 'log2', None]
    }
    cv = RepeatedKFold(n_repeats=3,
                       n_splits=3,
                       random_state=1)
    
    model = GridSearchCV(
        LinearForestRegressor(ElasticNet(random_state = 0), random_state=42),
        param_grid=param_grid,
        n_jobs=-1,
        cv=cv,
        scoring='neg_root_mean_squared_error'
        )
    

    `

    opened by zuzannakarwowska 1
  • Potential bug in LinearForestClassifier 'predict_proba'

    Potential bug in LinearForestClassifier 'predict_proba'

    Hello! Thank you for useful package!

    I think I might have found a potential bug in LinearForestClassifier.

    I expected 'predict_proba' to use 'self.decision_function', similarly to 'predict' - to include predictions from both estimators (base + forest). Is that a potential bug or am I in wrong here?

    https://github.com/cerlymarco/linear-tree/blob/8d5beca8d492cb8c57e6618e3fb770860f28b550/lineartree/lineartree.py#L1560

    opened by PiotrKaszuba 1
Releases(0.3.5)
Owner
Marco Cerliani
Statistician Hacker & Data Scientist
Marco Cerliani
Rethinking the U-Net architecture for multimodal biomedical image segmentation

MultiResUNet Rethinking the U-Net architecture for multimodal biomedical image segmentation This repository contains the original implementation of "M

Nabil Ibtehaz 308 Jan 05, 2023
Volsdf - Volume Rendering of Neural Implicit Surfaces

Volume Rendering of Neural Implicit Surfaces Project Page | Paper | Data This re

Lior Yariv 221 Jan 07, 2023
Pytorch implementation of Each Part Matters: Local Patterns Facilitate Cross-view Geo-localization https://arxiv.org/abs/2008.11646

[TCSVT] Each Part Matters: Local Patterns Facilitate Cross-view Geo-localization LPN [Paper] NEWs Prerequisites Python 3.6 GPU Memory = 8G Numpy 1.

46 Dec 14, 2022
El-Gamal on Elliptic Curve (Python)

El-Gamal-on-EC El-Gamal on Elliptic Curve (Python) References: https://docsdrive.com/pdfs/ansinet/itj/2005/299-306.pdf https://arxiv.org/ftp/arxiv/pap

3 May 04, 2022
Polynomial-time Meta-Interpretive Learning

Louise - polynomial-time Program Learning Getting help with Louise Louise's author can be reached by email at Stassa Patsantzis 64 Dec 26, 2022

Code for our work "Activation to Saliency: Forming High-Quality Labels for Unsupervised Salient Object Detection".

A2S-USOD Code for our work "Activation to Saliency: Forming High-Quality Labels for Unsupervised Salient Object Detection". Code will be released upon

15 Dec 16, 2022
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
Toolchain to build Yoshi's Island from source code

Project-Y Toolchain to build Yoshi's Island (J) V1.0 from source code, by MrL314 Last updated: September 17, 2021 Setup To begin, download this toolch

MrL314 19 Apr 18, 2022
Transformer in Computer Vision

Transformer-in-Vision A paper list of some recent Transformer-based CV works. If you find some ignored papers, please open issues or pull requests. **

506 Dec 26, 2022
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
Keras-1D-ACGAN-Data-Augmentation

Keras-1D-ACGAN-Data-Augmentation What is the ACGAN(Auxiliary Classifier GANs) ? Related Paper : [Abstract : Synthesizing high resolution photorealisti

Jae-Hoon Shim 7 Dec 23, 2022
Title: Heart-Failure-Classification

This Notebook is based off an open source dataset available on where I have created models to classify patients who can potentially witness heart failure on the basis of various parameters. The best

Akarsh Singh 2 Sep 13, 2022
Header-only library for using Keras models in C++.

frugally-deep Use Keras models in C++ with ease Table of contents Introduction Usage Performance Requirements and Installation FAQ Introduction Would

Tobias Hermann 927 Jan 05, 2023
EasyMocap is an open-source toolbox for markerless human motion capture from RGB videos.

EasyMocap is an open-source toolbox for markerless human motion capture from RGB videos. In this project, we provide the basic code for fitt

ZJU3DV 2.2k Jan 05, 2023
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
Synthetic Humans for Action Recognition, IJCV 2021

SURREACT: Synthetic Humans for Action Recognition from Unseen Viewpoints Gül Varol, Ivan Laptev and Cordelia Schmid, Andrew Zisserman, Synthetic Human

Gul Varol 59 Dec 14, 2022
To SMOTE, or not to SMOTE?

To SMOTE, or not to SMOTE? This package includes the code required to repeat the experiments in the paper and to analyze the results. To SMOTE, or not

Amazon Web Services 1 Jan 03, 2022
PromptDet: Expand Your Detector Vocabulary with Uncurated Images

PromptDet: Expand Your Detector Vocabulary with Uncurated Images Paper Website Introduction The goal of this work is to establish a scalable pipeline

103 Dec 20, 2022
Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning"

Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning" Getting started Prerequisites CUD

70 Dec 02, 2022
Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch.

Faster R-CNN and Mask R-CNN in PyTorch 1.0 maskrcnn-benchmark has been deprecated. Please see detectron2, which includes implementations for all model

Facebook Research 9k Jan 04, 2023