{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Efficiently searching for optimal tuning parameters ([video #8](https://www.youtube.com/watch?v=Gol_qOgRqfA&list=PL5-da3qGB5ICeMbQuqbbCOQWcS6OYBr5A&index=8))\n", "\n", "Created by [Data School](https://www.dataschool.io). Watch all 10 videos on [YouTube](https://www.youtube.com/playlist?list=PL5-da3qGB5ICeMbQuqbbCOQWcS6OYBr5A). Download the notebooks from [GitHub](https://github.com/justmarkham/scikit-learn-videos).\n", "\n", "**Note:** This notebook uses Python 3.9.1 and scikit-learn 0.23.2. The original notebook (shown in the video) used Python 2.7 and scikit-learn 0.16." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Agenda\n", "\n", "- How can K-fold cross-validation be used to search for an **optimal tuning parameter**?\n", "- How can this process be made **more efficient**?\n", "- How do you search for **multiple tuning parameters** at once?\n", "- What do you do with those tuning parameters before making **real predictions**?\n", "- How can the **computational expense** of this process be reduced?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Review of K-fold cross-validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Steps for cross-validation:\n", "\n", "- Dataset is split into K \"folds\" of **equal size**\n", "- Each fold acts as the **testing set** 1 time, and acts as the **training set** K-1 times\n", "- **Average testing performance** is used as the estimate of out-of-sample performance\n", "\n", "Benefits of cross-validation:\n", "\n", "- More **reliable** estimate of out-of-sample performance than train/test split\n", "- Can be used for selecting **tuning parameters**, choosing between **models**, and selecting **features**\n", "\n", "Drawbacks of cross-validation:\n", "\n", "- Can be computationally **expensive**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Review of parameter tuning using `cross_val_score`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Goal:** Select the best tuning parameters (aka \"hyperparameters\") for KNN on the iris dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# added empty cell so that the cell numbering matches the video" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_iris\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import cross_val_score\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# read in the iris data\n", "iris = load_iris()\n", "\n", "# create X (features) and y (response)\n", "X = iris.data\n", "y = iris.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 0.93333333 1. 1. 0.86666667 0.93333333\n", " 0.93333333 1. 1. 1. ]\n" ] } ], "source": [ "# 10-fold cross-validation with K=5 for KNN (the n_neighbors parameter)\n", "knn = KNeighborsClassifier(n_neighbors=5)\n", "scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')\n", "print(scores)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9666666666666668\n" ] } ], "source": [ "# use average accuracy as an estimate of out-of-sample accuracy\n", "print(scores.mean())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.96, 0.9533333333333334, 0.9666666666666666, 0.9666666666666666, 0.9666666666666668, 0.9666666666666668, 0.9666666666666668, 0.9666666666666668, 0.9733333333333334, 0.9666666666666668, 0.9666666666666668, 0.9733333333333334, 0.9800000000000001, 0.9733333333333334, 0.9733333333333334, 0.9733333333333334, 0.9733333333333334, 0.9800000000000001, 0.9733333333333334, 0.9800000000000001, 0.9666666666666666, 0.9666666666666666, 0.9733333333333334, 0.96, 0.9666666666666666, 0.96, 0.9666666666666666, 0.9533333333333334, 0.9533333333333334, 0.9533333333333334]\n" ] } ], "source": [ "# search for an optimal value of K for KNN\n", "k_range = list(range(1, 31))\n", "k_scores = []\n", "for k in k_range:\n", " knn = KNeighborsClassifier(n_neighbors=k)\n", " scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')\n", " k_scores.append(scores.mean())\n", "print(k_scores)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Cross-Validated Accuracy')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot the value of K for KNN (x-axis) versus the cross-validated accuracy (y-axis)\n", "plt.plot(k_range, k_scores)\n", "plt.xlabel('Value of K for KNN')\n", "plt.ylabel('Cross-Validated Accuracy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## More efficient parameter tuning using `GridSearchCV`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Allows you to define a **grid of parameters** that will be **searched** using K-fold cross-validation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]\n" ] } ], "source": [ "# define the parameter values that should be searched\n", "k_range = list(range(1, 31))\n", "print(k_range)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}\n" ] } ], "source": [ "# create a parameter grid: map the parameter names to the values that should be searched\n", "param_grid = dict(n_neighbors=k_range)\n", "print(param_grid)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# instantiate the grid\n", "grid = GridSearchCV(knn, param_grid, cv=10, scoring='accuracy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- You can set **`n_jobs = -1`** to run computations in parallel (if supported by your computer and OS)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=10, estimator=KNeighborsClassifier(n_neighbors=30),\n", " param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,\n", " 23, 24, 25, 26, 27, 28, 29, 30]},\n", " scoring='accuracy')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit the grid with data\n", "grid.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_scorestd_test_scoreparams
00.9600000.053333{'n_neighbors': 1}
10.9533330.052068{'n_neighbors': 2}
20.9666670.044721{'n_neighbors': 3}
30.9666670.044721{'n_neighbors': 4}
40.9666670.044721{'n_neighbors': 5}
50.9666670.044721{'n_neighbors': 6}
60.9666670.044721{'n_neighbors': 7}
70.9666670.044721{'n_neighbors': 8}
80.9733330.032660{'n_neighbors': 9}
90.9666670.044721{'n_neighbors': 10}
100.9666670.044721{'n_neighbors': 11}
110.9733330.032660{'n_neighbors': 12}
120.9800000.030551{'n_neighbors': 13}
130.9733330.044222{'n_neighbors': 14}
140.9733330.032660{'n_neighbors': 15}
150.9733330.032660{'n_neighbors': 16}
160.9733330.032660{'n_neighbors': 17}
170.9800000.030551{'n_neighbors': 18}
180.9733330.032660{'n_neighbors': 19}
190.9800000.030551{'n_neighbors': 20}
200.9666670.033333{'n_neighbors': 21}
210.9666670.033333{'n_neighbors': 22}
220.9733330.032660{'n_neighbors': 23}
230.9600000.044222{'n_neighbors': 24}
240.9666670.033333{'n_neighbors': 25}
250.9600000.044222{'n_neighbors': 26}
260.9666670.044721{'n_neighbors': 27}
270.9533330.042687{'n_neighbors': 28}
280.9533330.042687{'n_neighbors': 29}
290.9533330.042687{'n_neighbors': 30}
\n", "
" ], "text/plain": [ " mean_test_score std_test_score params\n", "0 0.960000 0.053333 {'n_neighbors': 1}\n", "1 0.953333 0.052068 {'n_neighbors': 2}\n", "2 0.966667 0.044721 {'n_neighbors': 3}\n", "3 0.966667 0.044721 {'n_neighbors': 4}\n", "4 0.966667 0.044721 {'n_neighbors': 5}\n", "5 0.966667 0.044721 {'n_neighbors': 6}\n", "6 0.966667 0.044721 {'n_neighbors': 7}\n", "7 0.966667 0.044721 {'n_neighbors': 8}\n", "8 0.973333 0.032660 {'n_neighbors': 9}\n", "9 0.966667 0.044721 {'n_neighbors': 10}\n", "10 0.966667 0.044721 {'n_neighbors': 11}\n", "11 0.973333 0.032660 {'n_neighbors': 12}\n", "12 0.980000 0.030551 {'n_neighbors': 13}\n", "13 0.973333 0.044222 {'n_neighbors': 14}\n", "14 0.973333 0.032660 {'n_neighbors': 15}\n", "15 0.973333 0.032660 {'n_neighbors': 16}\n", "16 0.973333 0.032660 {'n_neighbors': 17}\n", "17 0.980000 0.030551 {'n_neighbors': 18}\n", "18 0.973333 0.032660 {'n_neighbors': 19}\n", "19 0.980000 0.030551 {'n_neighbors': 20}\n", "20 0.966667 0.033333 {'n_neighbors': 21}\n", "21 0.966667 0.033333 {'n_neighbors': 22}\n", "22 0.973333 0.032660 {'n_neighbors': 23}\n", "23 0.960000 0.044222 {'n_neighbors': 24}\n", "24 0.966667 0.033333 {'n_neighbors': 25}\n", "25 0.960000 0.044222 {'n_neighbors': 26}\n", "26 0.966667 0.044721 {'n_neighbors': 27}\n", "27 0.953333 0.042687 {'n_neighbors': 28}\n", "28 0.953333 0.042687 {'n_neighbors': 29}\n", "29 0.953333 0.042687 {'n_neighbors': 30}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view the results as a pandas DataFrame\n", "import pandas as pd\n", "pd.DataFrame(grid.cv_results_)[['mean_test_score', 'std_test_score', 'params']]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'n_neighbors': 1}\n", "0.96\n" ] } ], "source": [ "# examine the first result\n", "print(grid.cv_results_['params'][0])\n", "print(grid.cv_results_['mean_test_score'][0])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.96 0.95333333 0.96666667 0.96666667 0.96666667 0.96666667\n", " 0.96666667 0.96666667 0.97333333 0.96666667 0.96666667 0.97333333\n", " 0.98 0.97333333 0.97333333 0.97333333 0.97333333 0.98\n", " 0.97333333 0.98 0.96666667 0.96666667 0.97333333 0.96\n", " 0.96666667 0.96 0.96666667 0.95333333 0.95333333 0.95333333]\n" ] } ], "source": [ "# print the array of mean scores only\n", "grid_mean_scores = grid.cv_results_['mean_test_score']\n", "print(grid_mean_scores)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Cross-Validated Accuracy')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot the results\n", "plt.plot(k_range, grid_mean_scores)\n", "plt.xlabel('Value of K for KNN')\n", "plt.ylabel('Cross-Validated Accuracy')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9800000000000001\n", "{'n_neighbors': 13}\n", "KNeighborsClassifier(n_neighbors=13)\n" ] } ], "source": [ "# examine the best model\n", "print(grid.best_score_)\n", "print(grid.best_params_)\n", "print(grid.best_estimator_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Searching multiple parameters simultaneously" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Example:** tuning `max_depth` and `min_samples_leaf` for a `DecisionTreeClassifier`\n", "- Could tune parameters **independently**: change `max_depth` while leaving `min_samples_leaf` at its default value, and vice versa\n", "- But, best performance might be achieved when **neither parameter** is at its default value" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# define the parameter values that should be searched\n", "k_range = list(range(1, 31))\n", "weight_options = ['uniform', 'distance']" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], 'weights': ['uniform', 'distance']}\n" ] } ], "source": [ "# create a parameter grid: map the parameter names to the values that should be searched\n", "param_grid = dict(n_neighbors=k_range, weights=weight_options)\n", "print(param_grid)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=10, estimator=KNeighborsClassifier(n_neighbors=30),\n", " param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,\n", " 23, 24, 25, 26, 27, 28, 29, 30],\n", " 'weights': ['uniform', 'distance']},\n", " scoring='accuracy')" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# instantiate and fit the grid\n", "grid = GridSearchCV(knn, param_grid, cv=10, scoring='accuracy')\n", "grid.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_scorestd_test_scoreparams
00.9600000.053333{'n_neighbors': 1, 'weights': 'uniform'}
10.9600000.053333{'n_neighbors': 1, 'weights': 'distance'}
20.9533330.052068{'n_neighbors': 2, 'weights': 'uniform'}
30.9600000.053333{'n_neighbors': 2, 'weights': 'distance'}
40.9666670.044721{'n_neighbors': 3, 'weights': 'uniform'}
50.9666670.044721{'n_neighbors': 3, 'weights': 'distance'}
60.9666670.044721{'n_neighbors': 4, 'weights': 'uniform'}
70.9666670.044721{'n_neighbors': 4, 'weights': 'distance'}
80.9666670.044721{'n_neighbors': 5, 'weights': 'uniform'}
90.9666670.044721{'n_neighbors': 5, 'weights': 'distance'}
100.9666670.044721{'n_neighbors': 6, 'weights': 'uniform'}
110.9666670.044721{'n_neighbors': 6, 'weights': 'distance'}
120.9666670.044721{'n_neighbors': 7, 'weights': 'uniform'}
130.9666670.044721{'n_neighbors': 7, 'weights': 'distance'}
140.9666670.044721{'n_neighbors': 8, 'weights': 'uniform'}
150.9666670.044721{'n_neighbors': 8, 'weights': 'distance'}
160.9733330.032660{'n_neighbors': 9, 'weights': 'uniform'}
170.9733330.032660{'n_neighbors': 9, 'weights': 'distance'}
180.9666670.044721{'n_neighbors': 10, 'weights': 'uniform'}
190.9733330.032660{'n_neighbors': 10, 'weights': 'distance'}
200.9666670.044721{'n_neighbors': 11, 'weights': 'uniform'}
210.9733330.032660{'n_neighbors': 11, 'weights': 'distance'}
220.9733330.032660{'n_neighbors': 12, 'weights': 'uniform'}
230.9733330.044222{'n_neighbors': 12, 'weights': 'distance'}
240.9800000.030551{'n_neighbors': 13, 'weights': 'uniform'}
250.9733330.032660{'n_neighbors': 13, 'weights': 'distance'}
260.9733330.044222{'n_neighbors': 14, 'weights': 'uniform'}
270.9733330.032660{'n_neighbors': 14, 'weights': 'distance'}
280.9733330.032660{'n_neighbors': 15, 'weights': 'uniform'}
290.9800000.030551{'n_neighbors': 15, 'weights': 'distance'}
300.9733330.032660{'n_neighbors': 16, 'weights': 'uniform'}
310.9733330.032660{'n_neighbors': 16, 'weights': 'distance'}
320.9733330.032660{'n_neighbors': 17, 'weights': 'uniform'}
330.9800000.030551{'n_neighbors': 17, 'weights': 'distance'}
340.9800000.030551{'n_neighbors': 18, 'weights': 'uniform'}
350.9733330.032660{'n_neighbors': 18, 'weights': 'distance'}
360.9733330.032660{'n_neighbors': 19, 'weights': 'uniform'}
370.9800000.030551{'n_neighbors': 19, 'weights': 'distance'}
380.9800000.030551{'n_neighbors': 20, 'weights': 'uniform'}
390.9666670.044721{'n_neighbors': 20, 'weights': 'distance'}
400.9666670.033333{'n_neighbors': 21, 'weights': 'uniform'}
410.9666670.044721{'n_neighbors': 21, 'weights': 'distance'}
420.9666670.033333{'n_neighbors': 22, 'weights': 'uniform'}
430.9666670.044721{'n_neighbors': 22, 'weights': 'distance'}
440.9733330.032660{'n_neighbors': 23, 'weights': 'uniform'}
450.9733330.032660{'n_neighbors': 23, 'weights': 'distance'}
460.9600000.044222{'n_neighbors': 24, 'weights': 'uniform'}
470.9733330.032660{'n_neighbors': 24, 'weights': 'distance'}
480.9666670.033333{'n_neighbors': 25, 'weights': 'uniform'}
490.9733330.032660{'n_neighbors': 25, 'weights': 'distance'}
500.9600000.044222{'n_neighbors': 26, 'weights': 'uniform'}
510.9666670.044721{'n_neighbors': 26, 'weights': 'distance'}
520.9666670.044721{'n_neighbors': 27, 'weights': 'uniform'}
530.9800000.030551{'n_neighbors': 27, 'weights': 'distance'}
540.9533330.042687{'n_neighbors': 28, 'weights': 'uniform'}
550.9733330.032660{'n_neighbors': 28, 'weights': 'distance'}
560.9533330.042687{'n_neighbors': 29, 'weights': 'uniform'}
570.9733330.032660{'n_neighbors': 29, 'weights': 'distance'}
580.9533330.042687{'n_neighbors': 30, 'weights': 'uniform'}
590.9666670.033333{'n_neighbors': 30, 'weights': 'distance'}
\n", "
" ], "text/plain": [ " mean_test_score std_test_score \\\n", "0 0.960000 0.053333 \n", "1 0.960000 0.053333 \n", "2 0.953333 0.052068 \n", "3 0.960000 0.053333 \n", "4 0.966667 0.044721 \n", "5 0.966667 0.044721 \n", "6 0.966667 0.044721 \n", "7 0.966667 0.044721 \n", "8 0.966667 0.044721 \n", "9 0.966667 0.044721 \n", "10 0.966667 0.044721 \n", "11 0.966667 0.044721 \n", "12 0.966667 0.044721 \n", "13 0.966667 0.044721 \n", "14 0.966667 0.044721 \n", "15 0.966667 0.044721 \n", "16 0.973333 0.032660 \n", "17 0.973333 0.032660 \n", "18 0.966667 0.044721 \n", "19 0.973333 0.032660 \n", "20 0.966667 0.044721 \n", "21 0.973333 0.032660 \n", "22 0.973333 0.032660 \n", "23 0.973333 0.044222 \n", "24 0.980000 0.030551 \n", "25 0.973333 0.032660 \n", "26 0.973333 0.044222 \n", "27 0.973333 0.032660 \n", "28 0.973333 0.032660 \n", "29 0.980000 0.030551 \n", "30 0.973333 0.032660 \n", "31 0.973333 0.032660 \n", "32 0.973333 0.032660 \n", "33 0.980000 0.030551 \n", "34 0.980000 0.030551 \n", "35 0.973333 0.032660 \n", "36 0.973333 0.032660 \n", "37 0.980000 0.030551 \n", "38 0.980000 0.030551 \n", "39 0.966667 0.044721 \n", "40 0.966667 0.033333 \n", "41 0.966667 0.044721 \n", "42 0.966667 0.033333 \n", "43 0.966667 0.044721 \n", "44 0.973333 0.032660 \n", "45 0.973333 0.032660 \n", "46 0.960000 0.044222 \n", "47 0.973333 0.032660 \n", "48 0.966667 0.033333 \n", "49 0.973333 0.032660 \n", "50 0.960000 0.044222 \n", "51 0.966667 0.044721 \n", "52 0.966667 0.044721 \n", "53 0.980000 0.030551 \n", "54 0.953333 0.042687 \n", "55 0.973333 0.032660 \n", "56 0.953333 0.042687 \n", "57 0.973333 0.032660 \n", "58 0.953333 0.042687 \n", "59 0.966667 0.033333 \n", "\n", " params \n", "0 {'n_neighbors': 1, 'weights': 'uniform'} \n", "1 {'n_neighbors': 1, 'weights': 'distance'} \n", "2 {'n_neighbors': 2, 'weights': 'uniform'} \n", "3 {'n_neighbors': 2, 'weights': 'distance'} \n", "4 {'n_neighbors': 3, 'weights': 'uniform'} \n", "5 {'n_neighbors': 3, 'weights': 'distance'} \n", "6 {'n_neighbors': 4, 'weights': 'uniform'} \n", "7 {'n_neighbors': 4, 'weights': 'distance'} \n", "8 {'n_neighbors': 5, 'weights': 'uniform'} \n", "9 {'n_neighbors': 5, 'weights': 'distance'} \n", "10 {'n_neighbors': 6, 'weights': 'uniform'} \n", "11 {'n_neighbors': 6, 'weights': 'distance'} \n", "12 {'n_neighbors': 7, 'weights': 'uniform'} \n", "13 {'n_neighbors': 7, 'weights': 'distance'} \n", "14 {'n_neighbors': 8, 'weights': 'uniform'} \n", "15 {'n_neighbors': 8, 'weights': 'distance'} \n", "16 {'n_neighbors': 9, 'weights': 'uniform'} \n", "17 {'n_neighbors': 9, 'weights': 'distance'} \n", "18 {'n_neighbors': 10, 'weights': 'uniform'} \n", "19 {'n_neighbors': 10, 'weights': 'distance'} \n", "20 {'n_neighbors': 11, 'weights': 'uniform'} \n", "21 {'n_neighbors': 11, 'weights': 'distance'} \n", "22 {'n_neighbors': 12, 'weights': 'uniform'} \n", "23 {'n_neighbors': 12, 'weights': 'distance'} \n", "24 {'n_neighbors': 13, 'weights': 'uniform'} \n", "25 {'n_neighbors': 13, 'weights': 'distance'} \n", "26 {'n_neighbors': 14, 'weights': 'uniform'} \n", "27 {'n_neighbors': 14, 'weights': 'distance'} \n", "28 {'n_neighbors': 15, 'weights': 'uniform'} \n", "29 {'n_neighbors': 15, 'weights': 'distance'} \n", "30 {'n_neighbors': 16, 'weights': 'uniform'} \n", "31 {'n_neighbors': 16, 'weights': 'distance'} \n", "32 {'n_neighbors': 17, 'weights': 'uniform'} \n", "33 {'n_neighbors': 17, 'weights': 'distance'} \n", "34 {'n_neighbors': 18, 'weights': 'uniform'} \n", "35 {'n_neighbors': 18, 'weights': 'distance'} \n", "36 {'n_neighbors': 19, 'weights': 'uniform'} \n", "37 {'n_neighbors': 19, 'weights': 'distance'} \n", "38 {'n_neighbors': 20, 'weights': 'uniform'} \n", "39 {'n_neighbors': 20, 'weights': 'distance'} \n", "40 {'n_neighbors': 21, 'weights': 'uniform'} \n", "41 {'n_neighbors': 21, 'weights': 'distance'} \n", "42 {'n_neighbors': 22, 'weights': 'uniform'} \n", "43 {'n_neighbors': 22, 'weights': 'distance'} \n", "44 {'n_neighbors': 23, 'weights': 'uniform'} \n", "45 {'n_neighbors': 23, 'weights': 'distance'} \n", "46 {'n_neighbors': 24, 'weights': 'uniform'} \n", "47 {'n_neighbors': 24, 'weights': 'distance'} \n", "48 {'n_neighbors': 25, 'weights': 'uniform'} \n", "49 {'n_neighbors': 25, 'weights': 'distance'} \n", "50 {'n_neighbors': 26, 'weights': 'uniform'} \n", "51 {'n_neighbors': 26, 'weights': 'distance'} \n", "52 {'n_neighbors': 27, 'weights': 'uniform'} \n", "53 {'n_neighbors': 27, 'weights': 'distance'} \n", "54 {'n_neighbors': 28, 'weights': 'uniform'} \n", "55 {'n_neighbors': 28, 'weights': 'distance'} \n", "56 {'n_neighbors': 29, 'weights': 'uniform'} \n", "57 {'n_neighbors': 29, 'weights': 'distance'} \n", "58 {'n_neighbors': 30, 'weights': 'uniform'} \n", "59 {'n_neighbors': 30, 'weights': 'distance'} " ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view the results\n", "pd.DataFrame(grid.cv_results_)[['mean_test_score', 'std_test_score', 'params']]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9800000000000001\n", "{'n_neighbors': 13, 'weights': 'uniform'}\n" ] } ], "source": [ "# examine the best model\n", "print(grid.best_score_)\n", "print(grid.best_params_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the best parameters to make predictions" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train your model using all data and the best known parameters\n", "knn = KNeighborsClassifier(n_neighbors=13, weights='uniform')\n", "knn.fit(X, y)\n", "\n", "# make a prediction on out-of-sample data\n", "knn.predict([[3, 5, 4, 2]])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# shortcut: GridSearchCV automatically refits the best model using all of the data\n", "grid.predict([[3, 5, 4, 2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reducing computational expense using `RandomizedSearchCV`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Searching many different parameters at once may be computationally infeasible\n", "- `RandomizedSearchCV` searches a subset of the parameters, and you control the computational \"budget\"" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import RandomizedSearchCV" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# specify \"parameter distributions\" rather than a \"parameter grid\"\n", "param_dist = dict(n_neighbors=k_range, weights=weight_options)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Important:** Specify a continuous distribution (rather than a list of values) for any continous parameters" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_scorestd_test_scoreparams
00.9733330.032660{'weights': 'distance', 'n_neighbors': 16}
10.9666670.033333{'weights': 'uniform', 'n_neighbors': 22}
20.9800000.030551{'weights': 'uniform', 'n_neighbors': 18}
30.9666670.044721{'weights': 'uniform', 'n_neighbors': 27}
40.9533330.042687{'weights': 'uniform', 'n_neighbors': 29}
50.9733330.032660{'weights': 'distance', 'n_neighbors': 10}
60.9666670.044721{'weights': 'distance', 'n_neighbors': 22}
70.9733330.044222{'weights': 'uniform', 'n_neighbors': 14}
80.9733330.044222{'weights': 'distance', 'n_neighbors': 12}
90.9733330.032660{'weights': 'uniform', 'n_neighbors': 15}
\n", "
" ], "text/plain": [ " mean_test_score std_test_score params\n", "0 0.973333 0.032660 {'weights': 'distance', 'n_neighbors': 16}\n", "1 0.966667 0.033333 {'weights': 'uniform', 'n_neighbors': 22}\n", "2 0.980000 0.030551 {'weights': 'uniform', 'n_neighbors': 18}\n", "3 0.966667 0.044721 {'weights': 'uniform', 'n_neighbors': 27}\n", "4 0.953333 0.042687 {'weights': 'uniform', 'n_neighbors': 29}\n", "5 0.973333 0.032660 {'weights': 'distance', 'n_neighbors': 10}\n", "6 0.966667 0.044721 {'weights': 'distance', 'n_neighbors': 22}\n", "7 0.973333 0.044222 {'weights': 'uniform', 'n_neighbors': 14}\n", "8 0.973333 0.044222 {'weights': 'distance', 'n_neighbors': 12}\n", "9 0.973333 0.032660 {'weights': 'uniform', 'n_neighbors': 15}" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# n_iter controls the number of searches\n", "rand = RandomizedSearchCV(knn, param_dist, cv=10, scoring='accuracy', n_iter=10, random_state=5)\n", "rand.fit(X, y)\n", "pd.DataFrame(rand.cv_results_)[['mean_test_score', 'std_test_score', 'params']]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9800000000000001\n", "{'weights': 'uniform', 'n_neighbors': 18}\n" ] } ], "source": [ "# examine the best model\n", "print(rand.best_score_)\n", "print(rand.best_params_)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.98, 0.98, 0.98, 0.98, 0.973, 0.98, 0.973, 0.98, 0.98, 0.98, 0.973, 0.98, 0.98, 0.973, 0.973, 0.98, 0.98, 0.973, 0.973, 0.98]\n" ] } ], "source": [ "# run RandomizedSearchCV 20 times (with n_iter=10) and record the best score\n", "best_scores = []\n", "for _ in range(20):\n", " rand = RandomizedSearchCV(knn, param_dist, cv=10, scoring='accuracy', n_iter=10)\n", " rand.fit(X, y)\n", " best_scores.append(round(rand.best_score_, 3))\n", "print(best_scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Resources\n", "\n", "- scikit-learn documentation: [Grid search](https://scikit-learn.org/stable/modules/grid_search.html), [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html), [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)\n", "- Timed example: [Comparing randomized search and grid search](https://scikit-learn.org/stable/auto_examples/model_selection/plot_randomized_search.html)\n", "- scikit-learn workshop by Andreas Mueller: [Video segment on randomized search](https://youtu.be/0wUF_Ov8b0A?t=17m38s) (3 minutes), [related notebook](https://github.com/amueller/pydata-nyc-advanced-sklearn/blob/master/Chapter%203%20-%20Randomized%20Hyper%20Parameter%20Search.ipynb)\n", "- Paper by Yoshua Bengio: [Random Search for Hyper-Parameter Optimization](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comments or Questions?\n", "\n", "- Email: \n", "- Website: https://www.dataschool.io\n", "- Twitter: [@justmarkham](https://twitter.com/justmarkham)\n", "\n", "© 2021 [Data School](https://www.dataschool.io). All rights reserved." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" } }, "nbformat": 4, "nbformat_minor": 1 }