{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Balance model complexity and cross-validated score\n\nThis example balances model complexity and cross-validated score by\nfinding a decent accuracy within 1 standard deviation of the best accuracy\nscore while minimising the number of PCA components [1].\n\nThe figure shows the trade-off between cross-validated score and the number\nof PCA components. The balanced case is when n_components=10 and accuracy=0.88,\nwhich falls into the range within 1 standard deviation of the best accuracy\nscore.\n\n[1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and\nSelection. The Elements of Statistical Learning (pp. 219-260). New York,\nNY, USA: Springer New York Inc..\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.datasets import load_digits\nfrom sklearn.decomposition import PCA\nfrom sklearn.model_selection import GridSearchCV\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.svm import LinearSVC\n\n\ndef lower_bound(cv_results):\n \"\"\"\n Calculate the lower bound within 1 standard deviation\n of the best `mean_test_scores`.\n\n Parameters\n ----------\n cv_results : dict of numpy(masked) ndarrays\n See attribute cv_results_ of `GridSearchCV`\n\n Returns\n -------\n float\n Lower bound within 1 standard deviation of the\n best `mean_test_score`.\n \"\"\"\n best_score_idx = np.argmax(cv_results[\"mean_test_score\"])\n\n return (\n cv_results[\"mean_test_score\"][best_score_idx]\n - cv_results[\"std_test_score\"][best_score_idx]\n )\n\n\ndef best_low_complexity(cv_results):\n \"\"\"\n Balance model complexity with cross-validated score.\n\n Parameters\n ----------\n cv_results : dict of numpy(masked) ndarrays\n See attribute cv_results_ of `GridSearchCV`.\n\n Return\n ------\n int\n Index of a model that has the fewest PCA components\n while has its test score within 1 standard deviation of the best\n `mean_test_score`.\n \"\"\"\n threshold = lower_bound(cv_results)\n candidate_idx = np.flatnonzero(cv_results[\"mean_test_score\"] >= threshold)\n best_idx = candidate_idx[\n cv_results[\"param_reduce_dim__n_components\"][candidate_idx].argmin()\n ]\n return best_idx\n\n\npipe = Pipeline(\n [\n (\"reduce_dim\", PCA(random_state=42)),\n (\"classify\", LinearSVC(random_state=42, C=0.01)),\n ]\n)\n\nparam_grid = {\"reduce_dim__n_components\": [6, 8, 10, 12, 14]}\n\ngrid = GridSearchCV(\n pipe,\n cv=10,\n n_jobs=1,\n param_grid=param_grid,\n scoring=\"accuracy\",\n refit=best_low_complexity,\n)\nX, y = load_digits(return_X_y=True)\ngrid.fit(X, y)\n\nn_components = grid.cv_results_[\"param_reduce_dim__n_components\"]\ntest_scores = grid.cv_results_[\"mean_test_score\"]\n\nplt.figure()\nplt.bar(n_components, test_scores, width=1.3, color=\"b\")\n\nlower = lower_bound(grid.cv_results_)\nplt.axhline(np.max(test_scores), linestyle=\"--\", color=\"y\", label=\"Best score\")\nplt.axhline(lower, linestyle=\"--\", color=\".5\", label=\"Best score - 1 std\")\n\nplt.title(\"Balance model complexity and cross-validated score\")\nplt.xlabel(\"Number of PCA components used\")\nplt.ylabel(\"Digit classification accuracy\")\nplt.xticks(n_components.tolist())\nplt.ylim((0, 1.0))\nplt.legend(loc=\"upper left\")\n\nbest_index_ = grid.best_index_\n\nprint(\"The best_index_ is %d\" % best_index_)\nprint(\"The n_components selected is %d\" % n_components[best_index_])\nprint(\n \"The corresponding accuracy score is %.2f\"\n % grid.cv_results_[\"mean_test_score\"][best_index_]\n)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }