{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Demonstration of multi-metric evaluation on cross_val_score and GridSearchCV\n\nMultiple metric parameter search can be done by setting the ``scoring``\nparameter to a list of metric scorer names or a dict mapping the scorer names\nto the scorer callables.\n\nThe scores of all the scorers are available in the ``cv_results_`` dict at keys\nending in ``'_'`` (``'mean_test_precision'``,\n``'rank_test_precision'``, etc...)\n\nThe ``best_estimator_``, ``best_index_``, ``best_score_`` and ``best_params_``\ncorrespond to the scorer (key) that is set to the ``refit`` attribute.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom sklearn.datasets import make_hastie_10_2\nfrom sklearn.metrics import accuracy_score, make_scorer\nfrom sklearn.model_selection import GridSearchCV\nfrom sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running ``GridSearchCV`` using multiple evaluation metrics\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X, y = make_hastie_10_2(n_samples=8000, random_state=42)\n\n# The scorers can be either one of the predefined metric strings or a scorer\n# callable, like the one returned by make_scorer\nscoring = {\"AUC\": \"roc_auc\", \"Accuracy\": make_scorer(accuracy_score)}\n\n# Setting refit='AUC', refits an estimator on the whole dataset with the\n# parameter setting that has the best cross-validated AUC score.\n# That estimator is made available at ``gs.best_estimator_`` along with\n# parameters like ``gs.best_score_``, ``gs.best_params_`` and\n# ``gs.best_index_``\ngs = GridSearchCV(\n DecisionTreeClassifier(random_state=42),\n param_grid={\"min_samples_split\": range(2, 403, 20)},\n scoring=scoring,\n refit=\"AUC\",\n n_jobs=2,\n return_train_score=True,\n)\ngs.fit(X, y)\nresults = gs.cv_results_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting the result\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(figsize=(13, 13))\nplt.title(\"GridSearchCV evaluating using multiple scorers simultaneously\", fontsize=16)\n\nplt.xlabel(\"min_samples_split\")\nplt.ylabel(\"Score\")\n\nax = plt.gca()\nax.set_xlim(0, 402)\nax.set_ylim(0.73, 1)\n\n# Get the regular numpy array from the MaskedArray\nX_axis = np.array(results[\"param_min_samples_split\"].data, dtype=float)\n\nfor scorer, color in zip(sorted(scoring), [\"g\", \"k\"]):\n for sample, style in ((\"train\", \"--\"), (\"test\", \"-\")):\n sample_score_mean = results[\"mean_%s_%s\" % (sample, scorer)]\n sample_score_std = results[\"std_%s_%s\" % (sample, scorer)]\n ax.fill_between(\n X_axis,\n sample_score_mean - sample_score_std,\n sample_score_mean + sample_score_std,\n alpha=0.1 if sample == \"test\" else 0,\n color=color,\n )\n ax.plot(\n X_axis,\n sample_score_mean,\n style,\n color=color,\n alpha=1 if sample == \"test\" else 0.7,\n label=\"%s (%s)\" % (scorer, sample),\n )\n\n best_index = np.nonzero(results[\"rank_test_%s\" % scorer] == 1)[0][0]\n best_score = results[\"mean_test_%s\" % scorer][best_index]\n\n # Plot a dotted vertical line at the best score for that scorer marked by x\n ax.plot(\n [\n X_axis[best_index],\n ]\n * 2,\n [0, best_score],\n linestyle=\"-.\",\n color=color,\n marker=\"x\",\n markeredgewidth=3,\n ms=8,\n )\n\n # Annotate the best score for that scorer\n ax.annotate(\"%0.2f\" % best_score, (X_axis[best_index], best_score + 0.005))\n\nplt.legend(loc=\"best\")\nplt.grid(False)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }