.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_successive_halving_heatmap.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_successive_halving_heatmap.py: Comparison between grid search and successive halving ===================================================== This example compares the parameter search performed by :class:`~sklearn.model_selection.HalvingGridSearchCV` and :class:`~sklearn.model_selection.GridSearchCV`. .. GENERATED FROM PYTHON SOURCE LINES 10-25 .. code-block:: Python # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause from time import time import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn import datasets from sklearn.experimental import enable_halving_search_cv # noqa from sklearn.model_selection import GridSearchCV, HalvingGridSearchCV from sklearn.svm import SVC .. GENERATED FROM PYTHON SOURCE LINES 26-30 We first define the parameter space for an :class:`~sklearn.svm.SVC` estimator, and compute the time required to train a :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a :class:`~sklearn.model_selection.GridSearchCV` instance. .. GENERATED FROM PYTHON SOURCE LINES 30-52 .. code-block:: Python rng = np.random.RandomState(0) X, y = datasets.make_classification(n_samples=1000, random_state=rng) gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7] Cs = [1, 10, 100, 1e3, 1e4, 1e5] param_grid = {"gamma": gammas, "C": Cs} clf = SVC(random_state=rng) tic = time() gsh = HalvingGridSearchCV( estimator=clf, param_grid=param_grid, factor=2, random_state=rng ) gsh.fit(X, y) gsh_time = time() - tic tic = time() gs = GridSearchCV(estimator=clf, param_grid=param_grid) gs.fit(X, y) gs_time = time() - tic .. GENERATED FROM PYTHON SOURCE LINES 53-54 We now plot heatmaps for both search estimators. .. GENERATED FROM PYTHON SOURCE LINES 54-122 .. code-block:: Python def make_heatmap(ax, gs, is_sh=False, make_cbar=False): """Helper to make a heatmap.""" results = pd.DataFrame(gs.cv_results_) results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype( np.float64 ) if is_sh: # SH dataframe: get mean_test_score values for the highest iter scores_matrix = results.sort_values("iter").pivot_table( index="param_gamma", columns="param_C", values="mean_test_score", aggfunc="last", ) else: scores_matrix = results.pivot( index="param_gamma", columns="param_C", values="mean_test_score" ) im = ax.imshow(scores_matrix) ax.set_xticks(np.arange(len(Cs))) ax.set_xticklabels(["{:.0E}".format(x) for x in Cs]) ax.set_xlabel("C", fontsize=15) ax.set_yticks(np.arange(len(gammas))) ax.set_yticklabels(["{:.0E}".format(x) for x in gammas]) ax.set_ylabel("gamma", fontsize=15) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") if is_sh: iterations = results.pivot_table( index="param_gamma", columns="param_C", values="iter", aggfunc="max" ).values for i in range(len(gammas)): for j in range(len(Cs)): ax.text( j, i, iterations[i, j], ha="center", va="center", color="w", fontsize=20, ) if make_cbar: fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) fig.colorbar(im, cax=cbar_ax) cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15) fig, axes = plt.subplots(ncols=2, sharey=True) ax1, ax2 = axes make_heatmap(ax1, gsh, is_sh=True) make_heatmap(ax2, gs, make_cbar=True) ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15) ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15) plt.show() .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_successive_halving_heatmap_001.png :alt: Successive Halving time = 1.576s, GridSearch time = 5.791s :srcset: /auto_examples/model_selection/images/sphx_glr_plot_successive_halving_heatmap_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 123-133 The heatmaps show the mean test score of the parameter combinations for an :class:`~sklearn.svm.SVC` instance. The :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the iteration at which the combinations where last used. The combinations marked as ``0`` were only evaluated at the first iteration, while the ones with ``5`` are the parameter combinations that are considered the best ones. We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV` class is able to find parameter combinations that are just as accurate as :class:`~sklearn.model_selection.GridSearchCV`, in much less time. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.533 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_successive_halving_heatmap.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_successive_halving_heatmap.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/index.html?path=auto_examples/model_selection/plot_successive_halving_heatmap.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_successive_halving_heatmap.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_successive_halving_heatmap.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_successive_halving_heatmap.zip ` .. include:: plot_successive_halving_heatmap.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_