{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lesson 9 - Overfitting\n", "\n", "> What is overfitting and how can it be avoided?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lvwerra/dslectures/master?urlpath=lab/tree/notebooks%2Flesson09_overfitting.ipynb)[![slides](https://img.shields.io/static/v1?label=slides&message=2021-lesson09.pdf&color=blue&logo=Google-drive)](https://drive.google.com/open?id=1KnV9j6Gnh0aJdhXnXJnMYH8Ppyn-H29U)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning Objectives\n", "Overfitting is a phenomena that can always occur when a model is fitted to data. Therefore, it is important to understand what it entails and how it can be avoided. In this notebook we will address these three questions related to overfitting:\n", "1. What is overfitting?\n", "2. How can we measure overfitting?\n", "3. How can overfitting be avoided?\n", "\n", "## References\n", "* Chapter 5: Overfitting and its avoidance of _Data Science for Business_ by F. Provost and P. Fawcett\n", "\n", "\n", "## Homework\n", "* Work through part 2 of the notebook concerning the housing dataset.\n", "* Solve exercises in the notebook. In particular, tune a random forest for the churn dataset in part 3." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is overfitting?\n", "\n", "Already John von Neumann, one of the founding fathers of computing, knew that fitting complex models to data is a tricky business:\n", ">With four parameters I can fit an elephant, and with five I can make him wiggle his trunk.\n", ">\n", "> \\- John von Neumann\n", "\n", "
\n", "\n", "

Figure reference:Irrelevant image.

\n", "
\n", "\n", "\n", "When we fit a model to data we always have to be careful not to overfit. If we overfit the model this means that the model learned specific aspects of the training data and does not *generalise* to new, unseen data. Instead of learning useful relations between the input feature and the target the model has memorised the training samples. If this happens the model we perform very poorly on new data and therefore we want to make sure this does not happen.\n", "\n", "Fortunately, there are tools that can help detect and avoid overfitting. One tool we already used: splitting the data into two sets. Measuring the performance difference between the training and validation set already helps identifying when we are overfitting. In this lecture we will see an even more systematic way of splitting the data namely *cross-validation*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# reload modules before executing user code\n", "%load_ext autoreload\n", "# reload all modules every time before executing Python code\n", "%autoreload 2\n", "# render plots in notebook\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "import time\n", "\n", "from sklearn.model_selection import train_test_split, cross_validate\n", "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n", "\n", "from dslectures.core import rmse, make_polynomial_data, PolynomialRegressor, get_dataset\n", "from dslectures.structured import proc_df\n", "\n", "np.warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 0: tqdm\n", "Before starting with overfitting we introduce the tqdm library. In this notebook we will make extensive use of for-loops and tqdm is a very handy addition to them. With just one expression you can add a progress bar to your for loop. Let say you have a for-loop that does some computation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(1000):\n", " #do something\n", " a = i**2\n", " time.sleep(0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:10<00:00, 92.63it/s]\n" ] } ], "source": [ "for i in tqdm(range(1000)):\n", " #do something\n", " a = i**2\n", " time.sleep(0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [00:05<00:00, 1.00s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "sum: 15\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "sum_i = 0\n", "for i in tqdm([1,2,3,4,5]):\n", " sum_i += i\n", " time.sleep(1)\n", "print('sum:', sum_i)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Installation\n", "If you run tqdm in Jupyter notebooks locally you might need to run the following in the terminal:\n", "\n", "```\n", "jupyter nbextension enable --py --sys-prefix widgetsnbextension\n", "```\n", "\n", "For Jupyterlab run additionally this command:\n", "\n", "```\n", "jupyter labextension install @jupyter-widgets/jupyterlab-manager\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Overfitting Polynomials\n", "\n", "### Polynomials\n", "To study the nature of overfitting we start looking at a the toy example of a polynomials. Later we will see our findings are not specific to polynomials and can be extended to other *supervised* machine learning methods such as linear regressors, tree classifiers or random forests. \n", "\n", "A polynomial of degree $n$ has the form:\n", "$$f(x) = w_0 + w_1\\cdot x + w_2\\cdot x^2 +\\ldots+w_n\\cdot x^n$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate data\n", "\n", "In this block we generate a random polynomial of degree 3 With the helper function `get_polynomial_data`. The polynomial has the form:\n", "$$f(x) = 10 \\cdot x^3 -5 \\cdot x $$\n", "If you are interested in creating or fitting polynomial data check out the functions `numpy.polyval` and `numpy.polyfit`. In this lesson we will use wrapper functions around them. The function `get_polynomial_data(w, n_samples=100)` evaluates the polymial defined by `w` on `n_samples` random points and adds some noise to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = np.array([10, 0, -5, 0])\n", "X, y = make_polynomial_data(weights, n_samples=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the polynomial data in a scatter plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def scatter_plot_polynomial(X, y, label='', title='Polynomial data'):\n", " plt.title(title)\n", " plt.scatter(X, y, label=label)\n", " plt.xlabel('x')\n", " plt.ylabel('y')\n", " plt.grid(True)\n", " plt.legend(loc='best')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scatter_plot_polynomial(X, y, label='data')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train and validation sets\n", "We have already discussed in previous lessons that a validation set helps us investigate overfitting. The model is trained with the training set and its performance is measured using the validation set. We will do the same here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2,)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scatter_plot_polynomial(X_train, y_train, label='training set')\n", "scatter_plot_polynomial(X_valid, y_valid, label='validation set')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now lets fit a polynomial to the generated data. We can use `PolyFit` class to fit a polynomial to the data. We can pass the the degree as an argument and then use the same functions `fit`, `predict` and `evaluate` functions known from `scikit-learn`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pr = PolynomialRegressor(degree=3)\n", "pr.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we want to see how that polynomial looks like on a range from [-1, 1]:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_lin = np.linspace(-1, 1, 1000)\n", "y_fit = pr.predict(X_lin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEWCAYAAABv+EDhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3de3yO9f/A8ddns7EZm0M5jHLIKXOYQymSUayopEzn6EAokq8i/eggFEU6oQNCMcd8ky8q+4bwdRqSMxVzGtlsbO30+f1x31v3tvve7uOue7vfz8djD7uv+zq8d923631dn6PSWiOEEML3+BkdgBBCCGNIAhBCCB8lCUAIIXyUJAAhhPBRkgCEEMJHSQIQQggfJQlAlHpKqdeVUguMjsOSUupRpdQ6O9d1KH6llFZK3eB8dEKYSAIQXkMp9btSKk0plaqUOqeUmquUCjE6LmdorRdqrbsbGYNSqp45WZQzMg7hvSQBCG9zj9Y6BGgDtANeMzgeIcosSQDCK2mtE4A1QASAUqq2UmqVUuovpdRRpdSz1rZTSq1WSr1QYNlepdT95t+1Uuo5pdQRpVSSUupjpZQyv+enlHpNKfWHUuq8UuorpVSo+b3cu+kBSqmTSqlL5v20N+8/SSn1kcUx+yulNlm8/sC83WWl1E6l1G32ngul1Cil1Bml1Gml1FMF3uuplNpt3u9JpdTrFm//bP43yfxUdYtSqqFS6iel1EWl1AWl1EKlVJi9sYiyRRKA8EpKqbrA3cBu86JFwCmgNvAgMFEp1dXKpvOAxyz20woIB1ZbrNMLaA+0BGKAHubl/c0/UUADIAT4iPxuBhoB/YDpwFjgDqA5EKOUut3Gn7QdaA1UBb4GliilKthYN49SKhr4F3Cn+bh3FFjlCvAEEAb0BAYrpXqb3+ts/jdMax2itd4CKGASpvPYDKgLvF5cHKJskgQgvM1KpVQSsAn4L6YLfV2gI/CK1jpdax0PfI7pwlfQKqCxUqqR+fXjwGKtdYbFOpO11kla6z+BDZguzACPAu9rrY9rrVOBMcBDBcrQ3zLHsA7TxfcbrfV58xPLRiDS2h+ltV6gtb6otc7SWr8HlAea2HE+YoA5WutftdZXKHCx1lrHaa33aa1ztNZ7gW8AW0kIrfVRrfV6rfXfWutE4P2i1hdlmyQA4W16a63DtNbXa62HaK3TMN2t/qW1TrFY7w9Md/b5aK3TgcXAY0opP+BhYH6B1c5a/H4V050+5uP8UeAY5YAaFsvOWfyeZuW11UprpdS/lFIHlFLJ5gQXClS3tm4BtYGTBWKy3O/NSqkNSqlEpVQy8FxR+1VK1VBKLVJKJSilLgML7IxDlEGSAERpcBqoqpSqZLHsOiDBxvrzMN3NdwOumos+7D3O9QWOkUX+i7zDzOX9L2O6m6+itQ4DkjEVxxTnDKZiGsuYLH2N6amnrtY6FJhpsV9rQ/1ONC9vobWujKm4zJ44RBkkCUB4Pa31SeAXYJJSqoJSqiXwNKa7V2vrbwFygPcofPdflG+AEUqp+ubmpxMxFR9lufQHQCVMiSQRKKeUGgdUtnPbWKC/UupGpVQwMN7Kvv/SWqcrpW4CHrF4LxHTeWhQYP1UIFkpFQ6McvivEWWGJABRWjwM1MN0l74CGK+1/qGI9b8CWmAjSdjwJaaE8TNwAkgHXihyC/usBf4DHMZUhJNO/mIdm7TWazBVNv8EHDX/a2kI8KZSKgUYhylh5G57FXgb2GxupdQBeANTE9tkTBXjy53/s0Rpp2RCGFEWKaWeAAZqrTsZHYsQ3kqeAESZYy4qGQLMNjoWIbyZJABRpiilemAq+z6HqYJUCGGDFAEJIYSPMvQJQCkVppRaqpQ6aG4jfYuR8QghhC8xepTAD4D/aK0fVEoFAsFFrVy9enVdr149pw505coVKlas6NS2niRxOUbicozE5RhvjQtci23nzp0XtNbXFHpDa23ID6aekCcwF0PZ89O2bVvtrA0bNji9rSdJXI6RuBwjcTnGW+PS2rXYgB3ayjXVsDoApVRrTK00fgNaATuB4do03onlegOBgQA1atRou2jRIqeOl5qaSkiI9w0tL3E5RuJyjMTlGG+NC1yLLSoqaqfWul2hN6xlhZL4wTTWexZws/n1B5gG2pInAC8gcTlG4nKMxOU4TzwBGFkJfAo4pbXeZn69FFMPRSGEECXAsEpgrfVZ8wQWTbTWhzAN3PWbo/vJzMzk1KlTpKenF7leaGgoBw4ccDJazynLcVWoUIE6deoQEBDgpqiEEO5kdCugF4CF5hZAx4EBju7g1KlTVKpUiXr16mGe2MmqlJQUKlWqZPN9o5TVuLTWXLx4kVOnTlG/fn03RiaEcBdDE4A2TexRuGLCAenp6cVe/EXJU0pRrVo1EhMTjQ5FiFJp5e4Epqw9xOmkNGqHBTGqVbbbj1EmhoKQi793ks9FCOes3J3AmOX7SEhKQwMJSWkkXEpj5W5bU2A4p0wkACGEKEumrD1EWmb+O/4crZmy9pBbjyMJwEVJSUl88sknTm179913k5SUVOQ648aN44cfihr23jO+++47fvvN4Tp5IYQbnE5Kc2i5syQBuKioBJCVVfREUt9//z1hYWFFrvPmm29yxx13OB2fsyQBCGGc2mFBDi13ls8lgJW7E+g4+Sfqj15Nx8k/uVymNnr0aI4dO0br1q0ZNWoUcXFx3Hbbbdx7773ceOONAPTu3Zu2bdvSvHlzZs/+Z4j6evXqcfHiRX7//XeaNWvGs88+S/PmzenevTtpaaZM379/f5YuXZq3/vjx42nTpg0tWrTg4MGDACQmJnLnnXfSvHlznnnmGa6//nouXLiQL87s7Gz69+9PREQELVq0YNq0aQAcO3aM6Oho2rZty2233cbBgwf55Zdf+P777xk1ahStW7fm2LFjLp0jIYRjRvVoQlCAf75lfkoxqkcTtx7HpxKAtYqVMcv3uZQEJk+eTMOGDYmPj2fKlCkA7Nq1iw8++IDDhw8D8OWXX7Jz50527NjBjBkzuHjxYqH9HDlyhKFDh7J//37CwsJYtmyZ1eNVr16dXbt2MXjwYKZOnQrAG2+8QdeuXdm/fz8PPvggf/75Z6Ht4uPjSUhI4Ndff2Xfvn0MGGBqcTtw4EA+/PBDdu7cydSpUxkyZAi33nord999N1OmTCE+Pp6GDRs6fX6EEI7rHRnOpD4tCA8LQgHhYUGEVwmid2S4W49jdD+AEmWtYiUtM5spaw+59cTedNNN+dq+z5gxgxUrVgBw8uRJjhw5QrVq1fJtU79+fVq3bg1A27Zt+f33363uu0+fPnnrLF9ums5106ZNefuPjo6mSpUqhbZr0KABx48f54UXXqBnz550796d1NRUfvnlF/r27Zu33t9//+3kXy2EcKfekeH5rktxcXFuP4ZPJYCSqlixHLI1Li6OH374gS1bthAcHEyXLl2s9louX7583u/+/v55RUC21vP39y+2jsFSlSpV2LNnD2vXrmXmzJnExsYyffp0wsLCiI+Pt3s/QoiSdfXqVWbMmOGRJ3GfKgLyRMVKpUqVSElJsfl+cnIyVapUITg4mIMHD7J161anj2VLx44diY2NBWDdunVcunSp0DoXLlwgJyeHBx54gAkTJrBr1y4qV65M/fr1WbJkCWDqvbtnzx4AQkJCivy7hBAlY+3atYwZM8Zq0a6rfCoBWKtYCQrwd6lipVq1anTs2JGIiAhGjRpV6P3o6GiysrJo1qwZo0ePpkOHDk4fy5bx48ezbt06IiIiWLJkCTVr1iw0jENCQgJdunShdevWPPbYY0yaNAmAhQsX8sUXX9CqVSuaN2/Ot99+C8CDDz7IlClTiIyMlEpgIQy0dOlSqlatSqtWrdy/c2tDhHrrj7XhoH/77Te7hkO9fPmy1lrrFbtO6Vsn/ajrvfKdvnXSj3rFrlN2be8puXG5Ij09XWdmZmqttf7ll190q1atXN6nO+LS2v7Px17eOlyvxOUYics+aWlpulKlSvrpp5/2yHDQPlUHAIUrVsqCP//8k5iYGHJycggMDOSzzz4zOiQhhBusXbuWlJQUYmJiPLJ/n0sAZVGjRo3YvXu30WEIIdwsNjaWqlWrEhUVxebNm92+f5+qAxBCiNIiLS2NVatW0adPH4/NqSEJQAghvNDatWtJTU3N10/H3SQBCCGEF1qyZAnVqlUjKirKY8eQBCCEEF4mt/jn/vvv9+iUqpIADBASEgLA6dOnefzxx62u06VLF3bs2FHkfqZPn87Vq1fzXtszvLS7/f7773z99dclekwhyrrc4h9Ptf7JJQnAQLVr12b+/PlOb18wAdgzvLS7SQIQwv1KovgHfDEB7I2FaRHwepjp372xLu1u9OjRfPzxx3mvX3/9daZOnUpqairdunXLG7o5t4etpd9//52bb74ZMD3yPfTQQzRr1oz7778/31hAgwcPpl27djRv3pzx48cDpgHmTp8+TVRUVN6XpF69ennDQL///vtEREQQERHB9OnT845na9hpS0uWLOHmm2+mVatWdO7cGTANJz1q1Cjat29Py5YtmTVrVt7fv3HjRlq3bp03xLQQwnnp6el5rX/KlfNwS31rvcO89cflnsB7Fms9oYbW4yv/8zOhhmm5k3bt2qU7d+6c97pZs2b6zz//1JmZmTo5OVlrrXViYqJu2LChzsnJ0VprXbFiRa211idOnNDNmjXTWmv93nvv6QEDBmittd6zZ4/29/fX27dv11prffHiRa211llZWfr222/Xe/bs0Vprff311+vExMS8Y+e+3rFjh46IiNCpqak6JSVF33jjjXrXrl36xIkT2t/fX+/evVtrrXXfvn31/PnzC/1NERER+uDBg1prrS9duqS11nrWrFn6rbfe0lqbeh63bdtWHz9+XG/YsEH37NnT5vmRnsDGkrgc4w1xLV26VAN6/fr1+ZZ7oiewbz0B/PgmZBa4481MMy13UmRkJOfPn+f06dPs2bOHKlWqULduXbTWvPrqq7Rs2ZI77riDhIQEzp07Z3M/P//8M4899hgALVu2pGXLlnnvxcbG0qZNGyIjI9m/f3+xM3Vt2rSJ+++/n4oVKxISEkKfPn3YuHEjYN+w0x07dmTw4MF89tlnZGebhs9et24dX331Fa1bt+bmm2/m4sWLHDlyxKFzJYQo3sKFC6lZs6bHi3/A13oCJ59ybLmd+vbty9KlSzl79iz9+vUDTB9iYmIiO3fuJCAggHr16lkdBro4J06cYOrUqWzfvp0qVarQv39/p/aTy55hp2fOnMlPP/1EXFwcbdu2ZefOnWit+fDDD+nRo0e+dT0xRrkQPmdvLPz4JknnTrJ6VQpDHr4Lf3//4rdzkW89AYTWcWy5nfr168eiRYtYunRpXqeN5ORkrr32WgICAtiwYQN//PFHkfvo3LlzXmXqr7/+yt69ewG4fPkyFStWJDQ0lHPnzrFmzZq8bWwNRX3bbbexcuVKrl69ypUrV1ixYgW33Xab3X/PsWPHaN++PW+++SbXXHMNJ0+epEePHnz66adkZmYCcPjwYa5cuVLscNhCiGLsjYV/D4Pkkyz7LYOMbM0jIVtdrp+0h289AXQbZzrRlsVAAUGm5S5o3rw5KSkphIeHU6tWLQAeffRR7rnnHlq0aEG7du1o2rRpkfsYPHgwAwYMoFmzZjRr1oy2bdsC0KpVKyIjI2natCl169alY8eOedsMHDiQ6OhoateuzYYNG/KWt2nThv79+3PTTTcB8MwzzxAZGWlzlrGCRo0axaFDh1BK0a1bN1q1akXLli35/fffadOmDVprrrnmGlauXEnLli3x9/enVatW9O/fnxEjRjhy6oQQFkXTC/dl0qiqH+2uzTQtb+nZZqCGV+w68uOO4aD1nsVav99c6/Ghpn9dqAB2B3cNu+xuMhy0YyQux5TluBwdcj5nfKjW4yvrUyNCtAL9+u3lzY1UQt0WGzIctFnLGM9nVSGET1q5O4Exy/flzT2ekJTGmOX7AKwOQ79ydwLtdTXC1QUW/ZqJBh5pYb4su1g0bQ/D6wCUUv5Kqd1Kqe+MjkUIIVwxZe2hvIt/rrTMbKasPWRz/XcyY7iqA1m4L5P2tf1oVM2fNMq7XDRtD8MTADAcOODKDkxPOMLbyOcifM3ppMKt6opbviqnEwPP9Gb32RweaRHIqZzqjM54ukRKKgxNAEqpOkBP4HNn91GhQgUuXrwoFxsvo7Xm4sWLVKhQwehQhCgxtcOCnFr+719TQPnx3g2f0SljBjsq3+mxGC0pIy+cSqmlwCSgEvAvrXUvK+sMBAYC1KhRo+2iRYsKvk/FihWLbTOrtUYp5a7Q3aYsx5Wdnc2VK1fcmpxTU1PzBtPzJhKXY8pqXElpmSRcSiPH4jvvpxThVYIICyo8qmdSWiYnL17hjZeeo3qNmgwd/YbN9V2JLSoqaqfWul2hN6zVDJfED9AL+MT8exfgu+K2sdYKyF5ludWBJ0hcjpG4HFOW43K0FdBbs2M1oKv3Glnk+mWtFVBH4F6l1N1ABaCyUmqB1voxA2MSQgiX9I4Mt9rix5ajm1dTuXJl/lj8JsHBwR6MrDDD6gC01mO01nW01vWAh4Cf5OIvhPAlKSkpLFmyhH79+pX4xR+8oxWQEEL4pKVLl3L16lX69+9vyPG9oiOY1joOiDM4DCGEKFFz5syhcePG3HLLLYYcX54AhBDCAEePHmXjxo3079/fsJaAkgCEEMIA8+bNw8/PjyeeeMKwGCQBCCFECcvJyWHevHnceeedhIfb32LI3SQBiLLBzXM9C+FJb82O5eTJk+wq34qOk39i5e4EQ+LwikpgIVySO6FG7jwPySdNr0FGfhVeZ+XuBN778FP8ylckqFGHYkcM9SR5AhAetXJ3Ah0n/0T90as9d6fjgbmehfCUt5dtJeXgL1SM6IYqFwgUPWKoJ8kTgPAYR8dGd5qH5noWwhOObV4NOVmEtI7Ot9zWiKGeJE8AwmMcHRvdaR6a61kId8vJySFt3zrK17mRwOrX5XvP1oihniQJQHiMo2OjO63bONPczpbcMNezEO62YcMG0i8mULVtz3zLgwL8GdWjSYnHIwlAeIyjY6M7rWUM3DMDQusCyvTvPTOkAlh4nVmzZlG1alWmjx5EeFgQCggPC2JSnxYlXgEMUgcgPGhUjyb56gDAg3c6Mtez8HLnzp1jxYoVvPDCC8R0aEhMh4ZGhyQJQHhO7h3NlLWHOJ2URu2wIEb1aGLInY4QRps7dy5ZWVkMHDjQ6FDySAIQHuXo2OhClEU5OTnMnj2bzp0707RpU6PDySN1AEII4WHr1q3j+PHjDBo0yOhQ8pEnACGEcNHK3QlFFnXOmDGDmjVr8uCDDxoYZWHyBCCEEC7I7fCYkJSG5p8Oj7m93o8cOcKaNWt47rnnCAwMNDbYAiQBCCGEC4rr8PjRRx8REBDgdcU/IAlACCFcUlSHx8uXLzNnzhz69etHzZo1Sziy4kkCEEIIFxTV4XHevHmkpKQwbNiwEo7KPpIAhBDCBaN6NCEowD/fsqAAf0be2YgPP/yQDh060L59e4OiK5q0AhJCCAcVbPXzQNtwNhxMzNcKqPzZvRw5coQ33njD6HBtkgQghBAOsDbM+bKdCYXG84mOfppatWrxwAMPGBVqsSQBCCGEA6y1+rkz+790+PZ5+PYChNZhb90nWbt2LRMmTPC6pp+WJAEIIYQDCrb6uddvE5MDPieYDNOC5JNM/WocFYPKM3jwYAMitJ9UAgshhAMKtvp5uVwswSoj7/WfyTl8szedZ9tXpGrVqiUdnkMkAQghhAMKtvqprS7ke3/61gy0hhcjM/MtL5H5sR0kRUBCCGHL3lj48U3T/NKhdaDbOHpHmuadyG0FdF5dQ00SAbiUpvlsVwYPRQRw/fX/TPlYYvNjO8iwBKCUqgt8BdQANDBba/2BUfEIIUQ+e2Ph38Mg01zmn3zS9BroHRnzz4V775W89WbuyCA1A164NYThifewY/JPjOrRhNdX7bc5XISRCcDIIqAsYKTW+kagAzBUKXWjgfEIIcQ/fnzzn4t/rsw003JL5ilJ04PD+WBbBrc3DGJh1UF8m9OJhKQ0Ri3ZQ1Ja/uKgXG6fH9tBhj0BaK3PAGfMv6copQ4A4cBvRsUkhEdYKUaQ6StLgeRT9i9vGcNXW5M4d2UQut1Yfs9pnfdWZo62eQi3z4/tIKW17eBKLAil6gE/AxFa68sF3hsIDASoUaNG20WLFjl1jNTUVEJCQlwL1AMkLseUurjSLpmKDnTOP8uUn2ni+qAqxsVlsFIR1/nfIDuj8Er+gXBt/sKKrKwsHn/8cQKDKzHyjXdRStl1vLpVgwkLCnA8NgdFRUXt1Fq3K7jc8EpgpVQIsAx4seDFH0BrPRuYDdCuXTvdpUsXp44TFxeHs9t6ksTlmFIX17QIUwIoKLQujPjVuLgMViri2ns+fx0AQEAQ3DMDWnbJt92cOXM4e/YsTZ94nvd/te+CXiU4gN2P3ulcbG5iaDNQpVQApov/Qq31ciNjEc7zxuZtXsORYgThXcxl+4TWBZTp33tmFCq+y8rK4u2336ZNmzZMHP5koYHhrAkK8Gf8Pc09FLj9jGwFpIAvgANa6/eNikO4xlubt3mN0Do2ngDqlHwsvsqVOpiWMcWu+80333Ds2DFWrlzJfW3qoJTijX/v59LV/BW/ClNzx3ArU0YaxcgngI7A40BXpVS8+eduA+MplYy++y5uNiSf122cqdjAUkCQabnwvNymnMknAf1PU869sW7ZfXZ2NhMmTKBVq1bce++9gOnGJziw8L117sV/8+iuXnHxB2NbAW3ClBSFk7zh7ruo2ZBKm+Im9nZK7t2jtAIyRlFNOd3wGSxevJjDhw+zdOnSfBW/peX/heGVwMJ5Rd19l1QCqB0WRIKVL7XRzdscZW8ydSpJ2FGMIDzEg3UwWVlZvPXWW0RERHD//ffne6+0/L+QsYBKMW+4y7A1G9KoHk2K3dbo4itL9hRl5SaJhKQ0NP8kCan09mK26lqKqIOx93s5f/58Dh48yJtvvomfX/5LqSv/L0qSPAGUYs7eZbizqCN3O0f3V9wdt0eKY4pgTzItKkm83UHupbxSt3HWm3LaqINJSstkzI/FPwmmp6czfvx4brrpJnr37l1oP87+vyhpkgBKsVE9muS7iELxdxnWLrwjFsez44+/mNC7hVNx9I4Md/iLXdwdt63kEOZUhMWzJ5kWnSQqeigy4RIH62DOJaeTlpk/mVsrVp05cyYnT55kzpw5Njt9OfP/oqTJbUsp1jsynEl9WhAeFoTC1MKg4LR0BVm78Gpg4dY/S7Qoo6iLqREti+x5ZLf1ZOVt5bqigJYxpk53ryeZ/i2iPiYjO8fqcsvva0pKCm+//TZ33HEH3bp1c3u4JUmeAEo5R+8ybF14NXhN5bERd9r2PLIX+cSVfMQjcYmSFehv/Z7YMsm///77XLhwgYkTJ5ZUWB4jCcDH2LrwguOVx6mpqSQkJPDt/47w1c+HSExKoWpQOfreVI+o5uFUqFCBa665hmuvvZZKlSrle1Qu6mI6Ze0hp1tQuFJ3UFwyLSpJxMVJAij19sbSiESOlx/LaV2Nd7NiWJXTKd+TYGJiIlOnTqVPnz60b9/e4IBdJwnAx4zq0YQRi+OxNgSgrQvshQsX2LJlC3v37mXPnj0cOHCAkydPkpycXGjd88Bb8+GtAssrVKhA/fr1ady4MY0bN6ZJkyYMjLiOJUf9OZOSUehi7cyddkn0iygN5brCCeYOY34NRuOnNHXUBSYHfE7VgEBa9xyY95mPGzeOtLQ03n77bYMDdg9JAD6md2Q4S3b8yeZjfxV6L6rpNQBkZGTwv//9j+XLlxMXF8e+ffvy1mnQoAHNmzenS5cu1K1bl893JpGUE4hfQHlUufLg5w/ZmVQL8mfK/U1JTEzk3LlznD17lhMnTnD48GHWrFlDRoZplMWgoCDatGnDrbfeSnBiFlevVnH6Ttsb+kWIUspKh7FglcHrFZdB5BsA7Nu3j9mzZzN06FCaNm1qRJRuJwnAB/1+sXDxitY5LF/1PX8un8KqVatITk4mODiYjh078vDDD3PbbbfRqlUrKlWqlG+7T0evJtjKMa4Ad911l9XjZ2dnc/z4cbZv38727dvZtm0b06dPZ8qUKZQvX56OHTty1113MT+mDw0aNMjbbuXuBM6dTWHA6NVWi3c81S/itZX7+GbbSbK1xl8pHr65rtMtpoSXKqbDmNaakSNHEhoayvjx40swMM+SBOCDLC+I2Vcukbp3PSl71pKdfI4jQZXofMdd3HFLa0aMGEGFChXy1jWVr2/Pd1fuTF8Ef39/GjVqRKNGjXjkkUcAuHLlChs3buSHH35g/fr1jBo1ilGjRhEZGckDDzxAWPPOfLTzCkOa5qDxyyve2fHHX2w4mGizXqO4WIrz2sp9LNj6Z97rbK3zXksSKEOKGbTv+++/Z/369UyfPp1q1aqVcHCeIwmgFHFX56jaYUH88ecfXN62nNS969BZGVS4viVVbn+S4Ea3cDqoAs1a++e7+L+2ch8Lt/6ZV3eQewF+oG04y3YmONQXwZqKFSsSHR1NdHQ0ACdOnGD58uUsW7aM1157DYDA2k3YHN2VnMpd8CtfkbTM7HwxWeNq78tvtlm5KJiXSwIoQ3I7jFkydxjLzMxk5MiRNG7cmCFDhhgTn4dIAigl3FXBee7cOUJ2zCFhxTcAVGzeldCbHyCg2j9d49MyszmX/M9Qtit3J1i90KZlZrPhYCKT+rRwe4/H+vXrM3LkSEaOHElCQgItHh9H6r4fWPzlp6hyXxLc+BZCIntSPrypzY44/koV2y+iONk2ZsyztVyUUrl9Aw5ewDT2/z8dxj6ePp1Dhw6xatUqAgLsm+yltJAEUEq4WsGZlpbGtGnTmDRpEunp6dz14GOcr9+DC1S2ur5lh5gpaw/ZvMs+nZTm8ZYx4eHhNOv+KKfa30/figeZtXIDV377L1d+iyOw5g1UansPFZt2RpXL/58zR2uX4/JXyurF3t/OKf9EKdIyBv6Kg5ikvEWnT59m3Lhx9OjRg169ehkXm4dIT+BSwpUKzn+SPngAACAASURBVB9//JGIiAjGjh3LHXfcwf79+/l+8Vx2TH6YcBvl45YdYoo6Rkn1gh3VownBgeW4vmEjqnUfQp0hc6nafQg6828urp7GqU/7k7RxIdlX/2ma6o7YHr65rkPLRdny0ksvkZGRwUcffWT3PL+liSSAUsKZYQguXbrEU089xR133IG/vz8//vgjK1asoHHjxnnr2BoCoUboP+X/to6hzNuXhNxhLwL9/VBA3WurMnjwczR4bhbX9ptA+dpNSP7lGxJmPsVfP35GdsqFvGatrpjQuwWPdbgu747fXyke63CdlP/7gPXr17N48WLGjBnDDTfcYHQ4HiFFQKWEowO//fe//+XRRx/l7NmzjB49mnHjxhEUVPhCbqvNfZhFhytrx1bAox2uK9H29b0jw4lLPsKJyV3ylrW7vipT1pYnoV5rMi78yeVtS0nZ+W9Sdq3moy13EK7GMvi+Ti4dd0LvFnLB9zHp6ekMGTKEG264gVdeecXocDxGEkApYe/wsrmTVEyYMIGGDRuybds22rZtW2QLImtl+JYdrmwdG6Dj5J8MHe42N/aOk38igeuo3vMlwjo9SvK25STvXceQPuvYOWAA48aN47rrrvtnQ1fmiRWlXzGf/7vvvsvRo0dZt25dvtZwZY0kgFKk4IU4d3TM3OUXLlygb9++xMXF8cQTT/DRRx9RqVIlt7QgKpgkvGE6SkuW9RTlQmtQrftgwm59iMvbljJ//nzmz5/PoEGDePXVV6l5/uf8Y8TnzhMLTieBpLRMw5OhsFPuPMHWPn+u5bfffuPtt9+mX79+3HnnnYaFWRKkDqAUKWpGqr1799K+fXs2/7KFGx58mZ9rxRD98fa8O393D6/sbZPBW6un8A+pQsQDwzh69ChPPvkkn3zyCQ0bNmTMiKH8dflq/pVz54l1wsrdCSRcSpOZwkqLIuYJzs7OZsCAAVSqVIkZM2YYE18JkgRQiti66I6Z9iW33noryalp1H70HTIbds53IXLX6J/2bGvUpNdFjedft25dZs+ezYEDB+jduzfv/HSBhjNSmLblbzKyLZp4OjlP7JS1h8gp0FTUyGQoilHEsA+xsbH873//4+OPP+baa68t2bgMIAmgFLF2cU2J/w8HF7zOjTfeSMNnZ8C1+VsrpGVm22yz7kozSW+bHMWeyXEaNWrEwoULif9XQ24O9+eldX/T/JMrLD+Qida6yHlii+JtyVAUw8bnfCD9GubMmUOfPn2IiSlcFOhNc1i7iySAUsTy4qq1JvmXxfy19iPCGrcnLi6OCznWJ0vJ1trtE1R746TXvSPD2Ty6Kycm92Tz6K42y+BbPj6R/wyozppHgynvDw/EpnH7vHR21HzUqeN6WzIUxeg2zjTMg4Vs/wo8tUYRFBTEJ598UqjNf1HFr6WZJIBSJPeiq7Xm0k+fk7RxPpUjuvLZV4sIDg62ecHJvRt2ZOrI4jgzHaXXaBkD98wgum194p8LYVbfWhxKCab9I6/y2GOPcfKk9fF/bBnVowl+BS4YRidDZ5XFu9xCzJ8/oXUxDftQl0nnurB17xFeeOEFatSoUWgTb6vzchdpBVSK9I4MR2vN0BeGkbLjW2rd2oePZkynT1tTr9Si+gp4YriGUj05SssYaBlDOWAg8NDly7zzzju89957rFixgrFjxzJy5EjKly9f7K56R4az8uxvhIf5l+pWQN7WssujzJ8/wJYtW3h91G08/PDDNuf4LavFfPIEUFL2xsK0CHg9zPTv3liHd6G1ZuOC9znzywpGjBhBwqaleRd/KOV35QarXLkyb7/9NgcPHiQ6OpqxY8cSERHB6tWr7do+LCjAruInb1ZW73KLkpyczCOPPELdunX59NNPbQ73UFaL+eQJoCQU0+7YnmGetda88sorTJs2jeHDh/Pee+9Z/bKW6rtyL1CvXj2WLVvG+vXrGTZsGL169aJXr15MmzatzA4HkKus3uXaorXmueee4+TJk2zatInQ0FCb6zraE7+0MPQJQCkVrZQ6pJQ6qpQabWQsHlVEu+OktEy7KpemTp3KlClTGDJkCNOmTbNaSVXmy25L0J133smePXuYOnUqcXFxNG/enLFjx3LlyhWjQ/OYsnqXa8vcuXNZtGgRb7zxBh06dChy3bL6dF1sAlBKvaCUquLuAyul/IGPgbuAG4GHlVI3uvs4XqGIdsfnktOLfexesGABL7/8Mv369ePDDz/0mRYKRgsMDGTkyJEcPnyYfv36MXHiRJo2bUpsbKyp2WgZY0/LrrJyo7F7926GDBlCVFQUo0fbd+9pbyuz0sSeJ4AawHalVKz5jt1dY6LeBBzVWh/XWmcAi4D73LRv72KrfXlonXzj7lvKfexet24dAwYMICoqinnz5uHnV/gjK4my27LyH98ZtWrV4quvvmLTpk1Ur16dfv360a1bN/bv3290aG5V3F1uWbnR+Ouvv3jggQeoXr06ixYtwt/fv/iNyihlz52M+aLfHRgAtANigS+01secPrBSDwLRWutnzK8fB27WWj9fYL2BmBpqUKNGjbaLFi1y6nipqamEhIQ4G65r0i6Zyv21xcVe+UFoXS6k+3HmauFNAv398E89x7Bhw6hduzbTp0+3Gf++hGSrywFahNsu1yyK5flKSssk4VJavt6ufkoRXiWIsKCSnSHJ0M8R04T23333HV988QVXr16lT58+PPnkk2itDY3LFneer0NnU6zesAT6+9GkZiXD4iooKS2Tc8npZGTnEOjvR43QCnnf0+zsbF599VV2797N9OnTufHG/IUORn+/iuJKbFFRUTu11u0KLrcrAQAopVphSgDRwAagA7Bea/2yMwHZmwAstWvXTu/YscOZwxEXF0eXLl2c2tYtbIw+uHLNesb8kl2ocml0VG3GP3Uv2dnZbNu2jdq1a9vcdcfJP1kd7iE8LIjNo7s6Fa7l+fLE/p1l+OdoduHCBV599VU+//xzatSowVNPPcWECRO8btIQd56v+qNXW50ZTgEnJvd0aF+e+hwLNmUF0/+n3CeZsWPHMnHiRGbOnMmgQYNKLC53cCU2pZTVBGBPHcBwpdRO4F1gM9BCaz0YaAs84FQ0JgmA5bRKdczLyqaWMTDiV3g9yfSvuQ1yWFBAocfut+5pypdvvMC5c+dYsWJFkRd/8HyvXF9rHWKP6tWrM3v2bLZt28Z1113HxIkT6dy5M3v27DE6NI8pDZXERRWHzps3j4kTJ/LMM88wcOBAgyL0LvY0A60K9NFa/2G5UGudo5RyZZLM7UAjpVR9TBf+h4BHXNhfqVWw6ebQoUP5+eefWbBgAe3aFUraVreH4ucKcFbtsCCrTwDe9B/fKO3bt2fLli288sorzJ07lzZt2jBkyBDeeustwsLCjA7PrbyhKWTBJtNRTa9hw8HEvNe2Bj48vvd/PDtuHN26dbM61IOvKvYJQGs9vuDF3+K9A84eWGudBTwPrAUOALFa67JVq+aEzz//nE8++YRRo0bx6KP2j03jSAsFRyt0vXHcH2/i5+dHz549OXToEIMHD+aTTz6hcePGfPnll+TkWK/kL42MbgpprRJ6wdY/aXt5PRsDh7Ex7X42BQ7jXr9N+bbL/CuBCysncsMNN7B06VICAoqot0q75HKHzdLE0I5gWuvvge+NjMGb7Nmzh+eff57u3bszadIkjxzDme7+nn7CKCuqVq3KRx99xNNPP83zzz/P008/zWeffcZHH31E27ZtjQ7PLYzsaGiteOdev01MDvicYJUBQB11gckBn0MmrMrpRFbKBRJjxxFcPoDvvvuu6KeyvbGQfNbUYAPcMlGQt5OhILxEamoqMTExVK1alfnz53usaZqzTUbLYhtoT4mMjGTTpk3MmzePEydO0L59e5577jkuXrxodGilmrU6p5fLxeZd/HMFqwxeDVxCztVk/lo6jnJZV/hx/VoaNGhQ9AF+fDN/Sz1waaKg0kASgBfQWjN48GCOHj3K119/7dGJKKRCt2QopXjiiSc4dOgQw4cP5/PPP6dx48bMmjWL7OzsIrf1xj4XDsfkhrGvCrJW51RbXbC6bvDfidTYMg19+TxrVn9nV11aUR02yypJAF5g3rx5LFiwgPHjx3u8CVppaMlRloSGhjJt2jTi4+Np0aIFzz33HDfffDNbt261ur43drZyOKbcsa+STwL6n6IUF5NAVNNrCi07rasXWnb5b03PxVnEx8ezdOlSbr/9dvsOUESHzbJKEoDBEhISGDp0KFFRUYwdO9bjx5MKXWNERESwYcMGvvnmG86cOcMtt9zC008/zfnz5/Ot540jcjocUxFjX7liw8HEQsvezYrhqg7Me52Urum+IJ0tf6bz9ddf07OnA/0Tuo0zddC0FBBkWl5GSQIwUHZ2NpMmTSIwMNCj5f6WjG7J4a2KK+JwR7GMUoqHHnqIgwcP8vLLL/PVV1/RuHFjZsyYQUaGqRzbk0V0zv4NDsfkoaIUa8dbldOJ0ZnPcCqnOhevarotzGTX2RyWLl1G3759HTtAyxjTJDEWE8Vwz4wyWwEMMhy0oaZOncr+/ftZsGAB289DzPyfSqSVjQwZnV9xLaPcPVFKpUqVeOeddxgwYADDhg1j+PDhfPjhh0yePJlaoWGcTk4vtI2rRXSu/A0O9wMJrfNPS5qCy11gK45VOZ34JaMV6d9P4GjiUVZ+u4q7777buYMEVTF11PQR8gRgkH379jFu3Dg6d+5McLPbva7c15cUV8ThqWKZpk2bsnbtWlavXk358uV58MEHubR4NPps/v26o4jOlb/B4WJDK3PuuqMoZVSPJgT4Fe7AlXPhOCe+fJGTJ0+yZs0a5y/+PkgSgAEyMjJ44oknCAsLY8SIEUxdd9jryn19SXFFHJ4sllFKcffddxMfH89nn31G0rlT/DlvJCmr3yHrrwS3FdG58jc4XGxoZc5ddxSl9I4MZ0rfVvkGIPRPiCfxmzEElw9g8+bNREVFuXQMXyNFQA6wZ+Yue7zzzjvEx8ezcuVKQkNDOZ1kfZIRaZppm7s+Cyi+iKMkhsIoV64czzzzDA899BDvvfceU6ZM4fLBLdz77LO0u+ZVl/fv6t/gcLGhxZy77pQbh9aayZMn89qUcbRo0YLVq1cTHm49Pnd+V8oaeQKwk7ua5x0+fJi3336bfv36cd99pukPpGmmY9zdVLK4Io6SbDkVEhLC+PHjOXr0KM888wyff/45N9xwA8OHD+fMmTNO77cstf5KTk6mT58+vPrqq8TExLB58+YiL/5SvGqbJAA7uaMcWGvNoEGDCAoKYvr06XnLy9J/zpLg7jL54oo4jGg5VbNmTT799FMOHz7M448/zscff0yDBg0YOXIk586ds2sfK3cncOhsCvVHr2bK2kM80Da81Lf+2rp1K23btuXf//4306dP5+uvv6ZixYo217f1XXl91X6v62xnBCkCspM7yoHnzp1LXFwcs2bNombNmnnLZawdx3iiTL64Ig6jWk7Vq1ePzz77jNGjRzNhwgSmT5/OzJkzefbZZxkxYgTXX3+91e1y73yHNM1B40dCUhrLdiZ49KLvyaKWzMxM3nzzTSZOnEjdunWJi4ujU6dOxW5n6zuRlJZJUlomkL9FVNkav7V48gRgJ1eLaRITE/nXv/5Fx44deeaZZwq9L2Pt2K/MFpkVMXxCw4YNmTNnDgcOHKBv3758/PHHNGzYkEcffZT4+PhCu/JUyyVbfQk8WdSybds2brrpJiZMmMATTzzB3r177br4g/3fCV9tdCEJwE6uFtP861//IiUlhdmzZ1ud11fYr0wWmVkZPiFt+fMMf3VMvgtt48aNmTt3LsePH+fFF19k1apVREZG0r17d77//vu8cYY88ZRk7SK/acUnXH2nKfd+25z1ami+oZhdvaj+9ddfDBo0iFtuuYXz58+zfPly5syZQ+XKle3eh7Xvii2+2OhCrkR2cqUceMuWLXz11VeMHDmy0BykwnFlsjezleETgvibUeVird5N161bl6lTp3Ly5EkmTZrEvn376NmzJzfccAOTJ0+mernCncnAtaekgk8V9/pt4k01m+C0M/ihqeNnGorZMgk4c1G9evUqU6ZMoVGjRnzxxRe8+OKLHDhwgPvvv9/hfVn7rlQJtj4fQKl/gnSC1AE4wJly4JycnLyJ3UtirB9fUeZ6M9sYJqG2Mg0hnXs3XfBvDgsLY/To0bz00kusWLGCTz/9lDFjxlCuXABBjW/lt5zOaNqh/PxdfkoqeDG3NRTzy+ViWZVhKqJx5KJ69epV5s6dy4QJEzhz5gzR0dG88847tGzZsugNbcy3navgd8XWvMGjejSB5CN2x1sWyBOAh82dO5cdO3bwzjvvEBISYnQ4wlvZGCbhtK72z+9F3E0HBgbSr18/4uLi+O233xgyZDA5J+OZOeUtTn38JH///AUDGmVyX+ui55cuSsGLua2hmHOTlr0J5+zZs/zf//0f1113HUOHDqVBgwb897//Zc2aNfZd/B0cebRMPkE6SZ4APCg5OZkxY8Zwyy23ODS9oyh9ktIy6TjZgbGcCt61NuoOe77OVwx0VQfybtY/d7L23k03a9aMDz74gHfffZcpU6YQHx/Pd999xytPruDDsXXo1asX99xzD127dqVChQp2/40F5wQ+ratTx0oSOK2rEV7MOcjIyGDp0qV89dVXrFmzhuzsbO677z5eeuklOnXqZP+cvUWNPFpER7Qy9wTpJEkAHvTWW2+RmJhIzb6v02DM99K8s4xauTuBhEtpJCSZKhuLHWgt964198KVfNJ08W/1CBxZh04+xWldjXcyY1iVYypKcab4pnz58nTq1InXXnuNpKQkVq5cyapVq5g/fz4zZ84kODiY22+/nc6dO3PbbbfRrl07ypcvb3N/BZsrfx74GK/pmZTLtqhvCAiizj2T2Nyya6Htz58/z5o1a1i9ejXff/89V65coXbt2owYMYJnn32WRo0aOfT3AT45iYs7SQLwkKNHjzL9gw+o3Ko7ySHXAa6PIim805S1h3iors63zFaZPWD7rvXIOhjxKwrYvjuBnWsPodzUpj4sLIz+/fvTv39/0tPTiYuL47vvvmPDhg2MGTMGgAoVKtC6dWtatmyZ91O/fn1q1aqVN1R5/jvnnrC3eaHyd92iL2dOn+bQoUMcOHCAbdu2sWXLFo4cMZWv16pVi9tvv53hw4fTrVs314ZB99DIo75CEoCHjB07Fq3KEdIxf9FPkRcGUSqdTkqDujaWW2PHXasniygqVKhAdHQ00dHRAFy4cIFNmzbx888/s3v3bpYsWcLs2bPz1i9Xrhzh4eHUqlWL0NBQKleunFeflZOTQ3Z2JJcvN+TChQtcmPl/nD79NKmpqXnbX3vttdxyyy089dRTdO/endatW/Pzzz+7Z/a7buPyP01BmZ/ExZ0kAXjA9u3biY2NJfTWhygXUrXQ+77Y3rgsM5XNp9hYboWX3bVWr16d3r1707t3b8A0ZMkXa3fw/uIfOX/mFBX+vkS94HTK/Z3MpUuX+OOPP0hNTUUphZ+fH35+flSqVInq1avTunVroqOjady4MU2aNKFx48bUrVvX/jJ9R+WW8xfRCkjYJgnAzbTWvPLKK1SvXp0Gdz7COSvNsX2xvXFZNqpHExIO7My3rNjx8r34rvXb+NO8t/kiaTVaEFKjBQDnAvy9t6WMh0Ye9QXSDNTN/vOf/7BhwwbGjRvHmPvalL0eq6KQ3pHhhFcJMny8fHfxxnmJhWfIE4AbZWdn88orr9CgQQMGDRpEYKBpsmoZ5K3sCwsKYPPoLvZv4MV3rZ6cAEd4F0kAbvTNN9+wb98+vvnmm7yLv7Q3FqVNSUyAI7yDFAG5SVZWFm+88QYtW7YkJsY77+yEsEeZHGxPWGXIE4BSagpwD5ABHAMGaK2TjIjFXb7++muOHj3K8uXLZbRPUarJ/BS+w6gioPXAGK11llLqHWAM8IpBsbgsKyuLt956i9atW+c1pROiNJOiS99gSALQWq+zeLkVeNCIONxlwYIFHD16lJUrV3quvbMQbmDvrF0ykbpvUFrr4tfyZABK/RtYrLVeYOP9gcBAgBo1arRdtGiRU8dJTU31yGicWVlZPPnkk1SsWJFZs2Y5nAA8FZerJC7HlIa4ktIySbiURo7F/3k/pQivEkRY0D9j5Nu7nrvi8ibeGhe4FltUVNROrXW7gss99gSglPoBqGnlrbFa62/N64wFsoCFtvajtZ4NzAZo166ddrb7eFxcnHu6nhfw5Zdfcvr0aVatWkVUVJTDd06eistVEpdjvDmupNBGTFl7iISkDKDwuDvhYf75mrB2nPxT3sB2Ra3nalzeer68MS7wTGweSwBa6zuKel8p1R/oBXTTRj+GOCkrK4sJEybQrl07evXqVWiiCRn8TRgtKS2TMT/uK9Sxy1LB9v3SD8B3GNUKKBp4Gbhda33ViBjcYfHixZw4cYLytz1FgzHf46cU2dqBUSGF8LBzyemkZRbdKq3QRC/e1A+gmNm+hGuMaq/4EVAJWK+UildKzTQojmKt3J1Ax8k/UX/06nyTc2utefX1twi85nqu1myFhkIX/1xy5ySMkpGdU+T71tr3e00/ACdm+xKOMaoV0A1GHNdRRRXpBJyO58+jh6jW8yWUcuwOS4iSEuhv+7tpa9Yur+kH4ORsX8J+MhREEYoaFMv/+3fwr3wNFZt1LnIf0oNSGKlGaAWCArILTYBe3MieXtEPQGb78jhJAEWwVXRz/NednN24kfq9hpLjX/gU+itFjtbSflqUGFutz8KCApjU50bD7+YLxdfKdqV0Hi+bN6EskgRQBFuVYX/vWkG1atWYOGY4b6w55vDdlRDuVFRRZRjuuZt3pWOYtfgSLmWzcndC0fvw8nkTygIZtKYI1irD/C6d5NKBLQwbNoyHbm3EpD4t7B8HXggP8PT4/bkX8ISkNDT/JJjcBhHOxJejdfHxefm8CWWBPAEUwVplWOD+HzkXHMzQoUPz1pELvjBS0e32K7q8/6ISjD3ffZf6FXjxvAllgSSAYlhe4M+ePct1475l0KBBVKtWzeDIhDDxdLt9VzuGeVW/ApGPFAE54NNPPyUrK4thw4YZHYoQeTzdbt/WhdreC7jVolSlpHWcF5AEYKf09HQ+/fRTevbsSaNGjYwOR4g8vSPDPVoX5WqCsRZfeJUgKTr1AlIEZKdFixaRmJjIiy++aHQoQhTiybood3QMKxhfXFycu8MUTpAEYAetNdOnTyciIoKuXbsaHY4QJc5WgpF5A0o3KQIqyt5YmBbBfwdUYs+ePQx/4FaZ8EUIM1ebhwrjSQKwxWIgqulb/6ZakOJRVspAVEKYebr/gfA8SQC2mAeiOn4ph1WHsniuXQBBpJuWCyFk3oAyQBKALeYBpz7dnoGfgsHtAvMtF8LXudo8VBhPEoAtoXVIz9LMic+kd9NyhFf2y1suhPCieQOE0yQB2NJtHEsPKS6maZ7LvfuXgaiEyOPp/gfC86QZqC0tY5h5dByNrkmna/1ypoGoZDo6IfKRsbBKN0kANuzbt4/N8YeYOnUqfiNHGh2OEEK4nSQAsDrx9MxZ/6V8+fL079/f6OiEEMIjJAHktvfPnXQi+SSpS59n/rxUYmJiZNRPIUSZJQnAysTTX+9OIeVKOoMHDzYoKCGE8DxpBVSgXb/Wmpk7MmhZw48OHToYFJQQQnieJIAC7fp3n81h99kcBnWqIeP+CCHKNEkA3caZ2vebfbk7gwrl4JEX3zYwKCGE8DxJABYTT6dnwcJfs+nTvSNhnQYYHZkQQniUJAAwJYERv7IyYhZJaTk89dIbRkckhBAeJ62A+GdSi12z3qV8WA2SwxobHZIQQnicoU8ASqmRSimtlKpuVAy5k1r88ccfpP8eT4Xm3Ri7cr9MaiGEKPMMSwBKqbpAd+BPo2KAfya1SP31R0AT0qKbTGohhPAJRj4BTANeBrSBMXA6KQ2tc0jd9wMVrm9FudAaecuFEKIsU1qX/PVXKXUf0FVrPVwp9TvQTmt9wca6A4GBADVq1Gi7aNEip46ZmppKSEhIoeWHzqbw6749fDTx/3hiyAja3Xo7AIH+fjSpWcmpY7kjLqNJXI6RuBwjcTnOldiioqJ2aq3bFXpDa+2RH+AH4FcrP/cB24BQ83q/A9Xt2Wfbtm21szZs2GB1+Ypdp3TliK5ala+o6760TF//yne66Wtr9Ipdp5w+ljviMprE5RiJyzESl+NciQ3Yoa1cUz3WCkhrfYe15UqpFkB9YI+5p20dYJdS6iat9VlPxZMrt8XP6aQ0aocF8ULnumQc28q1raLwDyhP7bAgRvVoImOcCyHKvBJvBqq13gdcm/u6uCIgd8pt8ZOWmQ1AQlIaL035nPS0q6yd8jKdO3f2dAhCCOE1fKojWG6LH0uX9v5IYNi1dOrUyaCohBDCGIYnAK11vZK4+4fCLXuyr1wi/cRugprejp+f4adCCCFKlE9d9WqHBeV7feXARtA51O8QbVBEQghhHJ9KAKN6NCEowD/v9ZXf4ihfsyHjn+hhYFRCCGEMn0oAvSPDmdSnBeFhQWT9lUDGmcM8/PAj0uJHCOGTfCoBgCkJbB7dladrJaCU4u1/PWd0SEIIYQifSwBg6vy2cOFCunXrRu3atY0ORwghDOGTCWDbtm0cO3aMxx57zOhQhBDCMD6ZABYuXEiFChW4//77jQ5FCCEM43MJIDs7myVLltCrVy8qV65sdDhCCGEYn0sAGzdu5Ny5c8TExBgdihBCGMrnEkBsbCzBwcHcfffdRocihBCG8qkEkJWVxbJly+jVqxcVK1Y0OhwhhDCUTyWAn3/+mfPnz0vxjxBC4GMJIDY2looVK3LXXXcZHYoQQhjOZxJAdnY2y5Yt45577iE4ONjocIQQwnA+kwDi4+O5cOGCFP8IIYSZzySADRs2EBISQnS0DP0shBDgD/D+IgAAB7pJREFUIwkgMzOTjRs3cu+99xIUFFT8BkII4QNKfE5gI2zYsIHLly8XW/xTcMJ4mRxeCFGW+UQCyO381aOH7YlfrE0YP2b5PgBJAkKIMsknEsC4ceO44YYbqFChgs11rE0Yn5aZzZS1hyQBCCHKJJ+oA7juuuvo0KFDkesUnDC+uOVCCFHa+UQCsEfBCeOLWy6EEKWdJACzghPGAwQF+DOqRxODIhJCCM/yiToAe+SW80srICGEr5AEYKF3ZLhc8IUQPkOKgIQQwkdJAhBCCB8lCUAIIXyUJAAhhPBRkgCEEMJHKa210THYTSmVCPzh5ObVgQtuDMddJC7HSFyOkbgc461xgWuxXa+1vqbgwlKVAFyhlNqhtW5ndBwFSVyOkbgcI3E5xlvjAs/EJkVAQgjhoyQBCCGEj/KlBDDb6ABskLgcI3E5RuJyjLfGBR6IzWfqAIQQQuTnS08AQgghLEgCEEIIH1WmEoBSqq9Sar9SKkcpZbO5lFIqWil1SCl1VCk12mJ5faXUNvPyxUqpQDfFVVUptV4pdcT8bxUr60QppeItftKVUr3N781VSp2weK91ScVlXi/b4tirLJYbeb5aK6W2mD/vvUqpfhbvufV82fq+WLxf3vz3HzWfj3oW740xLz+klLI9KbVn4npJKfWb+fz8qJS63uI9q59pCcXVXymVaHH8Zyzee9L8uR9RSj1ZwnFNs4jpsFIqyeI9T56vL5VS55VSv9p4XymlZpjj3quUamPxnmvnS2tdZn6AZkATIA5oZ2Mdf+AY0AAIBPYAN5rfiwUeMv8+ExjsprjeBUabfx8NvFPM+lWBv4Bg8+u5wIMeOF92xQWk2lhu2PkCGgONzL/XBs4AYe4+X0V9XyzWGQLMNP/+ELDY/PuN5vXLA/XN+/EvwbiiLL5Dg3PjKuozLaG4+gMfWdm2KnDc/G8V8+9VSiquAuu/AHzp6fNl3ndnoA3wq4337wbWAAroAGxz1/kqU08AWusDWutDxax2E3BUa31ca50BLALuU0opoCuw1LzePKC3m0K7z7w/e/f7ILBGa33VTce3xdG48hh9vrTWh7XWR8y/nwbOA4V6OrqB1e9LEfEuBbqZz899wCKt9d9a6xPAUfP+SiQurfUGi+/QVqCOm47tUlxF6AGs11r/pbW+BKwHog2K62HgGzcdu0ha658x3fDZch/wlTbZCoQppWrhhvNVphKAncKBkxavT5mXVQOStNZZBZa7Qw2t9Rnz72eBGsWs/xCFv3xvmx//pimlypdwXBWUUjuUUltzi6XwovOllLoJ013dMYvF7jpftr4vVtcxn49kTOfHnm09GZelpzHdReay9pmWZFwPmD+fpUqpug5u68m4MBeV1Qd+sljsqfNlD1uxu3y+St2MYEqpH4CaVt4aq7X+tqTjyVVUXJYvtNZaKWWz7a05s7cA1losHoPpQhiIqS3wK8CbJRjX9VrrBKVUA+AnpdQ+TBc5p7n5fM0HntRa55gXO32+yiKl1GNAO+B2i8WFPlOt9THre3C7fwPfaK3/VkoNwvT01LWEjm2Ph4ClWutsi2VGni+PKXUJQGt9h4u7SADqWryuY152EdOjVTnzXVzucpfjUkqdU0rV0lqfMV+wzhexqxhghdY602LfuXfDfyul5gD/Ksm4tNYJ5n+PK6XigEhgGQafL6VUZWA1puS/1WLfTp8vK2x9X6ytc0opVQ4IxfR9smdbT8aFUuoOTEn1dq3137nLbXym7rigFRuX1vqixcvPMdX55G7bpcC2cW6Iya64LDwEDLVc4MHzZQ9bsbt8vnyxCGg70EiZWrAEYvqwV2lTrcoGTOXvAE8C7nqiWGXenz37LVT2aL4I5pa79wasthbwRFxKqSq5RShKqepAR+A3o8+X+bNbgalsdGmB99x5vqx+X4qI90HgJ/P5WQU8pEythOoDjYD/uRCLQ3EppSKBWcC9WuvzFsutfqYlGFcti5f3AgfMv68FupvjqwJ0J/+TsEfjMsfWFFOF6haLZZ48X/ZYBTxhbg3UAUg23+S4fr48VbNtxA9wP6ZysL+Bc8Ba8/LawPcW690NHMaUwcdaLG+A6T/oUWAJUN5NcVUDfgSOAD8AVc3L2wGfW6xXD1NW9yuw/U/APkwXsgVASEnFBdxqPvYe879Pe8P5Ah4DMoF4i5/Wnjhf1r4vmIqU7jX/XsH89x81n48GFtuONW93CLjLzd/34uL6wfz/IPf8rCruMy2huCYB+83H3wA0tdj2KfN5PAoMKMm4zK9fByYX2M7T5+sbTK3YMjFdv54GngOeM7+vgI/Nce/DooWjq+dLhoIQQggf5YtFQEIIIZAEIIQQPksSgBBC+ChJAEII4aMkAQghhI+SBCCEED5KEoAQQvgoSQBCuEAp1d48qFkFpVRFZZqfIMLouISwh3QEE8JFSqkJmHoDBwGntNaTDA5JCLtIAhDCReaxZbYD6cCtOv8okkJ4LSkCEsJ11YAQoBKmJwEhSgV5AhDCRco0R+wiTJOI1NJaP29wSELYpdTNByCEN1FKPQFkaq2/Vkr5A78opbpqrX8qblshjCZPAEII4aOkDkAIIXyUJAAhhPBRkgCEEMJHSQIQQggfJQlACCF8lCQAIYTwUZIAhBDCR/0/ZFpdNpgE5mkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scatter_plot_polynomial(X_train, y_train, label='training set')\n", "scatter_plot_polynomial(X_valid, y_valid, label='validation set')\n", "plt.plot(X_lin, y_fit, label='fit', c='black')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### From underfitting to overfitting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can investigate the shape of the fitted curve for different values of `degree`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for degree in [0, 1, 2, 3, 5, 10, 50]:\n", " pr = PolynomialRegressor(degree=degree)\n", " pr.fit(X_train, y_train)\n", " y_fit = pr.predict(X_lin)\n", " \n", " title = f'Polynomial fit degree={degree}'\n", " scatter_plot_polynomial(X_train, y_train, label='polynomial data', title=title)\n", " scatter_plot_polynomial(X_valid, y_valid, label='validation set', title=title)\n", " plt.plot(X_lin, y_fit, label='fit', c='black')\n", " \n", " plt.ylim([-6, 6])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By just looking at the curves we can observe **3 regimes**:\n", "\n", "**Underfitting (degree<3):**\n", "The model is not able to fit the complexity data properly. The fit is bad for both the training and the validation set.\n", "\n", "**Fit is just right (degree=3):**\n", "The model is able to caputre the underlying data distribution. The fit is good for both the training and the validation set.\n", "\n", "**Overfitting (degree>3:**\n", "The model starts fitting the noise in the dataset. While the fit for the training data gets even better the fit for the validation set gets worse.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fitting graph\n", "In the next step we want to quantify the previous observation. To do this we calculate the training and and validation error for each degree and plot them in a single graph. The resulting graph is called the **fitting graph**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 30/30 [00:00<00:00, 1358.22it/s]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "degrees = list(range(0, 30))\n", "\n", "for degree in tqdm(degrees):\n", " pr = PolynomialRegressor(degree=int(degree))\n", " pr.fit(X_train, y_train)\n", " rmse_train.append(pr.evaluate(X_train, y_train))\n", " rmse_valid.append(pr.evaluate(X_valid, y_valid))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_fitting_graph(x, metric_train, metric_valid, metric_name='metric', xlabel='x', yscale='linear'):\n", " plt.plot(x, metric_train, label='train')\n", " plt.plot(x, metric_valid, label='valid')\n", " plt.yscale(yscale)\n", " plt.title('Fitting graph')\n", " plt.ylabel(metric_name)\n", " plt.xlabel(xlabel)\n", " plt.legend(loc='best')\n", " plt.grid(True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(degrees, rmse_train, rmse_valid, metric_name='RMSE', xlabel='degree')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can make a few observations:\n", "1. The training error always decreases when adding more degrees.\n", "2. There is a region between 3-15 where the validation error is stable and low. \n", "\n", "Ideally, we would choose the the model parameters such that we have the best model performance. However, we want to make sure that we really have the best validation performance. When we do `train_test_split` we randomly split the data into to parts. What could happen is that we got lucky and split the data such that it favours the validation error. This is especially dangerous if we are dealing with small datasets. One way to check if that's the case is to run the experiment several times for different, random splits. However, there is an even more systematic way of doing this: **cross-validation**.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cross-validation\n", "\n", "The idea behind cross validation is to split the data into k equally sized parts, called folds. Each fold gets to be the validation set once while the other folds play the training set part. That means we run k experiments and aggregate the training and validation metrics by averaging them. This is a more robust approach to monitoring overfitting and thanks to `scikit-learn` we only have to adjust one line by adding the `cross_validate` function!\n", "\n", "
\n", "\n", "

Figure reference: https://scikit-learn.org/stable/modules/cross_validation.html

\n", "
\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross-valida takes care of all the steps, we just have to pass an initialized model, the full dataset and the number of folds with the keyword `cv`. Furthermore, we need to specify the metric to be evaluated and also that we want it to return the scores on the training sets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 16/16 [00:00<00:00, 235.88it/s]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "degrees = list(range(0, 16))\n", "for degree in tqdm(degrees):\n", " pr = PolynomialRegressor(degree=degree)\n", " results = cross_validate(pr, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", " \n", " # we average the scores and append them to the list\n", " rmse_train.append(-np.mean(results['train_score']))\n", " rmse_valid.append(-np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(degrees, rmse_train, rmse_valid, metric_name='RMSE', xlabel='degree', yscale='log')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use `np.argmin` to find the element with the minimum validation error. The function returns the index in the array with the minimum value." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmin(rmse_valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we start counting from 0 in programming this means we are looking for the fourth element in the degrees list:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "degrees" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "degrees[np.argmin(rmse_valid)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Findings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In these experiments we made some key observations:\n", "\n", "1. We can fit polynomials (hurray)\n", "2. There are three regimes when fitting polynomials: **underfitting**, **good fit** and **overfitting**\n", "3. These regimes depend on the degree of the polynomials we are fitting.\n", "4. Increasing the degree of the polynomials always decreases the training error.\n", "5. The validation error decreases in the from the underfitting to good fit regime and then increases in the overfitting region.\n", "\n", "These observations are not special about polynomials - they hold for fitting machine learning models in general. Let's translate these observations to general machine learning models:\n", "\n", "The challenge in fitting models in machine learning is to find the **good fit**. A model that is **too simple** will not be able to capture the complexity of the data and lead to **underfitting**. A model that is **too complex** has the capacity to \"memorize\" aspects of the data and cause **overfitting**. If we are overfitting our model will not predict well unseen data - we say it does not **generalise**. The goal is to find a model that has just the right complexity to fit the data. The **fitting graph** is a tool to identify the sweetspot of model complexity.\n", "\n", "**Complexity**\n", "\n", "The model complexity comes in different form and shapes. In our polynomial example the complexity is controlled by the `degree` parameter. For a Random Forest the complexity is given by several parameters such as `tree_depth` of and `n_estimators`. \n", "\n", "**Classification**\n", "\n", "We can also create fitting graphs for classification tasks. Instead of looking at RMSE we can look at the accuracy. The difference is that the graph will look inverted in the y-axis; instead of looking for the lowest validation error we will look for the highest validation accuracy.\n", "\n", "**More data**\n", "\n", "There is another way to reduce overfitting: get more data! In the following exercise you explore how more data influences the fitting graph.\n", "\n", "### Exercise #1\n", "Use the `get_polynomial_data` and generate `100_000` samples. Create another fitting graph with cross-validation and compare it to the one with `100` samples.\n", "\n", "In the following section we will investigate overfitting in Random Forests.\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train/validation/test set\n", "Before we move to RandomForests we want to clarify some ambiguity about the terms train, validation and test set. So far we have concentrated on the train and validation set: We train a model on the train set and we evaluate the metrics on the validation set. However, sometimes we also come across the term test set. For example the function we used to split the dataset in to sets: `train_test_split`.\n", "\n", "In machine learning each of these sets has a distinct function:\n", "\n", "- The **train set** is used to train a model.\n", "- With the **validation set** the model is evaluated. With this information we tune the parameters.\n", "- We only evaluate the final, tuned model on the **test set**. We do not use it to tune the model parameters.\n", "\n", "The reason we make the distinction between validation and test set is that by tuning the parameters on the validation performance we might start to overfit the validation data. The test set gives a final sanity check that we actually have a performant model.\n", "\n", "In **cross-validation** the concept of train and validation is melted and all training data is also validation data at some point." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Housing dataset\n", "In this section we want to use cross-validation to sytematically tune the parameters of the Random Forest without overfitting. We do it in a linear fashion and tune one parameter after another. This does not guarantee that we find the best global parameters, but runs much faster than a global grid search. We will tune the following parameters:\n", "\n", "* n_estimators\n", "* max_depth\n", "* min_samples_leaf\n", "* max_features\n", "\n", "First we load the processed housing data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset already exists at '../data/housing_processed.csv' and is not downloaded again.\n" ] } ], "source": [ "get_dataset('housing_processed.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "autos.csv median_submission.csv\n", "churn.csv sample_submission.csv\n", "housing.csv solution.csv\n", "housing_addresses.csv test.csv\n", "housing_gmaps_data_raw.csv train.csv\n", "housing_processed.csv word2vec-google-news-300.pkl\n", "imdb.csv zero_submission.csv\n" ] } ], "source": [ "DATA = Path('../data/')\n", "!ls {DATA}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
longitudelatitudehousing_median_agetotal_roomstotal_bedroomspopulationhouseholdsmedian_incomemedian_house_valuecitypostal_coderooms_per_householdbedrooms_per_householdbedrooms_per_roompopulation_per_householdocean_proximity_INLANDocean_proximity_<1H OCEANocean_proximity_NEAR BAYocean_proximity_NEAR OCEANocean_proximity_ISLAND
0-122.2337.8841.0880.0129.0322.0126.08.3252452600.069947056.9841271.0238100.1465912.55555600100
1-122.2237.8621.07099.01106.02401.01138.08.3014358500.0620946116.2381370.9718800.1557972.10984200100
2-122.2437.8552.01467.0190.0496.0177.07.2574352100.0620946188.2881361.0734460.1295162.80226000100
3-122.2537.8552.01274.0235.0558.0219.05.6431341300.0620946185.8173521.0730590.1844582.54794500100
4-122.2537.8552.01627.0280.0565.0259.03.8462342200.0620946186.2818531.0810810.1720962.18146700100
\n", "
" ], "text/plain": [ " longitude latitude housing_median_age total_rooms total_bedrooms \\\n", "0 -122.23 37.88 41.0 880.0 129.0 \n", "1 -122.22 37.86 21.0 7099.0 1106.0 \n", "2 -122.24 37.85 52.0 1467.0 190.0 \n", "3 -122.25 37.85 52.0 1274.0 235.0 \n", "4 -122.25 37.85 52.0 1627.0 280.0 \n", "\n", " population households median_income median_house_value city \\\n", "0 322.0 126.0 8.3252 452600.0 69 \n", "1 2401.0 1138.0 8.3014 358500.0 620 \n", "2 496.0 177.0 7.2574 352100.0 620 \n", "3 558.0 219.0 5.6431 341300.0 620 \n", "4 565.0 259.0 3.8462 342200.0 620 \n", "\n", " postal_code rooms_per_household bedrooms_per_household \\\n", "0 94705 6.984127 1.023810 \n", "1 94611 6.238137 0.971880 \n", "2 94618 8.288136 1.073446 \n", "3 94618 5.817352 1.073059 \n", "4 94618 6.281853 1.081081 \n", "\n", " bedrooms_per_room population_per_household ocean_proximity_INLAND \\\n", "0 0.146591 2.555556 0 \n", "1 0.155797 2.109842 0 \n", "2 0.129516 2.802260 0 \n", "3 0.184458 2.547945 0 \n", "4 0.172096 2.181467 0 \n", "\n", " ocean_proximity_<1H OCEAN ocean_proximity_NEAR BAY \\\n", "0 0 1 \n", "1 0 1 \n", "2 0 1 \n", "3 0 1 \n", "4 0 1 \n", "\n", " ocean_proximity_NEAR OCEAN ocean_proximity_ISLAND \n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "housing_data = pd.read_csv(DATA/'housing_processed.csv')\n", "housing_data.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = housing_data.drop('median_house_value', axis=1)\n", "y = housing_data['median_house_value']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Untuned model\n", "We want to first evaluate the untuned model to get a baseline:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train score: 15834.9\n", "Validation score: 60426.9\n" ] } ], "source": [ "rf = RandomForestRegressor(n_jobs=-1)\n", "results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", "print('Train score: %.1f' % -np.mean(results['train_score']))\n", "print('Validation score: %.1f' % -np.mean(results['test_score']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we tune one parameter after another and always choose the best one of the round based on the fitting curve." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### n_estimators" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [03:18<00:00, 49.67s/it]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "\n", "n_estimators = [25, 50, 100, 200]\n", "for n in tqdm(n_estimators):\n", " rf = RandomForestRegressor(n_estimators=n, n_jobs=-1)\n", " results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", " \n", " # we average the scores and append them to the list\n", " rmse_train.append(-np.mean(results['train_score']))\n", " rmse_valid.append(-np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEXCAYAAABsyHmSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3de5RU5Znv8e/TF7oBG1AwrYIJJGEM4I2LiKMmPTpj0ETRiYIZE03G0blojHFcGTLJGU1i5pgzlxw9MWYYJWKOjhIzGckZDTFKaeISFRQRRQMqSnO/yKXtBvrynD/2W927q6uquzdd1V3w+6xVq3a9+313PbW7u369L1Xb3B0REZEkyvq7ABERKV0KERERSUwhIiIiiSlEREQkMYWIiIgkphAREZHEFCIigZl92MwazKy8F2PONrM3C1lXsZiZm9nH+7sOKS0KETnsmNk6M2sKgZG+Hefu77n7Ee7eGvqlzOwvMsZ2eqN199+6+wnFfg0iA4VCRA5XF4bASN829ndBB8PMKvq7Bjk8KUREAjMbG7Y0Kszse8DZwA/DlsoPzeyZ0PWV0DbHzOrMrD62jHVmdrOZrTSz3Wb2sJlVx+Z/3cw2mdlGM/uLfLuQzGycmT1jZnvN7DdmdpeZ/d+MWq82s/eAp0L7z8xsc3juZ8xsUmx595nZj83sibDMp83sIxlP+8dmtsbMdoXnsz5ZuXLIUoiIZOHu3wR+C1wftlSud/dPhtmnhLaHcwyfDcwExgEnA18CMLOZwE3AHwMfB+q6KeNB4AVgJHAr8MUsfT4FTAA+HR4/DowHPgS8BDyQ0f8K4LvAKGBFlvmfBU4Ldc+OLVckK4WIHK7+K/y3vcvM/quPl32nu290953AL4FTQ/ts4Cfu/pq7NxIFQ1Zm9mGiN/N/cPcD7v47YFGWrre6+wfu3gTg7vPdfa+77w/LP8XMhsf6/7e7PxPmfxM4w8yOj82/3d13uft7wJJY7SJZKUTkcHWxu48It4v7eNmbY9ONwBFh+jhgfWxefDrTccDOEDb5+re3mVm5md1uZm+Z2R5gXZg1Klt/d28Adobn6q52kawUIiK59fVXXG8CxsQeH5+rY+h7lJkN6aZ/vMY/A2YR7S4bDowN7fHjGu3LMLMjgKOAkj6pQPqXQkQkty3AR3vQ1lMLgS+b2YQQDv8jV0d3fxdYBtxqZoPM7Azgwm6WXwPsB3YAQ4B/zNLnAjM7y8wGER0bWeru+baIRPJSiIjkdgdwqZm9b2Z3hrZbgQXhWMrs3izM3R8H7iQ61rAWWBpm7c8x5ArgDKJQuA14OE9fgPuBd4ENwOux5cc9CNxCtBtrKvCF3rwGkUymi1KJ9A8zmwCsAqrcvaUH/R8G3nD3WxI+331Avbt/K8l4kWy0JSJSRGZ2iZlVmdmRwPeBX+YKEDM7zcw+ZmZl4fTgWUBfn0kmclAUIiLF9ZfAVuAtoBX46zx9jwFSQAPRbrC/dveXC12gSG9od5aIiCSmLREREUnssPvStlGjRvnYsWMTjf3ggw8YOnRo3xZUIKVUK5RWvaVUK5RWvaVUK5RWvQdT6/Lly7e7+9FZZ7r7YXWbOnWqJ7VkyZLEY4utlGp1L616S6lW99Kqt5RqdS+teg+mVmCZ53hP1e4sERFJTCEiIiKJKURERCQxhYiIiCSmEBERkcQUIiIikphCREREEjvsPmyY2LKfULv5HXijEQaPgOrhUD0imq4cAmbdL0NE5BCjEOkJd3j860xoPQBv3NF1flllCJYRHffVw7u2ZZtXVaMAEpGSpRDpqZt/z9LUYmaccgI07YJ9u7re79sdTTduhx1rO9q8LfdyrTxL4AzPHj6ZIVQ1HMq0R1JE+o9CpCfMYPCR7Bt8DBw3uXdj29rgwN6uQZMthNL3u97rmG7Ld60ig+phWYJmOB/dtgfanoXKaqgYDJXhVlEdux8Smx/rV1GtcBKRHlGIFFpZWdiyGA58pHdj3eHAB1mCZnfuENq7CZp2MbppF6w/iOsXlVflDqAuYZRuG9w5lCqH5OhX3fVeREqSQmQgM4OqI6Lb8DG9GvrbVIq6T30KWvZBc1PHfXy6U1sTNO/LuM/Rv3FnRltjNKY13+W/8zu7bBA8PzRPAOULpczw6ibQygfpONShzD3ahZy+J/aYbPM8z7xc46LnGfJBPWx7M8cy48sgz7xc47I9zjcu3+tr40Nb3gHq+nx1FzREzGwEcA9wItFq/HPgTeBhYCywDpjt7u+bmQF3ABcAjcCX3P2lsJyrgPR1oW9z9wWhfSpwHzAYeAz4avjGSYHojTL9RlsMbW1RqMSDpdtQivpteOf3fPiYo2P9YmMbt2eMDfNaDyQs1JIHUEU1o+vfhaWre/jGk+OPP+EbQed5nmdex7hJW7fAlnv6/E0pSS3dLf8PD+yHFyoy5tHD9Zfn2GOBTAd4sehPm8jHK0cAt/b5cgu9JXIH8Ct3v9TMBgFDgL8HnnT3281sLjAX+DvgfGB8uJ0O3A2cbmZHAbcA04h+nZab2SJ3fz/0uQZ4nihEZgKPF/g1SS5lZTBoSHTjqF4NfTuV4sN1db17vrbWnKHUoy2qXG37dndeRnpLLRyfGg+wtneldrAo3K0sTJdlPM6czjUvcxxd54XpIY2NsGN3bB45ltF5XPu89luWeV3G5agz52vtPG/bps2MPm50ntdKL9ZfWcZ66ck4erDMjppfX72aiZMm9eC1Jl1HCX5e8dcQm7ds6Qv8YdJf2zwKFiJmNhz4JPAlAHc/ABwws1l0bFMtILqG9N8Bs4D7w5bEUjMbYWbHhr5PuPvOsNwngJlmlgKGufvS0H4/cDEKkcNHWXnH7r5iaG2BliZ+98wSzjrzrF6+maXfMIvvxVSKut4GdD9Zk0oxukRqBdi6M8XEE+v6u4weOVCV+D+fvAq5JTIO2Ab8xMxOAZYDXwVq3X1T6LMZqA3To4H1sfH1oS1fe32W9i7M7FrgWoDa2lpSqVSiF9TQ0JB4bLGVUq1QWvU27C8j9cLK/i6jx0pq3ZZQrVBa9Raq1kKGSAUwBfiKuz9vZncQ7bpq5+5uZgU/huHu84B5ANOmTfOk/5WlSug/ulKqFUqr3lKqFUqr3lKqFUqr3kLVWsgPA9QD9e7+fHj8CFGobAm7qQj3W8P8DcDxsfFjQlu+9jFZ2kVEpEgKFiLuvhlYb2YnhKZzgdeBRcBVoe0q4NEwvQi40iIzgN1ht9di4DwzO9LMjgTOAxaHeXvMbEY4s+vK2LJERKQICn121leAB8KZWW8DXyYKroVmdjXwLjA79H2M6PTetUSn+H4ZwN13mtl36TiR7jvpg+zA39Bxiu/j6KC6iEhRFTRE3H0F0am5mc7N0teB63IsZz4wP0v7MqLPoIiISD/QFySJiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiBQ0RM1tnZq+a2QozWxbajjKzJ8xsTbg/MrSbmd1pZmvNbKWZTYkt56rQf42ZXRVrnxqWvzaMtUK+HhER6awYWyJ/5O6nuvu08Hgu8KS7jweeDI8BzgfGh9u1wN0QhQ5wC3A6MB24JR08oc81sXEzC/9yREQkrT92Z80CFoTpBcDFsfb7PbIUGGFmxwKfBp5w953u/j7wBDAzzBvm7kvd3YH7Y8sSEZEiKHSIOPBrM1tuZteGtlp33xSmNwO1YXo0sD42tj605Wuvz9IuIiJFUlHg5Z/l7hvM7EPAE2b2Rnymu7uZeYFrIATYtQC1tbWkUqlEy2loaEg8tthKqVYorXpLqVYorXpLqVYorXoLVWtBQ8TdN4T7rWb2C6JjGlvM7Fh33xR2SW0N3TcAx8eGjwltG4C6jPZUaB+TpX+2OuYB8wCmTZvmdXV12bp1K5VKkXRssZVSrVBa9ZZSrVBa9ZZSrVBa9Raq1oLtzjKzoWZWk54GzgNWAYuA9BlWVwGPhulFwJXhLK0ZwO6w22sxcJ6ZHRkOqJ8HLA7z9pjZjHBW1pWxZYmISBEUckukFvhFOOu2AnjQ3X9lZi8CC83sauBdYHbo/xhwAbAWaAS+DODuO83su8CLod933H1nmP4b4D5gMPB4uImISJEULETc/W3glCztO4Bzs7Q7cF2OZc0H5mdpXwaceNDFiohIIvrEuoiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSWEV/FyAiMpA1NzdTX1/Pvn37uswbPnw4q1ev7oeqeq8ntVZXVzNmzBgqKyt7vFyFiIhIHvX19dTU1DB27FjMrNO8vXv3UlNT00+V9U53tbo7O3bsoL6+nnHjxvV4udqdJSKSx759+xg5cmSXADnUmBkjR47MusWVj0JERKQbh3qApCV5nQoREZEBbNeuXfzoRz/q9bgLLriAXbt2FaCizvKGiJmdE5selzHvTwtVlIiIRHKFSEtLS95xjz32GCNGjChUWe262xL559j0zzPmfauPaxERkQxz587lrbfe4tRTT+W0007j7LPP5qKLLmLixIkAXHzxxUydOpVJkyYxb9689nFjx45l+/btrFu3jgkTJvCVr3yFSZMmcd5559HU1NRn9XV3dpblmM72OPsCzMqBZcAGd/9s2KJ5CBgJLAe+6O4HzKwKuB+YCuwA5rj7urCMbwBXA63ADe6+OLTPBO4AyoF73P32ntQkIpLEt3/5Gq9v3NP+uLW1lfLy8oNa5sTjhnHLhZNyzr/99ttZtWoVK1asIJVK8ZnPfIZVq1a1n0E1f/58jjrqKJqamjjttNP43Oc+x8iRIzstY82aNdxzzz3cd999zJ49m5///Od84QtfOKi607rbEvEc09ke5/JVIH5y8veBH7j7x4H3icKBcP9+aP9B6IeZTQQuByYBM4EfmVl5CKe7gPOBicDnQ18RkUPW9OnTO52Ce+edd3LKKacwY8YM1q9fz5o1a7qMGTduHCeffDIAU6dOZd26dX1WT3dbIh81s0VEWx3pacLjbk8kNrMxwGeA7wE3WXTo/xzgz0KXBcCtwN3ArDAN8Ajww9B/FvCQu+8H3jGztcD00G+tu78dnuuh0Pf17uoSEUkic4uhPz4nMnTo0PbpVCrFb37zG5577jmGDBlCXV1d1lN0q6qq2qfLy8uLujtrVmz6nzPmZT7O5n8DXwfSa3kksMvd00eE6oHRYXo0sB7A3VvMbHfoPxpYGltmfMz6jPbTsxVhZtcC1wLU1taSSqV6UHpXDQ0NiccWWynVCqVVbynVCqVV70Csdfjw4ezduzfrvNbW1pzz+tKePXvYu3cvjY2NtLS0tD/n5s2bqampobW1leXLl7N06VIaGxvZu3cv7k5DQwMNDQ20tbW117p//37279+fs+59+/b16meQN0Tc/en4YzOrBE4kOr6xNd9YM/sssNXdl5tZXY8rKgB3nwfMA5g2bZrX1SUrJ5VKkXRssZVSrVBa9ZZSrVBa9Q7EWlevXp1za6MYWyI1NTWcddZZnHHGGQwePJja2tr257zkkktYsGAB06dP54QTTmDGjBkMGTKEmpoazIwjjjgCgLKyMsrLy6mpqaGqqorm5uacdVdXVzN58uQe15c3RMzsx8D/cffXzGw48BzRwe2jzOxmd/+PPMPPBC4yswuAamAY0UHwEWZWEbZGxgAbQv8NwPFAvZlVAMOJDrCn29PiY3K1i4gcMh588MGs7VVVVTz++ONZ56WPe4waNYpVq1a1b3ncfPPNfVpbdwfWz3b318L0l4Hfu/tJRGdQfT3fQHf/hruPcfexRAfGn3L3K4AlwKWh21XAo2F6UXhMmP+Uu3tov9zMqsKZXeOBF4AXgfFmNs7MBoXnSB+zERGRIujumMiB2PSfAD8DcPfNB/E1AH8HPGRmtwEvA/eG9nuBn4YD5zuJQoGwFbSQ6IB5C3Cdu7cCmNn1wGKiU3znxwJPRESKoLsQ2RWObWwg2j11NUDY3TS4p0/i7ikgFabfpuPsqniffcBlOcZ/j+gMr8z2x4DHelqHiIj0re5C5C+BO4FjgBvdfXNoPxf470IWJiIiA193Z2f9nugDfpnti4l2I4mIyGGsu7Oz7sw3391v6NtyRESklHR3dtZfAWcBG4m+/2p5xk1ERAaY9OdDNm7cyKWXXpq1T11dHcuWLTvo5+rumMixRAe75xCdGfUw8Ii7F/5L6kVE5KAcd9xxPPLIIwV9jrxbIu6+w91/7O5/RPQ5kRHA62b2xYJWJSIi7ebOnctdd93V/vjWW2/ltttu49xzz2XKlCmcdNJJPProo13GrVu3jhNPPBGApqYmLr/8ciZMmMAll1zSZ9+f1d2WCABmNgX4PNFnRR5Hu7JE5HD0+FzY/Gr7w8GtLVDeo7fR3I45Cc7PfxWLOXPmcOONN3LdddcBsHDhQhYvXswNN9zAsGHD2L59OzNmzOCiiy7KeYnbe++9lyFDhrB69WpWrlzJlClTDq7uoLsD698h+hbe1UTXAPlG7MsTRUSkCCZPnszWrVvZuHEj27Zt48gjj+SYY47ha1/7Gs888wxlZWVs2LCBLVu2cMwxx2RdxrPPPstNN90EwMknn9z+1fAHq7sI/RbwDnBKuP1jSDkD3N37pgoRkVKQscXQVMSvgr/ssst45JFH2Lx5M3PmzOGBBx5g27ZtLF++nMrKSsaOHZv1a+ALrbsQ6faaISIiUnhz5szhmmuuYfv27Tz99NMsXLiQD33oQ1RWVrJkyRLefffdvOPPPPNMHnzwQc455xxWrVrFypUr+6Su7j5smLUqMysjOkaSv2oREekTkyZNYu/evYwePZpjjz2WK664ggsvvJCTTjqJadOm8YlPfCLv+KuvvpobbriBCRMmMGHCBKZOndondXV3TGQYcB3RRaAWAU8A1wN/C7wCPNAnVYiISLdefbXjoP6oUaN47rnnsvZraGgAYOzYsaxatQqAwYMH89BDD/V5Td3tzvop0XXQnwP+Avh7ouMhF7v7ij6vRkRESkq311gP1w/BzO4BNgEfDt+4KyIih7nuvvakOT0RruFRrwAREZG07rZETjGzPWHagMHhcfoU32EFrU5EZABw95wf4juURBeT7Z3uzs4qT1yNiMghoLq6mh07djBy5MhDOkjcnR07dlBdXd2rcQf5eX0RkUPbmDFjqK+vZ9u2bV3m7du3r9dvuv2lJ7VWV1czZsyYXi1XISIikkdlZSXjxmX/3HUqlWLy5MlFriiZQtXa3YF1ERGRnBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJDGFiIiIJKYQERGRxBQiIiKSmEJEREQSU4iIiEhiChEREUlMISIiIokpREREJLGChYiZVZvZC2b2ipm9ZmbfDu3jzOx5M1trZg+b2aDQXhUerw3zx8aW9Y3Q/qaZfTrWPjO0rTWzuYV6LSIikl0ht0T2A+e4+ynAqcBMM5sBfB/4gbt/HHgfuDr0vxp4P7T/IPTDzCYClwOTgJnAj8ys3MzKgbuA84GJwOdDXxERKZKChYhHGsLDynBz4BzgkdC+ALg4TM8Kjwnzz7XogsazgIfcfb+7vwOsBaaH21p3f9vdDwAPhb4iIlIkBb08bthaWA58nGir4S1gl7u3hC71wOgwPRpYD+DuLWa2GxgZ2pfGFhsfsz6j/fQcdVwLXAtQW1tLKpVK9HoaGhoSjy22UqoVSqveUqoVSqveUqoVSqveQtVa0BBx91bgVDMbAfwC+EQhny9PHfOAeQDTpk3zurq6RMtJpVIkHVtspVQrlFa9pVQrlFa9pVQrlFa9haq1KGdnufsuYAlwBjDCzNLhNQbYEKY3AMcDhPnDgR3x9owxudpFRKRICnl21tFhCwQzGwz8CbCaKEwuDd2uAh4N04vCY8L8p9zdQ/vl4eytccB44AXgRWB8ONtrENHB90WFej0iItJVIXdnHQssCMdFyoCF7v7/zOx14CEzuw14Gbg39L8X+KmZrQV2EoUC7v6amS0EXgdagOvCbjLM7HpgMVAOzHf31wr4ekREJEPBQsTdVwKTs7S/TXRmVWb7PuCyHMv6HvC9LO2PAY8ddLEiIpKIPrEuIiKJKURERCQxhYiIiCSmEBERkcQUIiIikphCREREElOIiIhIYgoRERFJTCEiIiKJKURERCQxhYiIiCSmEBERkcQUIiIikphCREREElOIiIhIYgoRERFJTCEiIiKJKURERCQxhYiIiCSmEBERkcQUIiIikphCREREElOIiIhIYgoRERFJTCEiIiKJKURERCQxhYiIiCSmEBERkcQUIiIikphCREREElOIiIhIYgoRERFJTCEiIiKJKURERCSxgoWImR1vZkvM7HUze83MvhrajzKzJ8xsTbg/MrSbmd1pZmvNbKWZTYkt66rQf42ZXRVrn2pmr4Yxd5qZFer1iIhIV4XcEmkB/tbdJwIzgOvMbCIwF3jS3ccDT4bHAOcD48PtWuBuiEIHuAU4HZgO3JIOntDnmti4mQV8PSIikqFgIeLum9z9pTC9F1gNjAZmAQtCtwXAxWF6FnC/R5YCI8zsWODTwBPuvtPd3weeAGaGecPcfam7O3B/bFkiIlIEFcV4EjMbC0wGngdq3X1TmLUZqA3To4H1sWH1oS1fe32W9mzPfy3R1g21tbWkUqlEr6OhoSHx2GIrpVqhtOotpVqhtOotpVqhtOotVK0FDxEzOwL4OXCju++JH7ZwdzczL3QN7j4PmAcwbdo0r6urS7ScVCpF0rHFVkq1QmnVW0q1QmnVW0q1QmnVW6haC3p2lplVEgXIA+7+n6F5S9gVRbjfGto3AMfHho8Jbfnax2RpFxGRIink2VkG3Ausdvd/jc1aBKTPsLoKeDTWfmU4S2sGsDvs9loMnGdmR4YD6ucBi8O8PWY2IzzXlbFliYhIERRyd9aZwBeBV81sRWj7e+B2YKGZXQ28C8wO8x4DLgDWAo3AlwHcfaeZfRd4MfT7jrvvDNN/A9wHDAYeDzcRESmSgoWIu/8OyPW5jXOz9HfguhzLmg/Mz9K+DDjxIMoUEZGDoE+si4hIYgoRERFJTCEiIiKJKURERCQxhYiIiCSmEBERkcQUIj308nvvU7+3jc2799F0oJXojGQRkcNbUb6A8VDwZ//+PE3NrXzr2ScBqCw3hg+uZFh1JcMGR7focUV03/44uo/aonk11ZWUl+nSJyJS+hQiPeDu/NsXp/Lc8lcY89Hx7G5qZk9TS3S/r5k9Tc3sbjzAezs+YM++qL21Lf+WSk1VRXv4ZA+eivbHmfOqK8vQ9bdEZCBQiPSAmfHJPziato0V1J3+kW77uzuNB1rbQ2Z3Y3N7uOxpamZ3uKUDaE9TC+/uaIz6NjXTeKA17/IHlZcxLIRMeksnCpmK9scb32tm67L1VJYbFWVlVJaXRdPhPnpcRkWZtc+rLC+jIj2vrIzKivRYU2iJSFYKkQIwM4ZWVTC0qoLjGNzr8c2tbVG4ZARPOmQyt4J2NR7g3WxbQa+v7LPXVFFmUcCUlVFZ0Tl8KspjIVXWEVCZoZUOpE5hFUKu/r0DvGlvUVFexqAwrqLMGFSRf1zWAExPh3HlZQpBkUJRiAxAleVljDyiipFHVPV6rLvzwYFWnljyDNOmz6ClzWlpbeNAaxstrU5LWxsHWqL75tY2mludllYP0220tHmsvaNP53ltYUy6vWOZLa3OgdY2Gg+00NLmHGhpa6+huTVjfFvU1h56a97o4zXZYVAIlngwVZRbrD0Kx8qyLGFVHm8vY8vm/SzZvao9mMygzAwL02YW3ZO+D/PDNKFvvC3buLLYNF36W8eY9HgsGhOmCXW9uaGZnS/Vt9cJHeMya6B9GR3LLQsd0s8br6tTLbHpsrCsrq8/XXfn6fT629jQxtqtDV3XaXxcltrp1L9jPu3P1bX2bOs0/dql5xQihxgz44iqCo6sLuP4o4b0dzk90tbmPJVKccaZZ7eHUDyQehJyLW1tNLdEwdTRP4RcaxvNbU5zS5YgbG9v40BYZkur09TcmqMGZ9/+Fl7avpG2NscBHBxoc8cdnHAfn47N7xevvtJPT5zA757u7wqALP8ctAdzRxC1tbZSuWRx15DN/Kcitox4UGWGaK5xXcK/rHOQ0x6QWUI0LKPlg30U4vpZChHpd2Vl0RbC0KrS+HXsiyvEeQiUNo+CKFf4xOfT6bG3B1M6xNLj2rxj+QDPLV3K9Omnt49rC8mXGW49C8B0XR01eOZ0Ru2k62qjU+3pGtPj2tx57fXXmTBhAnSpK94/S+3RSu1SV1tYCZ3r6pgG2v8ZiK8/MpbR9WcU3a9fv57RY46PPY93qSveP7OGjp9dqD3b+un0c43/DOJ1ZVk/8T5t0Fqgf2BK469W5BDTvtsn59US+s5bQ8oYO2powZ+nL9S8/3vqTh3d32X0WCq1lbq6if1dRo8U6lrw+rChiIgkphAREZHEFCIiIpKYQkRERBJTiIiISGIKERERSUwhIiIiiSlEREQkMTvcLq5kZtuAdxMOHwVs78NyCqmUaoXSqreUaoXSqreUaoXSqvdgav2Iux+dbcZhFyIHw8yWufu0/q6jJ0qpViitekupViitekupViitegtVq3ZniYhIYgoRERFJTCHSO/P6u4BeKKVaobTqLaVaobTqLaVaobTqLUitOiYiIiKJaUtEREQSU4iIiEhiCpEszOx4M1tiZq+b2Wtm9tXQfquZbTCzFeF2QX/XmmZm68zs1VDXstB2lJk9YWZrwv2RA6DOE2Lrb4WZ7TGzGwfSujWz+Wa21U+2BjEAAAaoSURBVMxWxdqyrkuL3Glma81spZlNGQC1/pOZvRHq+YWZjQjtY82sKbaOf1zMWvPUm/Nnb2bfCOv2TTP79ACo9eFYnevMbEVo79d1m+c9q/C/t9ElFXWL34BjgSlhugb4PTARuBW4ub/ry1HzOmBURtv/AuaG6bnA9/u7zoz6yoHNwEcG0roFPglMAVZ1ty6BC4DHiS5lPQN4fgDUeh5QEaa/H6t1bLzfAFq3WX/24W/uFaAKGAe8BZT3Z60Z8/8F+IeBsG7zvGcV/PdWWyJZuPsmd38pTO8FVgOlc83ODrOABWF6AXBxP9aSzbnAW+6e9BsECsLdnwF2ZjTnWpezgPs9shQYYWbHFqfS7LW6+6/dvSU8XAqMKVY93cmxbnOZBTzk7vvd/R1gLTC9YMVlyFermRkwG/iPYtWTT573rIL/3ipEumFmY4HJwPOh6fqw+Td/IOweinHg12a23MyuDW217r4pTG8GavuntJwup/Mf4UBdt5B7XY4G1sf61TOw/uH4c6L/ONPGmdnLZva0mZ3dX0Vlke1nP5DX7dnAFndfE2sbEOs24z2r4L+3CpE8zOwI4OfAje6+B7gb+BhwKrCJaHN2oDjL3acA5wPXmdkn4zM92oYdMOdzm9kg4CLgZ6FpIK/bTgbauszFzL4JtAAPhKZNwIfdfTJwE/CgmQ3rr/piSuZnH/N5Ov8DNCDWbZb3rHaF+r1ViORgZpVEP4wH3P0/Adx9i7u3unsb8O8UcdO6O+6+IdxvBX5BVNuW9CZquN/afxV2cT7wkrtvgYG9boNc63IDcHys35jQ1q/M7EvAZ4ErwpsHYbfQjjC9nOgYwx/0W5FBnp/9QF23FcCfAg+n2wbCus32nkURfm8VIlmE/Z33Aqvd/V9j7fF9hpcAqzLH9gczG2pmNelpogOrq4BFwFWh21XAo/1TYVad/pMbqOs2Jte6XARcGc52mQHsju0+6BdmNhP4OnCRuzfG2o82s/Iw/VFgPPB2/1TZIc/PfhFwuZlVmdk4onpfKHZ9Wfwx8Ia716cb+nvd5nrPohi/t/11NsFAvgFnEW32rQRWhNsFwE+BV0P7IuDY/q411PtRorNYXgFeA74Z2kcCTwJrgN8AR/V3raGuocAOYHisbcCsW6Jw2wQ0E+0rvjrXuiQ6u+Uuov88XwWmDYBa1xLt707/7v449P1c+P1YAbwEXDhA1m3Onz3wzbBu3wTO7+9aQ/t9wF9l9O3XdZvnPavgv7f62hMREUlMu7NERCQxhYiIiCSmEBERkcQUIiIikphCREREElOIiIhIYgoRkSIws1MzvuL8IjOb20fLvtHMhvTFskR6S58TESmC8DUk09z9+gIse11Y9vZejCl399a+rkUOP9oSEYkJFxdabWb/Hi7u82szG5yj78fM7Ffhm5N/a2afCO2XmdkqM3vFzJ4JXzb5HWBOuGDRHDP7kpn9MPS/z8zuNrOlZva2mdWFb7NdbWb3xZ7vbjNbFur6dmi7ATgOWGJmS0Lb5y26QNkqM/t+bHyDmf2Lmb0CnGFmt1t0EaOVZvbPhVmjcsgr9tce6KbbQL4RXVyoBTg1PF4IfCFH3yeB8WH6dOCpMP0qMDpMjwj3XwJ+GBvb/pjoazQeIvoqilnAHuAkon/ylsdqSX9lRTmQAk4Oj9cRLkhGFCjvAUcDFcBTwMVhngOzw/RIoq8SsXiduunW25u2RES6esfdV4Tp5UTB0kn4yu0/BH5m0SVS/43o6nIAzwL3mdk1RG/4PfFLd3eiANri7q969K22r8Wef7aZvQS8DEwiunJdptOAlLtv8+jCVA8QXaEPoJXoW14BdgP7gHvN7E+Bxi5LEumBiv4uQGQA2h+bbgWy7c4qA3a5+6mZM9z9r8zsdOAzwHIzm9qL52zLeP42oCJ8i+3NwGnu/n7YzVXdg+XG7fNwHMTdW8xsOtHVJS8FrgfO6eXyRLQlIpKERxf8ecfMLoPoq7jN7JQw/TF3f97d/wHYRnTdhr1E175OahjwAbDbzGqJrseSFl/2C8CnzGxU+GryzwNPZy4sbEkNd/fHgK8BpxxEbXIY05aISHJXAHeb2beASqLjGq8A/2Rm44mOcTwZ2t4D5oZdX/+zt0/k7q+Y2cvAG0Rf8/5sbPY84FdmttHd/yicOrwkPP9/u3u268jUAI+aWXXod1NvaxIBneIrIiIHQbuzREQkMe3OEumGmd0FnJnRfIe7/6Q/6hEZSLQ7S0REEtPuLBERSUwhIiIiiSlEREQkMYWIiIgk9v8B5NtLdBat9iUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(n_estimators, rmse_train, rmse_valid, metric_name='RMSE', xlabel='n_estimators')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can zoom in a little bit further to get a better picture of how the validation error behaves:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that adding more estimaters further decreases the validation error. We do not overfit by adding more estimators. However, it takes longer to train and evaluate the model as we add more estimators. This is usually not a major concern when experimenting with models but it can be a major constraint when putting it in production. See [this](https://www.wired.com/2012/04/netflix-prize-costs/) example from Netflix.\n", "\n", "We choose `n_estimators=100` for the sake of speed but we keep in mind that we could probably further improve the model by adding more estimators. We continue tuning the other parameters:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### tree_depth" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 7/7 [03:29<00:00, 29.99s/it]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "\n", "max_depths = [1, 2, 4, 8, 16, 32, 64]\n", "for d in tqdm(max_depths):\n", " rf = RandomForestRegressor(n_estimators=100, max_depth=d, n_jobs=-1)\n", " results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", " \n", " # we average the scores and append them to the list\n", " rmse_train.append(-np.mean(results['train_score']))\n", " rmse_valid.append(-np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(max_depths, rmse_train, rmse_valid, metric_name='RMSE', xlabel='max_depths')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "max_depths[np.argmin(rmse_valid)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### min_samples_leaf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [02:39<00:00, 31.93s/it]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "\n", "min_samples_leaf = [1, 3, 5, 10, 25]\n", "for s in tqdm(min_samples_leaf):\n", " rf = RandomForestRegressor(n_estimators=100, max_depth=16, min_samples_leaf=s, n_jobs=-1)\n", " results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", " \n", " # we average the scores and append them to the list\n", " rmse_train.append(-np.mean(results['train_score']))\n", " rmse_valid.append(-np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(min_samples_leaf, rmse_train, rmse_valid, metric_name='RMSE', xlabel='min_samples_leaf')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "min_samples_leaf[np.argmin(rmse_valid)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### max_features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [01:22<00:00, 16.48s/it]\n" ] } ], "source": [ "rmse_train = []\n", "rmse_valid = []\n", "\n", "max_features = [.1, .25, .5, .75, 1.]\n", "for mf in tqdm(max_features):\n", " rf = RandomForestRegressor(n_estimators=100, max_depth=32, min_samples_leaf=5, max_features=mf, n_jobs=-1)\n", " results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')\n", " \n", " # we average the scores and append them to the list\n", " rmse_train.append(-np.mean(results['train_score']))\n", " rmse_valid.append(-np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(max_features, rmse_train, rmse_valid, metric_name='RMSE', xlabel='max_features')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "max_features[np.argmin(rmse_valid)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rf = RandomForestRegressor(n_estimators=100, max_depth=32, min_samples_leaf=5, max_features=0.5, n_jobs=-1)\n", "results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='neg_root_mean_squared_error')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train score: 28674.0\n", "Validation score: 58935.2\n" ] } ], "source": [ "print('Train score: %.1f' % -np.mean(results['train_score']))\n", "print('Validation score: %.1f' % -np.mean(results['test_score']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that we improved the RMSE on the validation set by more than $2000 which corresponds to 3-4% improvement!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 3: Churn\n", "Now we want to follow the same procedure to optimise the parameters of the Random Forest classifier on the churn dataset. Since this is a classification task there are two things we need to change in our code:\n", "\n", "1. We replace the `RandomForestRegressor` with the `RandomForestClassifier``\n", "2. During cross-validation we want to measure the accuracy instead of the RMSE." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset already exists at '../data/churn.csv' and is not downloaded again.\n" ] } ], "source": [ "get_dataset('churn.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "churn_data = pd.read_csv(DATA/'churn.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X, y, nas = proc_df(churn_data, \"Churn\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train score: 1.000\n", "Validation score: 0.794\n" ] } ], "source": [ "rf = RandomForestClassifier(n_jobs=-1)\n", "results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='accuracy')\n", "print('Train score: %.3f' % np.mean(results['train_score']))\n", "print('Validation score: %.3f' % np.mean(results['test_score']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that we have about 79% accuracy on the validation set. Let's see if we can improve on that." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### n_estimators" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:15<00:00, 3.83s/it]\n" ] } ], "source": [ "acc_train = []\n", "acc_valid = []\n", "\n", "n_estimators = [25, 50, 100, 200]\n", "for n in tqdm(n_estimators):\n", " rf = RandomForestClassifier(n_estimators=n, n_jobs=-1)\n", " results = cross_validate(rf, X, y,\n", " cv=5,\n", " return_train_score=True,\n", " scoring='accuracy')\n", " \n", " # we average the scores and append them to the list\n", " acc_train.append(np.mean(results['train_score']))\n", " acc_valid.append(np.mean(results['test_score']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_fitting_graph(n_estimators, acc_train, acc_valid, metric_name='Accuracy', xlabel='max_features')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: since we want to maximise the accuracy we need to change `np.argmin` to `np.argmax`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "200" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_estimators[np.argmax(acc_valid)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Following the same argument as with the housing dataset we will stick to `n_estimators=100` for now but you can experiment with more estimators if you want." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exercise 2\n", "Run the same optimisation steps as we did with housing dataset to tune the Random Forest for churn classification.\n", "\n", "Optimise the following parameters:\n", "* max_depth\n", "* min_samples_leaf\n", "* max_features\n", "\n", "Note: You will probably see a ~1% improvement. While this not might seem like very much keep in mind that Netflix paid \\\\$50'000 for a 1% improvement in predictions and $1'000'000 for a 10% improvement (see [link](https://en.wikipedia.org/wiki/Netflix_Prize)). In Kaggle challenges this can make the difference between being top-10 and top-100." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }