{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cross-Validation and the Test Set\n", "\n", "In the last lecture, we saw how keeping some data hidden from our model could help us to get a clearer understanding of whether or not the model was overfitting. This time, we'll introduce a common automated framework for handling this task, called **cross-validation**. We'll also incorporate a designated **test set**, which we won't touch until the very end of our analysis to get an overall view of the performance of our model. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassNameSexAgeSiblings/Spouses AboardParents/Children AboardFare
003Mr. Owen Harris Braundmale22.0107.2500
111Mrs. John Bradley (Florence Briggs Thayer) Cum...female38.01071.2833
213Miss. Laina Heikkinenfemale26.0007.9250
311Mrs. Jacques Heath (Lily May Peel) Futrellefemale35.01053.1000
403Mr. William Henry Allenmale35.0008.0500
...........................
88202Rev. Juozas Montvilamale27.00013.0000
88311Miss. Margaret Edith Grahamfemale19.00030.0000
88403Miss. Catherine Helen Johnstonfemale7.01223.4500
88511Mr. Karl Howell Behrmale26.00030.0000
88603Mr. Patrick Dooleymale32.0007.7500
\n", "

887 rows × 8 columns

\n", "
" ], "text/plain": [ " Survived Pclass Name \\\n", "0 0 3 Mr. Owen Harris Braund \n", "1 1 1 Mrs. John Bradley (Florence Briggs Thayer) Cum... \n", "2 1 3 Miss. Laina Heikkinen \n", "3 1 1 Mrs. Jacques Heath (Lily May Peel) Futrelle \n", "4 0 3 Mr. William Henry Allen \n", ".. ... ... ... \n", "882 0 2 Rev. Juozas Montvila \n", "883 1 1 Miss. Margaret Edith Graham \n", "884 0 3 Miss. Catherine Helen Johnston \n", "885 1 1 Mr. Karl Howell Behr \n", "886 0 3 Mr. Patrick Dooley \n", "\n", " Sex Age Siblings/Spouses Aboard Parents/Children Aboard Fare \n", "0 male 22.0 1 0 7.2500 \n", "1 female 38.0 1 0 71.2833 \n", "2 female 26.0 0 0 7.9250 \n", "3 female 35.0 1 0 53.1000 \n", "4 male 35.0 0 0 8.0500 \n", ".. ... ... ... ... ... \n", "882 male 27.0 0 0 13.0000 \n", "883 female 19.0 0 0 30.0000 \n", "884 female 7.0 1 2 23.4500 \n", "885 male 26.0 0 0 30.0000 \n", "886 male 32.0 0 0 7.7500 \n", "\n", "[887 rows x 8 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# assumes that you have run the function retrieve_data() \n", "# from \"Introduction to ML in Practice\" in ML_3.ipynb\n", "titanic = pd.read_csv(\"data.csv\")\n", "titanic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are again going to use the `train_test_split` function to divide our data in two. This time, however, we are not going to be using the holdout data to determine the model complexity. Instead, we are going to hide the holdout data until the very end of our analysis. We'll use a different technique for handling the model complexity. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "np.random.seed(1234) # for reproducibility\n", "train, test = train_test_split(titanic, test_size = 0.2) # hold out 20% of data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We again need to clean our data: " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn import preprocessing\n", "def prep_titanic_data(data_df):\n", " df = data_df.copy()\n", " le = preprocessing.LabelEncoder()\n", " df['Sex'] = le.fit_transform(df['Sex'])\n", " df = df.drop(['Name'], axis = 1)\n", " \n", " X = df.drop(['Survived'], axis = 1).values\n", " y = df['Survived'].values\n", " \n", " return(X, y)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X_train, y_train = prep_titanic_data(train)\n", "X_test, y_test = prep_titanic_data(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-fold Cross-Validation\n", "\n", "The idea of k-fold cross validation is to take a small piece of our training data, say 10%, and use that as a mini test set. We train the model on the remaining 90%, and then evaluate on the 10%. We then take a *different* 10%, train on the remaining 90%, and so on. We do this many times, and finally average the results to get an overall average picture of how the model might be expected to perform on the real test set. Cross-validation is a highly efficient tool for estimating the optimal complexity of a model. \n", "\n", "
\n", " \"Illustration\n", "
\n", " K-fold cross-validation. Source: scikit-learn docs.\n", "
\n", "\n", "The good folks at `scikit-learn` have implemented a function called `cross_val_score` which automates this entire process. It repeatedly selects holdout data; trains the model; and scores the model against the holdout data. While exceptions apply, you can often use `cross_val_score` as a plug-and-play replacement for `model.fit()` and `model.score()` during your model selection phase. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.8028169 , 0.73239437, 0.76056338, 0.81690141, 0.83098592,\n", " 0.8028169 , 0.81690141, 0.78873239, 0.85915493, 0.84285714])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "from sklearn import tree\n", "\n", "# make a model\n", "T = tree.DecisionTreeClassifier(max_depth = 3)\n", "\n", "# 10-fold cross validation: hold out 10%, train on the 90%, repeat 10 times. \n", "cv_scores = cross_val_score(T, X_train, y_train, cv=10)\n", "cv_scores" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8054124748490945" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_scores.mean()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEWCAYAAAB42tAoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfxklEQVR4nO3de5RdZZnn8e+vAlEKMIApEELqHBjxgigRS7w2zTQqgSVD0zMjwfLSjG0Zl3HAC0Okultau0ZsvPUasJlCaRwpYWjBblBawrQIyLSYBBMhYOwYqipFGCiMiFB0Y5Jn/ti7wqFyatc5VWfXuf0+a53F2e++PfvslXrY77vf91VEYGZmNp2OegdgZmaNzYnCzMwyOVGYmVkmJwozM8vkRGFmZpmcKMzMLJMThVmTkHSypLF6x2Htx4nCmpqkYUnPSHpK0q8lfU/S0hod920Z60+WtDs971OSxiRdL+n1cz13yTlC0ktrdbwyx/+hpH8tuYbNeZ3LmpsThbWCMyLiAOBw4FHgf8zTeben5z0QeCPwc+AuSafM0/lrYVVEHJB+Xl7vYKwxOVFYy4iIfwW+DRw7WSbpBZK+IGlU0qOSrpC0X7pusaTvSnpC0g5Jd0nqkPRNoBu4Of0/7f82w3kjIsYi4s+BrwGfLzn/KyTdlh5/s6R3lay7Oo3nNkm/lXSHpEK67s50s41pDGeX7PcJSY9JekTSuXP+4cxm4ERhLUNSJ3A28OOS4s8DLwOWAS8FlgB/nq77BDAGdAGHAReR/N1/LzBK+qQSEX9VRRg3AidI2l/S/sBtwLeAQ4FzgK9KelXJ9r3AZ4HFwAZgiCSIk9L1x6cx/O90+SXAovQ6PgBcLungaX6P1ZK+O0O8n5P0uKS7JZ1cxXVaG9mn3gGY1cDfS9oJHAA8BpwKIEnAB4HXRMSOtOy/k/zh/hTwO5LqqkJEbAHuqkEs2wEBBwFvBYYj4m/TdfdKugH4T8CmtOx7EXFnGls/8BtJSyNi2zTH/x3wmYjYCdwi6Sng5Tw/OQIQEZfMEOuFwAPAs8AKkieoZRHxywqv1dqEnyisFfxhRBwEvABYBdwh6SUkTwqdwPq0eukJ4PtpOcClwBZgjaStklbXIJYlQABPAAXgDZPnTs/fS/JUMGlPQoiIp4AdwBEZx/9VmiQmTZAkyKpFxD0R8duI+LeI+AZwN3D6bI5lrc2JwlpGROyKiBuBXST/N/848Azwqog4KP0sShugSf9IfiIijgbOAD5e0hA922GVzwLujYinSZLAHSXnPiitRvpwyfZ73tCSdABwCMlTST0EydOQ2fM4UVjLUOJM4GDgwYjYDVwJfFnSoek2SyRNVk29U9JL0yqqJ0kSzK70cI8CR1dx3iWSPg38CUlbB8B3gZdJeq+kfdPP6yW9smT30yW9VdJCkraKe0qqnSqOoVqSDpJ0qqQXStpHUi9wEnBrHuez5uZEYa3g5rSu/klgAHh/REy2AVxIUr30Y0lPAv+HpE4f4Jh0+Sngn4GvRsQP03WfA/40rTL65DTnPSI971PAWuDVwMkRsQaSJxbgHST1/9uB/0fSuP6CkmN8C/g0SZXT60iqpiZdDHwjjeFdVEnSRZL+cZrV+wJ/CYyTPHl9lKQKz30pbC/yxEVm9SHpamAsIv603rGYZfEThZmZZXKiMDOzTLkmCknL096oW8q9eihpkaSbJW2UtGmyl6mkpZJul/RgWn5ennGa1UNE/LGrnawZ5NZGIWkB8Avg7SS9X9cC50TEAyXbXAQsiogLJXUBm0neMX8xcHhE3CvpQGA9SUPbA1PPY2Zm+cqzZ/aJwJaI2Aog6TrgTJKeoJMCODB9PfEAkjc/dkbEI8AjkLw5IulBko5MmYli8eLFUSwWa30dZmYta/369Y9HRFfWNnkmiiWU9Doleap4w5RtLgNuInl18EDg7PTd9z0kFYHXAveUO4mkPqAPoLu7m3Xr1tUgdDOz9iBpZKZt8myjKNfDc2o916kkA6EdQTJo22WSXrTnAElP1RuA8yPiyXIniYjBiOiJiJ6ursykaGZms5BnohijZHgC4Ej2HprgXODGdJjmLcBDwCsAJO1LkiSG0mEZzMysDvJMFGuBYyQdlQ5PsIKkmqnUKHAKgKTDSHrMbk3bLL5OMgzDl3KM0czMZpBbokhHuFxFMnbMg8D1EbFJ0kpJK9PNPgu8WdJ9wD8BF0bE48BbgPcCfyBpQ/rxqJZmZnWQ63wUEXELcMuUsitKvm8nGQtn6n4/wqNYmpk1BPfMbiJDQ0MUi0U6OjooFosMDQ3VOyQzawOe4a5JDA0N0dfXx8TEBAAjIyP09fUB0Nvbm7Wrmdmc+ImiSfT39+9JEpMmJibo7++vU0Rm1i6cKJrE6OhoVeVmZrXiRNEkuru7qyo3M6sVJ4omMTAwQGdn5/PKOjs7GRgYqFNEZtYunCiaRG9vL4ODgxQKBSRRKBQYHBx0Q7aZ5a6lpkLt6ekJDwpoZlY5SesjoidrGz9RmJlZJicKMzPL5ERhZmaZnCjMzCyTE4WZmWVyojAzs0xOFGZmlsmJwszMMjlRmJlZJicKMzPL5ERhZmaZck0UkpZL2ixpi6TVZdYvknSzpI2SNkk6t2TdVZIek3R/njGamVm23BKFpAXA5cBpwLHAOZKOnbLZR4AHIuJ44GTgi5IWpuuuBpbnFZ+ZmVUmzyeKE4EtEbE1Ip4FrgPOnLJNAAdKEnAAsAPYCRARd6bLZmZWR3kmiiXAtpLlsbSs1GXAK4HtwH3AeRGxu5qTSOqTtE7SuvHx8bnEa2ZmZeSZKFSmbOrkF6cCG4AjgGXAZZJeVM1JImIwInoioqerq2t2kZqZ2bTyTBRjwNKS5SNJnhxKnQvcGIktwEPAK3KMqeEMDQ1RLBbp6OigWCwyNDTU0Mc1s/azT47HXgscI+ko4GFgBfDuKduMAqcAd0k6DHg5sDXHmBrK0NAQfX19TExMADAyMkJfXx/AnKY4zeu4Ztaecp0KVdLpwFeABcBVETEgaSVARFwh6QiSt5sOJ6mquiQirkn3vZbkTajFwKPApyPi61nna7apUIvFIiMjI3uVFwoFhoeHG+64ZtZ6KpkK1XNm11FHRwflfn9J7N5dVZv+vBzXzFqP58xucN3d3VWV1/u4ZtaenCjqaGBggM7OzueVdXZ2MjAw0JDHNbP25ERRR729vQwODlIoFJBEoVBgcHBwzg3OeR3XzNqT2yjMzNqY2yjqyP0YzKxV5NmPom25H4OZtRI/UeSgv79/T5KYNDExQX9/f50iMjObPSeKHIyOjlZVbmbWyJwocuB+DGbWSpwocuB+DGbWSpwocuB+DGbWStyPwsysjbkfhZmZzZkThZmZZXKiMDOzTE4UZmaWyYnCzMwyOVGYmVkmJwozM8uUa6KQtFzSZklbJK0us36RpJslbZS0SdK5le5rZmbzI7dEIWkBcDlwGnAscI6kY6ds9hHggYg4HjgZ+KKkhRXua2Zm8yDPJ4oTgS0RsTUingWuA86csk0AB0oScACwA9hZ4b5mZjYP8kwUS4BtJctjaVmpy4BXAtuB+4DzImJ3hfsCIKlP0jpJ68bHx2sVu5mZpfJMFCpTNnVgqVOBDcARwDLgMkkvqnDfpDBiMCJ6IqKnq6trLvGamVkZeSaKMWBpyfKRJE8Opc4FbozEFuAh4BUV7mtmZvMgz0SxFjhG0lGSFgIrgJumbDMKnAIg6TDg5cDWCvc1M7N5sE9eB46InZJWAbcCC4CrImKTpJXp+iuAzwJXS7qPpLrpwoh4HKDcvnnFamZm0/N8FGZmbczzUdTY0NAQxWKRjo4OisUiQ0ND9Q7JzCx3uVU9tZqhoSH6+vqYmJgAYGRkhL6+PgBPcWpmLc1PFBXq7+/fkyQmTUxM0N/fX6eIzMzmhxNFhUZHR6sqNzNrFU4UFeru7q6q3MysVThRVGhgYIDOzs7nlXV2djIwMFCniMzM5ocTRYV6e3sZHBykUCggiUKhwODgoBuyzazluR+FmVkbcz8KMzObMycKMzPL5ERhZmaZnCjMzCyTE4V5DCszy+Sxntqcx7Ays5n4iaLNeQwrM5uJE0Wbq3YMK1dTmbUfJ4o2V80YVpPVVCMjI0TEnmoqJwuz1uZE0eaqGcPK1VRm7cmJos1VM4aVh1o3a0+5JgpJyyVtlrRF0uoy6y+QtCH93C9pl6RD0nXnpWWbJJ2fZ5ztrre3l+HhYXbv3s3w8PC0bzt5qHWz9pRbopC0ALgcOA04FjhH0rGl20TEpRGxLCKWAZ8C7oiIHZKOAz4InAgcD7xT0jF5xWqV8VDrZu0pzyeKE4EtEbE1Ip4FrgPOzNj+HODa9PsrgR9HxERE7ATuAM7KMVargIdaN2tPeXa4WwJsK1keA95QbkNJncByYFVadD8wIOnFwDPA6UDZ8cMl9QF94CqQ+dDb2+vEYNZm8nyiUJmy6Sa/OAO4OyJ2AETEg8DngduA7wMbgZ3ldoyIwYjoiYierq6uqoN0vwAzs2x5JooxYGnJ8pHA9mm2XcFz1U4ARMTXI+KEiDgJ2AH8S60DdL8AM7OZVTzDnaT9I+Lpig8s7QP8AjgFeBhYC7w7IjZN2W4R8BCwtPT4kg6NiMckdQNrgDdFxK+zzlntDHfFYpGRkZG9yguFAsPDwxUfx8ysWdVkhjtJb5b0APBguny8pK/OtF/aCL0KuDXd9/qI2CRppaSVJZueBawpk4RuSM97M/CRmZLEbLhfgJnZzCppzP4ycCpwE0BEbJR0UiUHj4hbgFumlF0xZflq4Ooy+/5eJeeYi+7u7rJPFG4UNzN7TkVtFBGxbUrRrhximXfuF2BmNrNKEsU2SW8GQtJCSZ8krYZqdu4XYGY2sxkbsyUtBv4aeBvJK69rgPMi4lf5h1edahuzzcza3Zwbs9NhOL4SEb0RcVhEHBoR72nEJGGNx31UzFpDZmN2ROyS1CVpYToMh1lFPMWqWeuopOrpfwInkLz1tOcV1oj4Ur6hVc9VT43DfVTMmkMlVU+VvB67Pf10AAfWIjBrfe6jYtY6ZkwUEfEXAJIOTBbjqdyjsqbnPipmraOSntnHSfopyYiumyStl/Sq/EOzZuY+Kmato5J+FIPAxyOiEBEF4BPAlfmGZc3OfVTMWkcljdkbI+L4mcoagRuzzcyqU6vG7K2S/gz4Zrr8HpLRXs3MrA1UUvX0X4Au4Mb0sxg4N8+gzMyscVTy1tOvgf86D7GYmVkDquStp9skHVSyfLCkW/MNy8zMGkUlVU+LI+KJyYX0CePQ/EIyM7NGUkmi2J1ORwqApAJQ2fypZmbW9Cp566kf+JGkO9Llk4C+/EIyM7NGUklj9vclnQC8MS36WEQ8nm9YZmbWKKatepJUkLQIIE0MTwNvB94naWElB5e0XNJmSVskrS6z/gJJG9LP/ZJ2STokXfcxSZvS8mslvXBWV2hmZnOS1UZxPbA/gKRlwN8Bo8DxwFdnOnA66dHlwGnAscA5ko4t3SYiLo2IZRGxDPgUcEdE7JC0hOSV3J6IOA5YAKyo9uLMzGzusqqe9ouI7en39wBXRcQXJXUAGyo49onAlojYCiDpOuBM4IFptj8HuHZKbPtJ+h3QSTLUuZmZzbOsJwqVfP8D4J8AImJ3hcdeAmwrWR5Ly/Y+kdQJLAduSM/xMPAFkieYR4DfRMSaafbtk7RO0rrx8fEKQzMzs0plJYofSLpe0l8DBwM/AJB0OFDJtKgqUzbda7VnAHdHxI70HAeTPH0cBRwB7C/pPeV2jIjBiOiJiJ6urq4KwjIzs2pkJYrzScZ2GgbeGhG/S8tfQvLK7EzGgKUly0cyffXRCp5f7fQ24KGIGE/PeyPw5grOaWZmNTZtG0Uk449fV6b8pxUeey1wjKSjgIdJksG7p26Uvln1+yTtIJNGgTemVVLPAKcAHj/czKwOKulwNysRsVPSKuBWkreWroqITZJWpuuvSDc9C1gTEU+X7HuPpG8D9wI7gZ+STKBkZmbzbMaJi5qJJy4yM6tOJRMXZXW4+6SkpdOtNzOz9pDVmL0E+L+S7pT0YUmL5ysoMzNrHNMmioj4GNAN/BnwGuBnkv5R0vskHThfAZqZWX1lDjMeiTsi4sMkr7p+BfgY8Oh8BGdmZvVX0VtPkl5N8nrr2cCvgIvyDMrMzBrHtIlC0jEk4y+tAHaR9Kl4x+TYTWZm1h6ynihuJektfXZE3DdP8ZiZWYPJShSnAodNTRKSfg/YHhG/zDUyMzNrCFmN2V8GnixT/gxJo7aZmbWBrERRjIifTS2MiHVAMbeIzMysoWQliqypR/erdSBmZtaYshLFWkkfnFoo6QPA+vxCsnY0NDREsViko6ODYrHI0NBQvUMys1RWY/b5wHck9fJcYugBFpKM+GpWE0NDQ/T19TExMQHAyMgIfX19APT29tYzNDOjgtFjJf174Lh0cVNE/CD3qGbJo8c2p2KxyMjIyF7lhUKB4eHh+Q/IrI1UMnrsjD2zI+J24PaaRWU2xejoaFXlZja/Msd6MpsP3d3dVZWb2fxyorC6GxgYoLOz83llnZ2dDAwM1CkiMyvlRGF119vby+DgIIVCAUkUCgUGBwfdkG3WIDwVqplZG5vTVKg1CmC5pM2StkhaXWb9BZI2pJ/7Je2SdIikl5eUb5D0pKTz84zVzMzKq2g+itmQtAC4HHg7MEbSge+miHhgcpuIuBS4NN3+DOBjEbED2AEsKznOw8B38orVzMyml+cTxYnAlojYGhHPksxncWbG9ueQDGs+1SnALyNi7xftzcwsd3kmiiXAtpLlsbRsL5I6geXADWVWr6B8Apnct0/SOknrxsfH5xCumZmVk2eiUJmy6VrOzwDuTqudnjuAtBD4D8DfTXeSiBiMiJ6I6Onq6pp1sNYcPCaU2fzLrY2C5AliacnykcD2abad7qnhNODeiHi0xrFZE/KYUGb1kecTxVrgGElHpU8GK4Cbpm4kaRHw+8A/lDnGdO0W1ob6+/v3JIlJExMT9Pf31ykis/aQ2xNFROyUtIpk7u0FwFURsUnSynT9FemmZwFrIuLp0v3Tdou3Ax/KK0ZrLh4Tyqw+8qx6IiJuAW6ZUnbFlOWrgavL7DsBvDjH8KzJdHd3lx1l1mNCmeXLQ3hY0/CYUGb14URhTcNjQpnVh8d6MjNrY3Uf68msFbkvh7WbXBuzzVqN+3JYO/IThVkV3JfD2pEThRmVVye5L4e1IycKa3uT1UkjIyNExJ7qpHLJwvN7WztyorC2V011kvtyWDtyorC2V011kvtyWDtyPwpre8VisezQIIVCgeHh4fkPyGweuR+FWQVcnWSWzYnC2p6rk8yyuerJzKyNuerJ2pqH2jCrDQ/hYS3JQ22Y1Y6fKKwleagNs9pxorCW5KE2zGrHicJakofaaE7VtCu5DWoeRURuH2A5sBnYAqwus/4CYEP6uR/YBRySrjsI+Dbwc+BB4E0zne91r3tdmEVEXHPNNdHZ2RnAnk9nZ2dcc8019Q7NplHNPfP9rR1gXcz0t3ymDWb7ARYAvwSOBhYCG4FjM7Y/A/hByfI3gD9Jvy8EDprpnE4UVuqaa66JQqEQkqJQKPiPSIMrFArP+8M/+SkUCnPa1rJVkihy60ch6U3AxRFxarr8KYCI+Nw0238LuD0irpT0IpLEcnRUEaD7UZg1r46ODsr9c5fE7t27Z72tZat3P4olwLaS5bG0bC+SOkmqqW5Ii44GxoG/lfRTSV+TtP80+/ZJWidp3fj4eO2iN7N5VU27Up5tUG772FueiUJlyqZ7OjgDuDsidqTL+wAnAH8TEa8FngZWl9sxIgYjoicierq6uuYas5nVSTVjbuU1Plc1c5O0kzwTxRiwtGT5SGD7NNuuAK6dsu9YRNyTLn+bJHGYWYuqZsytvMbncv+b8vJso9gH+AVwCvAwsBZ4d0RsmrLdIuAhYGlEPF1SfhdJY/ZmSRcD+0fEBVnndBuFmc1FO7Z9VNJGkdsQHhGxU9Iq4FaSN6CuiohNklam669INz0LWFOaJFIfBYYkLQS2AufmFauZGSRtHOXmJmn3/jcePdbMLDV1jDBI2j5aedj5er/1ZGbWVDw3SXl+ojAza2N+ojAzw30j5srzUZhZS/PcJHPnJwoza2nuGzF3ThRm1tI8N8ncOVGYWUtrhLlJmr2NxInCzFpaXuNCVaoVxo9yojCzllbvvhGt0EbifhRmZjlq9PGj3I/CzHLR7HXutVLJ71BtG0lD/rYzTYHXTB9PhWqWP89Xnaj0d2j0ucCp55zZ9fg4UZjlz/NVJ6r5HSqdv70ev20licJtFGZWlUavc58vefwO9fht3UZhZjWv826EfgmNII/foVF/WycKsxaWxzv89e6X0Cjy+B0a9redqW6qmT5uozB7vrzqvCutc291efwO8/3b4jYKs/bm9gSbidsozJpIHu/PV1Pnndf7+w3ZL8CqM9Mjx1w+wHJgM7AFWF1m/QXAhvRzP7ALOCRdNwzcl66b8dEoXPVkTSyv9+fzeNc/j/Nb/VTy9zXPJLEA+CVwNLAQ2Agcm7H9GcAPSpaHgcXVnNOJwppVnu/PV1Lnndf53eei8VWSKHJro5D0JuDiiDg1Xf4UQER8bprtvwXcHhFXpsvDQE9EPF7pOd1GYc2q3m0JeZ2/3tdlM6t3G8USYFvJ8lhathdJnSTVVDeUFAewRtJ6SX3TnURSn6R1ktaNj4/XIGyz+Vfv9+fzOn+9r8tqI89EoTJl0z2+nAHcHRE7SsreEhEnAKcBH5F0UrkdI2IwInoioqerq2tuEZvVSb3fn8/r/PW+LquNPBPFGLC0ZPlIYPs0264Ari0tiIjt6X8fA74DnJhDjGYNod5zJuR1/npfl9VGnm0U+wC/AE4BHgbWAu+OiE1TtlsEPAQsjYin07L9gY6I+G36/TbgMxHx/axzuo3CzKw6dW2jiIidwCrgVuBB4PqI2CRppaSVJZueBayZTBKpw4AfSdoI/AT43kxJwqwRuW+CtQL3zDbLyeQ4S6XTYHZ2ds656iWv41p7quSJwonCLCfFYpGRkZG9yguFAsPDww13XGtP9X491qytjY6OVlVe7+OaTceJwiwn7ptgeZuvtionCrOcuG+C5SmPuUamNdMYH8308VhP1mjymlvA80FYrcbRwvNRmJm1plqNo+XGbDOzFjWfbVVOFGZmTWg+26qcKMzMmtB8jqPlNgozszbmNgozM5szJwozM8vkRGFmZpmcKMzMLJMThZmZZWqpt54kjQOl4y8vBh6vUzh5atXrgta9Nl9X82nVa5t6XYWI6MraoaUSxVSS1s302lczatXrgta9Nl9X82nVa5vNdbnqyczMMjlRmJlZplZPFIP1DiAnrXpd0LrX5utqPq16bVVfV0u3UZiZ2dy1+hOFmZnNkROFmZllaslEIWm5pM2StkhaXe94aknSsKT7JG2Q1LRD5Uq6StJjku4vKTtE0m2S/iX978H1jHG2prm2iyU9nN63DZJOr2eMsyFpqaTbJT0oaZOk89Lypr5vGdfV1PdM0gsl/UTSxvS6/iItr/p+tVwbhaQFwC+AtwNjwFrgnIh4oK6B1YikYaAnIpq6I5Ckk4CngP8VEcelZX8F7IiIS9IEf3BEXFjPOGdjmmu7GHgqIr5Qz9jmQtLhwOERca+kA4H1wB8Cf0wT37eM63oXTXzPJAnYPyKekrQv8CPgPOCPqPJ+teITxYnAlojYGhHPAtcBZ9Y5JpsiIu4EdkwpPhP4Rvr9GyT/WJvONNfW9CLikYi4N/3+W+BBYAlNft8yrqupReKpdHHf9BPM4n61YqJYAmwrWR6jBW56iQDWSFovqa/ewdTYYRHxCCT/eIFD6xxPra2S9LO0aqqpqmemklQEXgvcQwvdtynXBU1+zyQtkLQBeAy4LSJmdb9aMVGoTFkr1a+9JSJOAE4DPpJWc1jj+xvg3wHLgEeAL9Y3nNmTdABwA3B+RDxZ73hqpcx1Nf09i4hdEbEMOBI4UdJxszlOKyaKMWBpyfKRwPY6xVJzEbE9/e9jwHdIqtpaxaNpffFkvfFjdY6nZiLi0fQf7W7gSpr0vqV13TcAQxFxY1rc9Pet3HW1yj0DiIgngB8Cy5nF/WrFRLEWOEbSUZIWAiuAm+ocU01I2j9tbEPS/sA7gPuz92oqNwHvT7+/H/iHOsZSU5P/MFNn0YT3LW0c/TrwYER8qWRVU9+36a6r2e+ZpC5JB6Xf9wPeBvycWdyvlnvrCSB9je0rwALgqogYqHNINSHpaJKnCIB9gG8167VJuhY4mWTI40eBTwN/D1wPdAOjwH+OiKZrFJ7m2k4mqcIIYBj40GQ9cbOQ9FbgLuA+YHdafBFJfX7T3reM6zqHJr5nkl5D0li9gOSh4PqI+IykF1Pl/WrJRGFmZrXTilVPZmZWQ04UZmaWyYnCzMwyOVGYmVkmJwozM8vkRGFWBUm70pFEN6Wjcn5c0qz/HUm6qOR7sXTEWbNG4URhVp1nImJZRLyKZITi00n6SczWRTNvYlZfThRms5QOo9JHMnCc0gHYLpW0Nh1I7kMAkk6WdKek70h6QNIVkjokXQLslz6hDKWHXSDpyvSJZU3ao9asrpwozOYgIraS/Ds6FPgA8JuIeD3weuCDko5KNz0R+ATwapKB5v4oIlbz3BNKb7rdMcDl6RPLE8B/nL+rMSvPicJs7iZHLH4H8L50WOd7gBeT/OEH+Ek6R8ou4FrgrdMc66GI2JB+Xw8U8wnZrHL71DsAs2aWjr+1i2QETgEfjYhbp2xzMnsPdT/d2Dn/VvJ9F+CqJ6s7P1GYzZKkLuAK4LJIBk27FfhwOmQ1kl6WjvILyVwAR6VvSJ1NMi0lwO8mtzdrVH6iMKvOfmnV0r7ATuCbwOTQ1F8jqSq6Nx26epznppn8Z+ASkjaKO3luFOBB4GeS7gX65+MCzKrl0WPNcpZWPX0yIt5Z71jMZsNVT2ZmlslPFGZmlslPFGZmlsmJwszMMjlRmJlZJicKMzPL5ERhZmaZ/j/Dqxo3fY8dxgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(1)\n", "\n", "best_score = 0\n", "\n", "for d in range(1,30):\n", " T = tree.DecisionTreeClassifier(max_depth = d)\n", " cv_score = cross_val_score(T, X_train, y_train, cv=10).mean()\n", " ax.scatter(d, cv_score, color = \"black\")\n", " if cv_score > best_score:\n", " best_depth = d\n", " best_score = cv_score\n", "\n", "l = ax.set(title = \"Best Depth : \" + str(best_depth),\n", " xlabel = \"Depth\", \n", " ylabel = \"CV Score\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have a reasonable estimate of the optimal depth, we can try evaluating against the unseen testing data. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8426966292134831" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "T = tree.DecisionTreeClassifier(max_depth = best_depth)\n", "T.fit(X_train, y_train)\n", "T.score(X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great! We even got slightly higher accuracy on the test set than we did in validation, although this is rare." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine Learning Workflow: The Big Picture\n", "\n", "We now have all of the elements that we need to execute the core machine learning workflow. At a high-level, here's what should go into a machine task:\n", "\n", "1. Separate out the test set from your data. \n", "2. Clean and prepare your data if needed. It is best practice to clean your training and test data separately. It's convenient to write a function for this. \n", "3. Identify a set of candidate models (e.g. decision trees with depth up to 30, logistic models with between 1 and 3 variables, etc). \n", "4. Use a validation technique (k-fold cross-validation is usually sufficient) to estimate how your models will perform on the unseen test data. Select the best model as measured by validation. \n", "5. Finally, score the best model against the test set and report the result. \n", "\n", "Of course, this isn't all there is to data science -- you still need to do exploratory analysis; interpret your model; etc. etc. \n", "\n", "We'll discuss model interpretation further in a coming lecture. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }