{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Parameter selection, Validation, and Testing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most models have parameters that influence how complex a model they can learn. Remember using `KNeighborsRegressor`.\n", "If we change the number of neighbors we consider, we get a smoother and smoother prediction:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the above figure, we see fits for three different values of ``n_neighbors``.\n", "For ``n_neighbors=2``, the data is overfit, the model is too flexible and can adjust too much to the noise in the training data. For ``n_neighbors=20``, the model is not flexible enough, and can not model the variation in the data appropriately.\n", "\n", "In the middle, for ``n_neighbors = 5``, we have found a good mid-point. It fits\n", "the data fairly well, and does not suffer from the overfit or underfit\n", "problems seen in the figures on either side. What we would like is a\n", "way to quantitatively identify overfit and underfit, and optimize the\n", "hyperparameters (in this case, the polynomial degree d) in order to\n", "determine the best algorithm.\n", "\n", "We trade off remembering too much about the particularities and noise of the training data vs. not modeling enough of the variability. This is a trade-off that needs to be made in basically every machine learning application and is a central concept, called bias-variance-tradeoff or \"overfitting vs underfitting\"." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameters, Over-fitting, and Under-fitting\n", "\n", "Unfortunately, there is no general rule how to find the sweet spot, and so machine learning practitioners have to find the best trade-off of model-complexity and generalization by trying several hyperparameter settings. Hyperparameters are the internal knobs or tuning parameters of a machine learning algorithm (in contrast to model parameters that the algorithm learns from the training data -- for example, the weight coefficients of a linear regression model); the number of *k* in K-nearest neighbors is such a hyperparameter.\n", "\n", "Most commonly this \"hyperparameter tuning\" is done using a brute force search, for example over multiple values of ``n_neighbors``:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score, KFold\n", "from sklearn.neighbors import KNeighborsRegressor\n", "# generate toy dataset:\n", "x = np.linspace(-3, 3, 100)\n", "rng = np.random.RandomState(42)\n", "y = np.sin(4 * x) + x + rng.normal(size=len(x))\n", "X = x[:, np.newaxis]\n", "\n", "cv = KFold(shuffle=True)\n", "\n", "# for each parameter setting do cross-validation:\n", "for n_neighbors in [1, 3, 5, 10, 20]:\n", " scores = cross_val_score(KNeighborsRegressor(n_neighbors=n_neighbors), X, y, cv=cv)\n", " print(\"n_neighbors: %d, average score: %f\" % (n_neighbors, np.mean(scores)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is a function in scikit-learn, called ``validation_plot`` to reproduce the cartoon figure above. It plots one parameter, such as the number of neighbors, against training and validation error (using cross-validation):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import validation_curve\n", "n_neighbors = [1, 3, 5, 10, 20, 50]\n", "train_scores, test_scores = validation_curve(KNeighborsRegressor(), X, y, param_name=\"n_neighbors\",\n", " param_range=n_neighbors, cv=cv)\n", "plt.plot(n_neighbors, train_scores.mean(axis=1), label=\"train accuracy\")\n", "plt.plot(n_neighbors, test_scores.mean(axis=1), label=\"test accuracy\")\n", "plt.ylabel('Accuracy')\n", "plt.xlabel('Number of neighbors')\n", "plt.xlim([50, 0])\n", "plt.legend(loc=\"best\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "