{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Visualizations with Display Objects\n\n.. currentmodule:: sklearn.metrics\n\nIn this example, we will construct display objects,\n:class:`ConfusionMatrixDisplay`, :class:`RocCurveDisplay`, and\n:class:`PrecisionRecallDisplay` directly from their respective metrics. This\nis an alternative to using their corresponding plot functions when\na model's predictions are already computed or expensive to compute. Note that\nthis is advanced usage, and in general we recommend using their respective\nplot functions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Data and train model\nFor this example, we load a blood transfusion service center data set from\n[OpenML](https://www.openml.org/d/1464). This is a binary classification\nproblem where the target is whether an individual donated blood. Then the\ndata is split into a train and test dataset and a logistic regression is\nfitted with the train dataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import StandardScaler\n\nX, y = fetch_openml(data_id=1464, return_X_y=True)\nX_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)\n\nclf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))\nclf.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create :class:`ConfusionMatrixDisplay`\nWith the fitted model, we compute the predictions of the model on the test\ndataset. These predictions are used to compute the confusion matrix which\nis plotted with the :class:`ConfusionMatrixDisplay`\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix\n\ny_pred = clf.predict(X_test)\ncm = confusion_matrix(y_test, y_pred)\n\ncm_display = ConfusionMatrixDisplay(cm).plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create :class:`RocCurveDisplay`\nThe roc curve requires either the probabilities or the non-thresholded\ndecision values from the estimator. Since the logistic regression provides\na decision function, we will use it to plot the roc curve:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.metrics import RocCurveDisplay, roc_curve\n\ny_score = clf.decision_function(X_test)\n\nfpr, tpr, _ = roc_curve(y_test, y_score, pos_label=clf.classes_[1])\nroc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create :class:`PrecisionRecallDisplay`\nSimilarly, the precision recall curve can be plotted using `y_score` from\nthe prevision sections.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve\n\nprec, recall, _ = precision_recall_curve(y_test, y_score, pos_label=clf.classes_[1])\npr_display = PrecisionRecallDisplay(precision=prec, recall=recall).plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combining the display objects into a single plot\nThe display objects store the computed values that were passed as arguments.\nThis allows for the visualizations to be easliy combined using matplotlib's\nAPI. In the following example, we place the displays next to each other in a\nrow.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\n\nroc_display.plot(ax=ax1)\npr_display.plot(ax=ax2)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }