{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# ROC Curve with Visualization API\nScikit-learn defines a simple API for creating visualizations for machine\nlearning. The key features of this API is to allow for quick plotting and\nvisual adjustments without recalculation. In this example, we will demonstrate\nhow to use the visualization API by comparing ROC curves.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Data and Train a SVC\nFirst, we load the wine dataset and convert it to a binary classification\nproblem. Then, we train a support vector classifier on a training dataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nfrom sklearn.datasets import load_wine\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import RocCurveDisplay\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.svm import SVC\n\nX, y = load_wine(return_X_y=True)\ny = y == 2\n\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\nsvc = SVC(random_state=42)\nsvc.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting the ROC Curve\nNext, we plot the ROC curve with a single call to\n:func:`sklearn.metrics.RocCurveDisplay.from_estimator`. The returned\n`svc_disp` object allows us to continue using the already computed ROC curve\nfor the SVC in future plots.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a Random Forest and Plotting the ROC Curve\nWe train a random forest classifier and create a plot comparing it to the SVC\nROC curve. Notice how `svc_disp` uses\n:func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve\nwithout recomputing the values of the roc curve itself. Furthermore, we\npass `alpha=0.8` to the plot functions to adjust the alpha values of the\ncurves.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rfc = RandomForestClassifier(n_estimators=10, random_state=42)\nrfc.fit(X_train, y_train)\nax = plt.gca()\nrfc_disp = RocCurveDisplay.from_estimator(rfc, X_test, y_test, ax=ax, alpha=0.8)\nsvc_disp.plot(ax=ax, alpha=0.8)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }