{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Comparison between grid search and successive halving\n\nThis example compares the parameter search performed by\n:class:`~sklearn.model_selection.HalvingGridSearchCV` and\n:class:`~sklearn.model_selection.GridSearchCV`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom time import time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nfrom sklearn import datasets\nfrom sklearn.experimental import enable_halving_search_cv # noqa\nfrom sklearn.model_selection import GridSearchCV, HalvingGridSearchCV\nfrom sklearn.svm import SVC" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first define the parameter space for an :class:`~sklearn.svm.SVC`\nestimator, and compute the time required to train a\n:class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a\n:class:`~sklearn.model_selection.GridSearchCV` instance.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rng = np.random.RandomState(0)\nX, y = datasets.make_classification(n_samples=1000, random_state=rng)\n\ngammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]\nCs = [1, 10, 100, 1e3, 1e4, 1e5]\nparam_grid = {\"gamma\": gammas, \"C\": Cs}\n\nclf = SVC(random_state=rng)\n\ntic = time()\ngsh = HalvingGridSearchCV(\n estimator=clf, param_grid=param_grid, factor=2, random_state=rng\n)\ngsh.fit(X, y)\ngsh_time = time() - tic\n\ntic = time()\ngs = GridSearchCV(estimator=clf, param_grid=param_grid)\ngs.fit(X, y)\ngs_time = time() - tic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now plot heatmaps for both search estimators.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def make_heatmap(ax, gs, is_sh=False, make_cbar=False):\n \"\"\"Helper to make a heatmap.\"\"\"\n results = pd.DataFrame(gs.cv_results_)\n results[[\"param_C\", \"param_gamma\"]] = results[[\"param_C\", \"param_gamma\"]].astype(\n np.float64\n )\n if is_sh:\n # SH dataframe: get mean_test_score values for the highest iter\n scores_matrix = results.sort_values(\"iter\").pivot_table(\n index=\"param_gamma\",\n columns=\"param_C\",\n values=\"mean_test_score\",\n aggfunc=\"last\",\n )\n else:\n scores_matrix = results.pivot(\n index=\"param_gamma\", columns=\"param_C\", values=\"mean_test_score\"\n )\n\n im = ax.imshow(scores_matrix)\n\n ax.set_xticks(np.arange(len(Cs)))\n ax.set_xticklabels([\"{:.0E}\".format(x) for x in Cs])\n ax.set_xlabel(\"C\", fontsize=15)\n\n ax.set_yticks(np.arange(len(gammas)))\n ax.set_yticklabels([\"{:.0E}\".format(x) for x in gammas])\n ax.set_ylabel(\"gamma\", fontsize=15)\n\n # Rotate the tick labels and set their alignment.\n plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n\n if is_sh:\n iterations = results.pivot_table(\n index=\"param_gamma\", columns=\"param_C\", values=\"iter\", aggfunc=\"max\"\n ).values\n for i in range(len(gammas)):\n for j in range(len(Cs)):\n ax.text(\n j,\n i,\n iterations[i, j],\n ha=\"center\",\n va=\"center\",\n color=\"w\",\n fontsize=20,\n )\n\n if make_cbar:\n fig.subplots_adjust(right=0.8)\n cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n fig.colorbar(im, cax=cbar_ax)\n cbar_ax.set_ylabel(\"mean_test_score\", rotation=-90, va=\"bottom\", fontsize=15)\n\n\nfig, axes = plt.subplots(ncols=2, sharey=True)\nax1, ax2 = axes\n\nmake_heatmap(ax1, gsh, is_sh=True)\nmake_heatmap(ax2, gs, make_cbar=True)\n\nax1.set_title(\"Successive Halving\\ntime = {:.3f}s\".format(gsh_time), fontsize=15)\nax2.set_title(\"GridSearch\\ntime = {:.3f}s\".format(gs_time), fontsize=15)\n\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The heatmaps show the mean test score of the parameter combinations for an\n:class:`~sklearn.svm.SVC` instance. The\n:class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the\niteration at which the combinations where last used. The combinations marked\nas ``0`` were only evaluated at the first iteration, while the ones with\n``5`` are the parameter combinations that are considered the best ones.\n\nWe can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`\nclass is able to find parameter combinations that are just as accurate as\n:class:`~sklearn.model_selection.GridSearchCV`, in much less time.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }