An intuitive library to add plotting functionality to scikit-learn objects.

Overview

Welcome to Scikit-plot

PyPI version license Build Status PyPI DOI

Single line functions for detailed visualizations

The quickest and easiest way to go from analysis...

roc_curves

...to this.

Scikit-plot is the result of an unartistic data scientist's dreadful realization that visualization is one of the most crucial components in the data science process, not just a mere afterthought.

Gaining insights is simply a lot easier when you're looking at a colored heatmap of a confusion matrix complete with class labels rather than a single-line dump of numbers enclosed in brackets. Besides, if you ever need to present your results to someone (virtually any time anybody hires you to do data science), you show them visualizations, not a bunch of numbers in Excel.

That said, there are a number of visualizations that frequently pop up in machine learning. Scikit-plot is a humble attempt to provide aesthetically-challenged programmers (such as myself) the opportunity to generate quick and beautiful graphs and plots with as little boilerplate as possible.

Okay then, prove it. Show us an example.

Say we use Naive Bayes in multi-class classification and decide we want to visualize the results of a common classification metric, the Area under the Receiver Operating Characteristic curve. Since the ROC is only valid in binary classification, we want to show the respective ROC of each class if it were the positive class. As an added bonus, let's show the micro-averaged and macro-averaged curve in the plot as well.

Let's use scikit-plot with the sample digits dataset from scikit-learn.

# The usual train-test split mumbo-jumbo
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB

X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
nb = GaussianNB()
nb.fit(X_train, y_train)
predicted_probas = nb.predict_proba(X_test)

# The magic happens here
import matplotlib.pyplot as plt
import scikitplot as skplt
skplt.metrics.plot_roc(y_test, predicted_probas)
plt.show()

roc_curves

Pretty.

And... That's it. Encaptured in that small example is the entire philosophy of Scikit-plot: single line functions for detailed visualization. You simply browse the plots available in the documentation, and call the function with the necessary arguments. Scikit-plot tries to stay out of your way as much as possible. No unnecessary bells and whistles. And when you do need the bells and whistles, each function offers a myriad of parameters for customizing various elements in your plots.

Finally, compare and view the non-scikit-plot way of plotting the multi-class ROC curve. Which one would you rather do?

Maximum flexibility. Compatibility with non-scikit-learn objects.

Although Scikit-plot is loosely based around the scikit-learn interface, you don't actually need Scikit-learn objects to use the available functions. As long as you provide the functions what they're asking for, they'll happily draw the plots for you.

Here's a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset.

# Import what's needed for the Functions API
import matplotlib.pyplot as plt
import scikitplot as skplt

# This is a Keras classifier. We'll generate probabilities on the test set.
keras_clf.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=2)
probas = keras_clf.predict_proba(X_test, batch_size=64)

# Now plot.
skplt.metrics.plot_precision_recall_curve(y_test, probas)
plt.show()

p_r_curves

You can see clearly here that skplt.metrics.plot_precision_recall_curve needs only the ground truth y-values and the predicted probabilities to generate the plot. This lets you use anything you want as the classifier, from Keras NNs to NLTK Naive Bayes to that groundbreaking classifier algorithm you just wrote.

The possibilities are endless.

Installation

Installation is simple! First, make sure you have the dependencies Scikit-learn and Matplotlib installed.

Then just run:

pip install scikit-plot

Or if you want the latest development version, clone this repo and run

python setup.py install

at the root folder.

If using conda, you can install Scikit-plot by running:

conda install -c conda-forge scikit-plot

Documentation and Examples

Explore the full features of Scikit-plot.

You can find detailed documentation here.

Examples are found in the examples folder of this repo.

Contributing to Scikit-plot

Reporting a bug? Suggesting a feature? Want to add your own plot to the library? Visit our contributor guidelines.

Citing Scikit-plot

Are you using Scikit-plot in an academic paper? You should be! Reviewers love eye candy.

If so, please consider citing Scikit-plot with DOI DOI

APA

Reiichiro Nakano. (2018). reiinakano/scikit-plot: 0.3.7 [Data set]. Zenodo. http://doi.org/10.5281/zenodo.293191

IEEE

[1]Reiichiro Nakano, “reiinakano/scikit-plot: 0.3.7”. Zenodo, 19-Feb-2017.

ACM

[1]Reiichiro Nakano 2018. reiinakano/scikit-plot: 0.3.7. Zenodo.

Happy plotting!

Comments
  • Improve handling of unbalanced confusion matrices

    Improve handling of unbalanced confusion matrices

    Here I have made a few changes that make it easier to plot confusion matrices where the true and predicted sets of labels are not the same. This is a case that can occur when doing something like applying "new" categories to a dataset with an older set of categories.

    The changes included are the following:

    Fix an issue with nan values showing up when unbalanced confusion matrices are normalized. Where rows with zero entries would sum to zero and then divide by zero when normalizing each cell.

    Add options to limit the labels displayed on the true and predicted axes, as with unbalanced confusion matrices some of the labels can be only in the set of true labels or only in the set of predicted labels.

    You can see the effect of the new options here:

    import numpy as np
    import matplotlib.pyplot as plt
    import scikitplot as sciplt
    
    y_true = np.array(["A", "A", "B", "B", "B", "C", "D"])
    y_pred = np.array(["A", "A", "Ba", "Bb", "Ba", "C", "D"])
    
    print(y_true.shape)
    print(y_pred.shape)
    
    true_labels = np.unique(y_true)
    pred_labels = np.unique(y_pred)
    
    labels = np.sort(np.unique(np.concatenate([true_labels, pred_labels])))
    
    true_label_indexes = np.where(np.isin(labels, true_labels))
    pred_label_indexes = np.where(np.isin(labels, pred_labels))
    
    sciplt.plotters.plot_confusion_matrix(y_true, y_pred, hide_zeros=True, normalize=True, true_label_indexes=true_label_indexes, pred_label_indexes=pred_label_indexes, labels=labels)
    plt.show()
    

    figure_1

    opened by ExcaliburZero 12
  • 0.2.3 to 0.2.6 update failed

    0.2.3 to 0.2.6 update failed

    I've just tried to upgrade the package, but it gave the following error:

    Collecting scikit-plot
      Using cached scikit-plot-0.2.6.tar.gz
        Complete output from command python setup.py egg_info:
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-build-7wut1485/scikit-plot/setup.py", line 9, in <module>
            import scikitplot
          File "/tmp/pip-build-7wut1485/scikit-plot/scikitplot/__init__.py", line 5, in <module>
            from scikitplot.classifiers import classifier_factory
          File "/tmp/pip-build-7wut1485/scikit-plot/scikitplot/classifiers.py", line 5, in <module>
            import matplotlib.pyplot as plt
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/matplotlib/pyplot.py", line 115, in <module>
            _backend_mod, new_figure_manager, draw_if_interactive, _show = pylab_setup()
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/matplotlib/backends/__init__.py", line 32, in pylab_setup
            globals(),locals(),[backend_name],0)
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/matplotlib/backends/backend_tkagg.py", line 6, in <module>
            from six.moves import tkinter as Tk
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/six.py", line 92, in __get__
            result = self._resolve()
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/six.py", line 115, in _resolve
            return _import_module(self.mod)
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/envs/work-3.6.1/lib/python3.6/site-packages/six.py", line 82, in _import_module
            __import__(name)
          File "/home/paulo/Programs/repos/pyenv/versions/3.6.1/lib/python3.6/tkinter/__init__.py", line 36, in <module>
            import _tkinter # If this fails your Python may not be configured for Tk
        ModuleNotFoundError: No module named '_tkinter'
        
        ----------------------------------------
    Command "python setup.py egg_info" failed with error code 1 in /tmp/pip-build-7wut1485/scikit-plot/
    

    Unfortunately, I don't know how to debug this problem. If you need some info, please don't hesitate to ask!

    opened by paulochf 10
  • Adding a parameter to plot_confusion_matrix() to hide overlaid counts

    Adding a parameter to plot_confusion_matrix() to hide overlaid counts

    Hi @reiinakano,

    Thank you for this great repo! I am using plot_confusion_matrix() but my counts are quite large so the overlaid counts end up overlapping each other and result in a cluttered plot. I was wondering if I could submit a pull request to update this function to add a hide_counts parameter to give the option to not plot the counts? I've already forked and created a branch with the changes. Thank you!

    opened by echan5 7
  • Plot ONLY one class

    Plot ONLY one class

    Hello i have a precision-recall curve where i plot as the following:

    skplt.metrics.plot_precision_recall_curve(y_test, y_probas, curves=['each_class'])
    

    I have two classes in the data (one positive and one negative class with labels 1 and -1 respectively). Questions: How can I plot ONLY the positive class?

    Thank you

    enhancement help wanted 
    opened by foo123 7
  • Plot precision-recall curve for support vector machine classifier

    Plot precision-recall curve for support vector machine classifier

    Hello I want to plot a precision-recall curve for SVC (support vector machine classifier), but the scikit-learn svm classifier does not implement a predict_proba method. How can I do that in scikit-plot (as far as I can see in the documentation it accepts prediction probabilities to plot the curve)?

    Note that the scikit-learn documentation page has an example of precision-recall curve for SVC

    Thank you, Nikos

    opened by foo123 6
  • Add Jupyter notebook examples

    Add Jupyter notebook examples

    It would be nice to have Jupyter notebooks in the "examples" folder showing the different plots as used in a Jupyter notebook. It could contain the same exact code as the examples in the .py files, but adjusted for size (Jupyter notebook plots tend to come out much smaller).

    easy 
    opened by reiinakano 6
  • Update to plot_confusion_matrix (figsize argument and to work if Seaborn is used)

    Update to plot_confusion_matrix (figsize argument and to work if Seaborn is used)

    Using the confusion matrix in a jupyter notebook returns a plot that is quite small. If Seaborn is also used, some values in the plot are hard to read (white text on white lines).

    I have added a figsize-argument to the plot_confusion_matrix and changed the way that values are displayed (now with neutral background box). For larger plots all text-elements scale with the figsize-argument.

    opened by frankherfert 6
  • Adding argument to allow the user to specify which roc_curve are plotted

    Adding argument to allow the user to specify which roc_curve are plotted

    You tagged this issue #14 as help wanted so I thought I'd pitch in. Feel free to edit if it doesn't match your styIe.

    I added a little bit of code to allow the user to pass a list to the roc curve plotting functions to allow them to suppress/show each of the three types of curves: class-specific curves, micro averages, and macro averages.

    opened by doug-friedman 5
  • Error installing No module named sklearn.metrics

    Error installing No module named sklearn.metrics

    Hi there, I am getting an error installing it

    pip install scikit-plot                                                              ~ 1
    Collecting scikit-plot
      Downloading scikit-plot-0.2.1.tar.gz
        Complete output from command python setup.py egg_info:
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "c:\users\arthur\.babun\cygwin\tmp\pip-build-yrgynz\scikit-plot\setup.py", line 9, in <module>
            import scikitplot
          File "c:\users\arthur\.babun\cygwin\tmp\pip-build-yrgynz\scikit-plot\scikitplot\__init__.py", line 5, in <module>
            from scikitplot.classifiers import classifier_factory
          File "c:\users\arthur\.babun\cygwin\tmp\pip-build-yrgynz\scikit-plot\scikitplot\classifiers.py", line 7, in <module>
            from scikitplot import plotters
          File "c:\users\arthur\.babun\cygwin\tmp\pip-build-yrgynz\scikit-plot\scikitplot\plotters.py", line 9, in <module>
            from sklearn.metrics import confusion_matrix
        ImportError: No module named sklearn.metrics
    
        ----------------------------------------
    Command "python setup.py egg_info" failed with error code 1 in c:\users\arthur\.babun\cygwin\tmp\pip-build-yrgynz\scikit-plot\
    
    opened by ArthurZ 5
  • Throws error

    Throws error "IndexError: too many indices for array" when trying to plot roc for binary classification

    For binary classification, when I input numpy arrays having test label and test probabilities, it throws the following error :

    
    y_true = np.array(ytest)
    y_probas = np.array(p_test)
    skplt.metrics.plot_roc_curve(y_true,y_probas)
    plt.show()
    
    IndexError                                Traceback (most recent call last)
    <ipython-input-49-1b02f082006a> in <module>()
    ----> 1 skplt.metrics.plot_roc_curve(y_true,y_probas)
          2 plt.show()
    
    
    /Users/tarun/anaconda/envs/gl-env/lib/python2.7/site-packages/scikitplot/metrics.pyc in plot_roc_curve(y_true, y_probas, title, curves, ax, figsize, cmap, title_fontsize, text_fontsize)
        247     roc_auc = dict()
        248     for i in range(len(classes)):
    --> 249         fpr[i], tpr[i], _ = roc_curve(y_true, probas[:, i],
        250                                       pos_label=classes[i])
        251         roc_auc[i] = auc(fpr[i], tpr[i])
    
    IndexError: too many indices for array
    
    opened by TarunTater 4
  • Class mismatch in skplt.plot_confusion_matrix when test has fewer classes than training

    Class mismatch in skplt.plot_confusion_matrix when test has fewer classes than training

    Hello, I have an issue when trying to plot a confusion matrix fewer classes in my test set than in training. The class with 12 000+ occcurences in my sample should be labelled 'O' is it possible to get around this, or to include the label set manually as an input?

    image it's not a big issue but would be nice if we could fix it. Thanks for your help

    opened by ArmandGiraud 4
  • Regarding the scikit-plot.metrics.plot_roc function

    Regarding the scikit-plot.metrics.plot_roc function

    In you code I noticed that if we pass classes in the form of their actual meaning instead of (0,1,2 .. ) and we pass it as (c,b,a) then np.unique(y_true) makes the classes in the form of its alphabetical format and this changes the position of the classes that the model was trained on classes = np.unique(y_true)

    fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i], pos_label=classes[i])

    Hence if you could add a parameter of class_labels in the function

    def plot_roc_multi(y_true, y_probas,class_labels, title='ROC Curves', plot_micro=True, plot_macro=True, classes_to_plot=None, ax=None, figsize=None, cmap='nipy_spectral', title_fontsize="large", text_fontsize="medium"):

    where class_labels is in the form of an array [a,b,c] it would be much easier I think

    opened by Akshay1-6180 1
  • add class prediction error plot

    add class prediction error plot

    This is just a heads up for the programmers that this is a dead library and it functionalities may break any time. There has been not a single update form 2018 and most probably there will be no bug-fix or feature implementations.

    Either implement your own plotting functions or look at other active modules such as yellowbrick.

    scikit-plot is dead and it might break any time, if you use this in production.

    opened by bhishanpdl 0
  • Adding class_names option to gain and lift plots

    Adding class_names option to gain and lift plots

    I'd like to be able to set the names of the classes used in the legend in the the plot_lift_curve() and plot_cumulative_gain() functions.

    I've made a pull request (https://github.com/reiinakano/scikit-plot/pull/109) that adds an optional arg class_names to each of these functions to accomplish this.

    This differs from issue https://github.com/reiinakano/scikit-plot/issues/78 in that I am trying to change the labels in the plot's legend, not omit one of the classes from the plot.

    opened by MichaelFishmanOD 0
Releases(v0.3.7)
  • v0.3.7(Aug 19, 2018)

  • v0.3.5(May 12, 2018)

    New features:

    • plot_precision_recall_curve and plot_roc_curve have been deprecated for plot_precision_recall and plot_roc, respectively. The major difference is the deletion of the curves parameter and the use of plot_macro, plot_micro, and classes_to_plot to choose which curves should be plotted. Thanks to @lugq1990 for this change.
    Source code(tar.gz)
    Source code(zip)
  • v0.3.4(Feb 5, 2018)

  • v0.3.3(Oct 26, 2017)

  • v0.3.2(Oct 25, 2017)

    New Features

    • Gain Chart and Lift Chart added to scikitplot.metrics module #71
    • Updated Jupyter notebook examples for v0.3.x by @ljvmiranda921 #69

    Bugfix

    • Changed deprecated spectral colormap to nipy_spectral by @emredjan #66
    Source code(tar.gz)
    Source code(zip)
  • v0.3.1(Sep 17, 2017)

  • v0.3.0(Sep 13, 2017)

    New features:

    • plot_learning_curve has new parameter scoring to allow custom scoring functions. By @jengelman
    • New plotting function plot_calibration_curves

    Deprecations

    • The Factory API has been deprecated and will be removed in v0.4.0
    • scikitplot.plotters has been deprecated and the functions in the Functions API have been distributed to various new modules. See documentation for more details.
    Source code(tar.gz)
    Source code(zip)
  • v0.2.8(Sep 8, 2017)

    Features

    • New option hide_zeros for plot_confusion_matrix by @ExcaliburZero. #39
    • New option to plot only certain labels in plot_confusion_matrix by @ExcaliburZero. #41
    • New options to set colormaps for plot_pca_2d_projection, plot_silhouette, plot_precision_recall_curve, plot_roc_curve, and plot_confusion_matrix. #50

    Bugfix:

    • Fixed bug with nan values in confusion matrices by @ExcaliburZero (#42)
    Source code(tar.gz)
    Source code(zip)
  • v0.2.7(Jul 9, 2017)

  • v0.2.6(May 17, 2017)

  • v0.2.5(Apr 30, 2017)

  • v0.2.4(Apr 25, 2017)

  • v0.2.3(Mar 19, 2017)

    New features:

    • plot_precision_recall_curve and plot_ks_statistic now have a new curves argument that allows the user to choose which curves should be plotted. Thanks to @doug-friedman for this PR.
    • Jupyter notebook examples are now available thanks to @lstmemery
    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Feb 26, 2017)

    New features:

    • plot_pca_2d_projection function
    • plot_pca_component_variance function
    • plots now have a figsize, title_fontsize, and text_fontsize feature to allow user to customize the size of the plot. This is particularly crucial for Jupyter notebook users where the default settings come out too small. Thanks to @frankherfert for this idea.
    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Feb 19, 2017)

  • v0.2.0(Feb 18, 2017)

  • v0.1.0(Feb 17, 2017)

Owner
Reiichiro Nakano
I like working on awesome things with awesome people!
Reiichiro Nakano
基于python爬虫爬取COVID-19爆发开始至今全球疫情数据并利用Echarts对数据进行分析与多样化展示。

COVID-19-Epidemic-Map 基于python爬虫爬取COVID-19爆发开始至今全球疫情数据并利用Echarts对数据进行分析与多样化展示。 觉得项目还不错的话欢迎给一个star! 项目的源码可以正常运行,各个库的版本、数据库的建表语句、运行过程中遇到的坑以及解决方式在笔记.md中都

31 Dec 15, 2022
Material for dataviz course at university of Bordeaux

Material for dataviz course at university of Bordeaux

Nicolas P. Rougier 50 Jul 17, 2022
Displaying plot of death rates from past years in Poland. Data source from these years is in readme

Average-Death-Rate Displaying plot of death rates from past years in Poland The goal collect the data from a CSV file count the ADR (Average Death Rat

Oliwier Szymański 0 Sep 12, 2021
This is my favourite function - the Rastrigin function.

This is my favourite function - the Rastrigin function. What sparked my curiosity and interest in the function was its complexity in terms of many local optimum points, which makes it particularly in

1 Dec 27, 2021
nptsne is a numpy compatible python binary package that offers a number of APIs for fast tSNE calculation.

nptsne nptsne is a numpy compatible python binary package that offers a number of APIs for fast tSNE calculation and HSNE modelling. For more detail s

Biomedical Visual Analytics Unit LUMC - TU Delft 29 Jul 05, 2022
Fast 1D and 2D histogram functions in Python

About Sometimes you just want to compute simple 1D or 2D histograms with regular bins. Fast. No nonsense. Numpy's histogram functions are versatile, a

Thomas Robitaille 237 Dec 18, 2022
Visualizing weather changes across the world using third party APIs and Python.

WEATHER FORECASTING ACROSS THE WORLD Overview Python scripts were created to visualize the weather for over 500 cities across the world at varying di

G Johnson 0 Jun 12, 2021
An interactive UMAP visualization of the MNIST data set.

Code for an interactive UMAP visualization of the MNIST data set. Demo at https://grantcuster.github.io/umap-explorer/. You can read more about the de

grant 70 Dec 27, 2022
Designed a greedy algorithm based on Markov sequential decision-making process in MATLAB/Python to optimize using Gurobi solver

Designed a greedy algorithm based on Markov sequential decision-making process in MATLAB/Python to optimize using Gurobi solver, the wheel size, gear shifting sequence by modeling drivetrain constrai

Sabbella Prasanna 1 Jan 11, 2022
A Bokeh project developed for learning and teaching Bokeh interactive plotting!

Bokeh-Python-Visualization A Bokeh project developed for learning and teaching Bokeh interactive plotting! See my medium blog posts about making bokeh

Will Koehrsen 350 Dec 05, 2022
a python function to plot a geopandas dataframe

Pretty GeoDataFrame A minimum python function (~60 lines) to draw pretty geodataframe. Based on matplotlib, shapely, descartes. Installation just use

haoming 27 Dec 05, 2022
eoplatform is a Python package that aims to simplify Remote Sensing Earth Observation by providing actionable information on a wide swath of RS platforms and provide a simple API for downloading and visualizing RS imagery

An Earth Observation Platform Earth Observation made easy. Report Bug | Request Feature About eoplatform is a Python package that aims to simplify Rem

Matthew Tralka 4 Aug 11, 2022
Insert SVGs into matplotlib

Insert SVGs into matplotlib

Andrew White 35 Dec 29, 2022
HW_02 Data visualisation task

HW_02 Data visualisation and Matplotlib practice Instructions for HW_02 Idea for data analysis As I was brainstorming ideas and running through databa

9 Dec 13, 2022
This is Pygrr PolyArt, a program used for drawing custom Polygon models for your Pygrr project!

This is Pygrr PolyArt, a program used for drawing custom Polygon models for your Pygrr project!

Isaac 4 Dec 14, 2021
Tidy data structures, summaries, and visualisations for missing data

naniar naniar provides principled, tidy ways to summarise, visualise, and manipulate missing data with minimal deviations from the workflows in ggplot

Nicholas Tierney 611 Dec 22, 2022
Profile and test to gain insights into the performance of your beautiful Python code

Profile and test to gain insights into the performance of your beautiful Python code View Demo - Report Bug - Request Feature QuickPotato in a nutshel

Joey Hendricks 138 Dec 06, 2022
Manim is an animation engine for explanatory math videos.

A community-maintained Python framework for creating mathematical animations.

12.4k Dec 30, 2022
Library for exploring and validating machine learning data

TensorFlow Data Validation TensorFlow Data Validation (TFDV) is a library for exploring and validating machine learning data. It is designed to be hig

688 Jan 03, 2023
Bar Chart of the number of Senators from each party who are up for election in the next three General Elections

Congress-Analysis Bar Chart of the number of Senators from each party who are up for election in the next three General Elections This bar chart shows

11 Oct 26, 2021