{ "cells": [ { "cell_type": "code", "execution_count": 222, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import statsmodels.formula.api as smf\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn import model_selection, metrics, linear_model, decomposition, cross_decomposition\n", "from itertools import combinations\n", "import numpy as np\n", "import time " ] }, { "cell_type": "code", "execution_count": 236, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>AtBat</th>\n", " <th>Hits</th>\n", " <th>HmRun</th>\n", " <th>Runs</th>\n", " <th>RBI</th>\n", " <th>Walks</th>\n", " <th>Years</th>\n", " <th>CAtBat</th>\n", " <th>CHits</th>\n", " <th>CHmRun</th>\n", " <th>CRuns</th>\n", " <th>CRBI</th>\n", " <th>CWalks</th>\n", " <th>League</th>\n", " <th>Division</th>\n", " <th>PutOuts</th>\n", " <th>Assists</th>\n", " <th>Errors</th>\n", " <th>Salary</th>\n", " <th>NewLeague</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>1</th>\n", " <td>315</td>\n", " <td>81</td>\n", " <td>7</td>\n", " <td>24</td>\n", " <td>38</td>\n", " <td>39</td>\n", " <td>14</td>\n", " <td>3449</td>\n", " <td>835</td>\n", " <td>69</td>\n", " <td>321</td>\n", " <td>414</td>\n", " <td>375</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>632</td>\n", " <td>43</td>\n", " <td>10</td>\n", " <td>475.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>479</td>\n", " <td>130</td>\n", " <td>18</td>\n", " <td>66</td>\n", " <td>72</td>\n", " <td>76</td>\n", " <td>3</td>\n", " <td>1624</td>\n", " <td>457</td>\n", " <td>63</td>\n", " <td>224</td>\n", " <td>266</td>\n", " <td>263</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>880</td>\n", " <td>82</td>\n", " <td>14</td>\n", " <td>480.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>496</td>\n", " <td>141</td>\n", " <td>20</td>\n", " <td>65</td>\n", " <td>78</td>\n", " <td>37</td>\n", " <td>11</td>\n", " <td>5628</td>\n", " <td>1575</td>\n", " <td>225</td>\n", " <td>828</td>\n", " <td>838</td>\n", " <td>354</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>200</td>\n", " <td>11</td>\n", " <td>3</td>\n", " <td>500.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>321</td>\n", " <td>87</td>\n", " <td>10</td>\n", " <td>39</td>\n", " <td>42</td>\n", " <td>30</td>\n", " <td>2</td>\n", " <td>396</td>\n", " <td>101</td>\n", " <td>12</td>\n", " <td>48</td>\n", " <td>46</td>\n", " <td>33</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>805</td>\n", " <td>40</td>\n", " <td>4</td>\n", " <td>91.5</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>594</td>\n", " <td>169</td>\n", " <td>4</td>\n", " <td>74</td>\n", " <td>51</td>\n", " <td>35</td>\n", " <td>11</td>\n", " <td>4408</td>\n", " <td>1133</td>\n", " <td>19</td>\n", " <td>501</td>\n", " <td>336</td>\n", " <td>194</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>282</td>\n", " <td>421</td>\n", " <td>25</td>\n", " <td>750.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns \\\n", "1 315 81 7 24 38 39 14 3449 835 69 321 \n", "2 479 130 18 66 72 76 3 1624 457 63 224 \n", "3 496 141 20 65 78 37 11 5628 1575 225 828 \n", "4 321 87 10 39 42 30 2 396 101 12 48 \n", "5 594 169 4 74 51 35 11 4408 1133 19 501 \n", "\n", " CRBI CWalks League Division PutOuts Assists Errors Salary NewLeague \n", "1 414 375 1.0 1.0 632 43 10 475.0 1.0 \n", "2 266 263 0.0 1.0 880 82 14 480.0 0.0 \n", "3 838 354 1.0 0.0 200 11 3 500.0 1.0 \n", "4 46 33 1.0 0.0 805 40 4 91.5 1.0 \n", "5 336 194 0.0 1.0 282 421 25 750.0 0.0 " ] }, "execution_count": 236, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hitters_dat = pd.read_csv('hitters.csv')\n", "hitters_dat = hitters_dat.drop(hitters_dat.columns[0], axis=1)\n", "hitters_dat = hitters_dat.dropna(axis=0)\n", "hitters_dat['League'] = hitters_dat['League'].map({'N' : 1., 'A' : 0.})\n", "hitters_dat['Division'] = hitters_dat['Division'].map({'W' : 1., 'E' : 0.})\n", "hitters_dat['NewLeague'] = hitters_dat['NewLeague'].map({'N' : 1., 'A' : 0.})\n", "#hitters_dat = (hitters_dat - hitters_dat.mean()) / hitters_dat.std()\n", "hitters_dat.head()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Best subset selection\n", "# There are 2 ^ 19 ~= 5 * 10^5 possible subsets of all features.\n", "# Trying them all takes ages.\n", "# For 4 predictors there are (19 choose 1) + .. + (19 choose 4) ~= 5 * 10^3 possible models.\n", "\n", "def n_choose_k(n, k):\n", " return int(np.math.factorial(n) / (np.math.factorial(n - k) * np.math.factorial(k)))\n", "\n", "def best_subset(n_predictors, target_column, data):\n", " i = 1\n", " predictors = data.drop(target_column, axis=1).columns\n", " top_models = []\n", " while i <= n_predictors:\n", " tick = time.time()\n", " cmbs = list(combinations(predictors, i))\n", " formulae = ['Salary ~ ' + ' + '.join(s) for s in cmbs]\n", " models = [smf.ols(formula=f, data=data).fit() for f in formulae]\n", " top_model = sorted(models, key=lambda m: m.rsquared, reverse=True)[0]\n", " top_models.append(top_model)\n", " tock = time.time()\n", " print('{}/{} : {} combinations, {} seconds'.format(i, n_predictors, n_choose_k(n_predictors, i), tock-tick))\n", " i += 1\n", " return pd.DataFrame({'model' : top_models, 'n_predictors' : range(1, n_predictors + 1)})\n" ] }, { "cell_type": "code", "execution_count": 238, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/3 : 3 combinations, 0.07082295417785645 seconds\n", "2/3 : 3 combinations, 0.7477953433990479 seconds\n", "3/3 : 1 combinations, 5.140793085098267 seconds\n" ] } ], "source": [ "# Select between models with different numbers of predictors using adjusted training metrics\n", "best = best_subset(3, 'Salary', hitters_dat)\n", "best['adj_R_sq'] = best['model'].map(lambda m: m.rsquared_adj)\n", "best['bic'] = best['model'].map(lambda m: m.bic)\n", "best['aic'] = best['model'].map(lambda m: m.aic)" ] }, { "cell_type": "code", "execution_count": 239, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x12211c278>" ] }, "execution_count": 239, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 1080x1080 with 3 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metrics\n", "_, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(15,15))\n", "sns.barplot(y='n_predictors', x='adj_R_sq', data=best, ax=ax1, orient='h')\n", "sns.barplot(y='n_predictors', x='bic', data=best, ax=ax2, orient='h')\n", "sns.barplot(y='n_predictors', x='aic', data=best, ax=ax3, orient='h')" ] }, { "cell_type": "code", "execution_count": 249, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/3 : 3 combinations, 0.07045507431030273 seconds\n", "2/3 : 3 combinations, 0.7818019390106201 seconds\n", "3/3 : 1 combinations, 5.664794921875 seconds\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>199054.342456</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>146475.652896</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>140876.193344</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 1 \n", "1 <statsmodels.regression.linear_model.Regressio... 2 \n", "2 <statsmodels.regression.linear_model.Regressio... 3 \n", "\n", " MSE \n", "0 199054.342456 \n", "1 146475.652896 \n", "2 140876.193344 " ] }, "execution_count": 249, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using the validation set approach\n", "train, test = model_selection.train_test_split(hitters_dat, test_size=0.2)\n", "best = best_subset(3, 'Salary', train)\n", "best['MSE'] = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", "best" ] }, { "cell_type": "code", "execution_count": 251, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x114b395c0>" ] }, "execution_count": 251, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEKCAYAAAD+XoUoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEStJREFUeJzt3XuQZGV5x/Hv4+5yUVZAWa11VxiwVIJKBLYsDJcqqUjARKq8JIFSwUtJEi/B0iQlkjLR/4zRqqhJBBXRBI0SJa4Gg6gEIyboLi67XOUSvJAVBKOgUsjikz/OO0vvZGanz0y/0z28309V15x+z+lznjnd85szb7/9TmQmkqQ2PGrcBUiSlo6hL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWrIynEXMOiAAw7IqampcZchScvK5s2b787MNcNsO1GhPzU1xaZNm8ZdhiQtKxHx3WG3tXtHkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNWSihmze8IN7OOpPPz7uMiRpSW1+9+lLdiyv9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ2pGvoRcX5E3BUR19Y8jiRpOLWv9C8ATqp8DEnSkKqGfmZ+DfhxzWNIkoZnn74kNWTsoR8RZ0bEpojYtOMX9427HEl6RBt76GfmeZm5ITM3rHz06nGXI0mPaGMPfUnS0qk9ZPOTwH8CT4+IH0TEa2oeT5K0eytr7jwzT6u5f0lSP3bvSFJDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGrJy3AUM+rX1j2fTu08fdxmS9Ijllb4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDhp5lMyIeA9yfmb+KiKcBhwJfzMwHR1XML7dfx/fe+axR7U5SJQe+fdu4S9AC9bnS/xqwV0SsA74EvAK4oEZRkqQ6+oR+ZOYvgBcDf5eZvws8o05ZkqQaeoV+RDwXeBnwr6VtxehLkiTV0if0zwLOBi7OzOsi4hDg8jplSZJqGOqN3IhYAZySmadMt2XmbcAf1ypMkjR6Q13pZ+ZDwLGVa5EkVdbnH6N/OyI2AhcBP59uzMzPjrwqSVIVfUJ/L+Ae4ISBtgQMfUlaJoYO/cx8Vc1CJEn1DT16JyLWR8TFEXFXuX0mItbXLE6SNFp9hmx+FNgIPKncPl/aJEnLRJ/QX5OZH83MHeV2AbCmUl2SpAr6hP49EfHyiFhRbi+ne2NXkrRM9An9VwO/B/wQ2A68FHhlhZokSZX0GbK5fvATuQARcQzw/dGWJEmqpc+V/vuHbJMkTah5r/TLzJq/AayJiDcPrHoszrIpScvKMN07ewD7lG1XD7TfS9evL0laJuYN/cy8ArgiIi7IzO8uQU2SpEr69Ol/OCL2m74TEftHxKUVapIkVdIn9A/IzJ9M38nM/wWeMPqSJEm19An9X0XEgdN3IuIgulk2JUnLRJ9x+ucAX4+IK4AAjgPOrFKVJKmKPlMr/1tEHAkcXZrelJl31ylLklTDvN07EXFo+XokcCDwP+V2YGnb3WOfHBGXR8T1EXFdRJw1iqIlSQszzJX+W4DXAu+ZZV2y63/SmmkH8JbMvDoiVgObI+KyzLy+f6mSpMUaZpz+a8vX5/XdeWZup5ucjcy8LyJuANYBhr4kjcEw0zC8eHfrh/3H6BExBRwBXDXM9pKk0Rume+eF5esT6Obg+Wq5/zzgGwzxj9EjYh/gM3Rv/t47Y92ZlFFA6/ZdNVzVkqQFGaZ751UAEfEl4LDSZUNErAUumO/xEbGKLvAvnO2vgsw8DzgP4PB1ezvuX5Iq6vPhrCdPB35xJ91onjlFRAAfAW7IzPcuoD5J0gj1+XDWV8pcO58s938f+PI8jzkGeAWwLSK2lLa3ZeYl/cqUJI1Cnw9nvSEiXgQcX5rOy8yL53nM1+k+vStJmgB9rvQBrgbuy8wvR8SjI2J1Zt5XozBJ0ugN3acfEa8F/hk4tzStA/6lRlGSpDr6vJH7ero++nsBMvNmnFpZkpaVPqH/QGb+cvpORKzEqZUlaVnpE/pXRMTbgL0j4vnARcDn65QlSaqhT+i/FfgRsA34A+AS4M9rFCVJqmOo0TsRsQL4eGa+DPhQ3ZIkSbUMdaWfmQ8BB0XEHpXrkSRV1Gec/m3AlRGxEfj5dKPTK0jS8tEn9G8tt0cBq+uUI0mqqc80DO8AiIjHdnf9JK4kLTd9PpG7ISK2AVvpJlC7JiKOqleaJGnU+nTvnA+8LjP/AyAijgU+ChxeozBJ0uj1Gaf/0HTgw84ZNHeMviRJUi19rvSviIhz6ebTT7r59P89Io4EyMyrK9QnSRqhPqH/6+XrX8xoP4Lul8AJI6lIklRNn9E7z9vd+og4IzM/tviSJEm19OnTn89ZI9yXJKmCUYa+/xZRkibcKEPfufUlacJ5pS9JDRn6jdyI2BN4CTA1+LjMfGdZvHKklUmSRq7PkM3PAT8FNgMPzFyZmW8YVVGSpDr6hP76zDypWiWSpOr6hP43IuJZmbmtVjF7rH0GB759U63dS1Lz+oT+scArI+K/6bp3gm6KZSdck6Rlok/on1ytCknSkugzDcN3axYiSapvlOP0JUkTztCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JA+E65Vd+NdN3LM+48ZdxlS0658o/8E75HMK31JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDqoZ+ROwVEd+MiGsi4rqIeEfN40mSdm9l5f0/AJyQmT+LiFXA1yPii5n5X5WPK0maRdXQz8wEflburiq3rHlMSdLcqvfpR8SKiNgC3AVclplX1T6mJGl21UM/Mx/KzGcD64HnRMQzB9dHxJkRsSkiNj34swdrlyNJTVuy0TuZ+RPgcuCkGe3nZeaGzNywap9VS1WOJDWp9uidNRGxX1neG3g+cGPNY0qS5lZ79M5a4GMRsYLuF8ynM/MLlY8pSZpD7dE7W4Ejah5DkjQ8P5ErSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqyMpxFzDo0CccypVvvHLcZUjSI5ZX+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0JakhkZnjrmGniLgPuGncdczjAODucRexG9a3ONa3ONa3OAut76DMXDPMhhM1Th+4KTM3jLuI3YmITZNco/UtjvUtjvUtzlLUZ/eOJDXE0Jekhkxa6J837gKGMOk1Wt/iWN/iWN/iVK9vot7IlSTVNWlX+pKkmjJzIm7ASXTDNW8B3lr5WE8GLgeuB64DzirtfwncAWwptxcMPObsUttNwG/NVzdwMHBVaf8UsEfPGm8HtpU6NpW2xwGXATeXr/uX9gDeV461FThyYD9nlO1vBs4YaD+q7P+W8tjoUdvTB87RFuBe4E3jPH/A+cBdwLUDbdXP11zHGLK+dwM3lhouBvYr7VPA/QPn8YMLrWN33+sQ9VV/PoE9y/1byvqpHvV9aqC224EtYzx/c2XKxLwGd+5nIaE56huwArgVOATYA7gGOKzi8dZOn2RgNfAd4LDyIv+TWbY/rNS0Z3nx3lpqnrNu4NPAqWX5g8Af9azxduCAGW1/Nf2DBLwVeFdZfgHwxfJCOhq4auDFcFv5un9Znn7RfbNsG+WxJy/iufshcNA4zx9wPHAku4ZC9fM11zGGrO9EYGVZftdAfVOD283YT6865vpeh6yv+vMJvI4SysCpwKeGrW/G+vcAbx/j+ZsrUybmNbiz1oX8oI/6BjwXuHTg/tnA2Ut4/M8Bz9/Ni3yXeoBLS82z1l2elLt5+Ad6l+2GrOl2/n/o3wSsHXiR3VSWzwVOm7kdcBpw7kD7uaVtLXDjQPsu2/Ws80TgyrI81vPHjB/2pThfcx1jmPpmrHsRcOHutltIHXN9r0Oev+rP5/Rjy/LKst2sf3Xu5rwE8H3gqeM8fzOONZ0pE/UazMyJ6dNfR/ekTftBaasuIqaAI+j+tAR4Q0RsjYjzI2L/eeqbq/3xwE8yc8eM9j4S+FJEbI6IM0vbEzNze1n+IfDEBda3rizPbF+IU4FPDtyflPMHS3O+5jpGX6+mu3qbdnBEfDsiroiI4wbq7lvHYn+2aj+fOx9T1v+0bN/HccCdmXnzQNvYzt+MTJm41+CkhP5YRMQ+wGeAN2XmvcDfA08Bng1sp/uTcVyOzcwjgZOB10fE8YMrs/u1nmOprIiIPYBTgItK0ySdv10sxfla6DEi4hxgB3BhadoOHJiZRwBvBj4REY+tXccsJvb5nOE0dr3wGNv5myVTRrLfYQ1zjEkJ/Tvo3giZtr60VRMRq+ienAsz87MAmXlnZj6Umb8CPgQ8Z5765mq/B9gvIlbOaB9aZt5Rvt5F9ybfc4A7I2JtqX8t3RtbC6nvjrI8s72vk4GrM/POUuvEnL9iKc7XXMcYSkS8Evgd4GXlB5bMfCAz7ynLm+n6yZ+2wDoW/LO1RM/nzseU9fuW7YdSHvNiujd1p+sey/mbLVMWsN/qr8FJCf1vAU+NiIPL1eOpwMZaB4uIAD4C3JCZ7x1oXzuw2YuAa8vyRuDUiNgzIg4Gnkr3psqsdZcf3suBl5bHn0HXxzdsfY+JiNXTy3T95teWOs6YZZ8bgdOjczTw0/Ln3qXAiRGxf/nT/ES6vtTtwL0RcXQ5F6f3qW/ALldYk3L+BizF+ZrrGPOKiJOAPwNOycxfDLSviYgVZfkQuvN12wLrmOt7Haa+pXg+B+t+KfDV6V9+Q/pNur7unV0f4zh/c2XKAvZb/zU43xsSS3Wjezf7O3S/lc+pfKxj6f4E2srAcDTgH+iGRG0tJ3LtwGPOKbXdxMBIl7nqphvB8E264VUXAXv2qO8QupEP19AN/zqntD8e+Ard0KwvA4/Lh9/I+ttSwzZgw8C+Xl1quAV41UD7Brof4luBD9BjyGZ5/GPorsj2HWgb2/mj++WzHXiQrr/zNUtxvuY6xpD13ULXf7vL0ELgJeV53wJcDbxwoXXs7nsdor7qzyewV7l/S1l/yLD1lfYLgD+cse04zt9cmTIxr8Hpm5/IlaSGTEr3jiRpCRj6ktQQQ1+SGmLoS1JDDH1Jaoihr+ZFREbEPw7cXxkRP4qIL5T7T4yIL0TENRFxfURcUtqnIuL+iNgycDt9XN+HNIxJ+8fo0jj8HHhmROydmffTTZQ1+KnLdwKXZebfAETE4QPrbs3MZy9dqdLieKUvdS4Bfrssz5zLZS0Dk11l5tYlrEsaKUNf6vwT3dQCewGH8/Csq9B9cvIjEXF5RJwTEU8aWPeUGd07xyFNMLt3JLqr9+imxD2N7qp/cN2lZQ6Xk+gmmft2RDyzrLZ7R8uKV/rSwzYCf82uXTsAZOaPM/MTmfkKuonFjp+5jbQcGPrSw84H3pGZ2wYbI+KEiHh0WV5NN8f898ZQn7Rodu9IRXbT875vllVHAR+IiB10F0ofzsxvle6gp0TEloFtz8/M2fYhTQRn2ZSkhti9I0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWrI/wE2f6A5uyp6/wAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 309, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/3 : 3 combinations, 0.07561683654785156 seconds\n", "2/3 : 3 combinations, 0.7820789813995361 seconds\n", "3/3 : 1 combinations, 5.312049150466919 seconds\n", "1/3 : 3 combinations, 0.06390810012817383 seconds\n", "2/3 : 3 combinations, 0.7506320476531982 seconds\n", "3/3 : 1 combinations, 5.409984111785889 seconds\n", "1/3 : 3 combinations, 0.07043266296386719 seconds\n", "2/3 : 3 combinations, 0.7198760509490967 seconds\n", "3/3 : 1 combinations, 5.0563881397247314 seconds\n", "1/3 : 3 combinations, 0.06885409355163574 seconds\n", "2/3 : 3 combinations, 0.6995189189910889 seconds\n", "3/3 : 1 combinations, 4.871701955795288 seconds\n", "1/3 : 3 combinations, 0.06180310249328613 seconds\n", "2/3 : 3 combinations, 0.6945500373840332 seconds\n", "3/3 : 1 combinations, 5.0047760009765625 seconds\n", "1/3 : 3 combinations, 0.059381961822509766 seconds\n", "2/3 : 3 combinations, 0.8878071308135986 seconds\n", "3/3 : 1 combinations, 5.055046796798706 seconds\n", "1/3 : 3 combinations, 0.06460094451904297 seconds\n", "2/3 : 3 combinations, 0.6986589431762695 seconds\n", "3/3 : 1 combinations, 4.979748725891113 seconds\n", "1/3 : 3 combinations, 0.06681489944458008 seconds\n", "2/3 : 3 combinations, 0.7132260799407959 seconds\n", "3/3 : 1 combinations, 5.130325078964233 seconds\n", "1/3 : 3 combinations, 0.06234002113342285 seconds\n", "2/3 : 3 combinations, 0.6919701099395752 seconds\n", "3/3 : 1 combinations, 4.884754180908203 seconds\n", "1/3 : 3 combinations, 0.060920000076293945 seconds\n", "2/3 : 3 combinations, 0.7059929370880127 seconds\n", "3/3 : 1 combinations, 5.235835313796997 seconds\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>149077.418140</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>131847.483984</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>137281.842881</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 1 \n", "1 <statsmodels.regression.linear_model.Regressio... 2 \n", "2 <statsmodels.regression.linear_model.Regressio... 3 \n", "\n", " MSE \n", "0 149077.418140 \n", "1 131847.483984 \n", "2 137281.842881 " ] }, "execution_count": 309, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using cross validation\n", "results = pd.DataFrame()\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train = hitters_dat.iloc[train_idx]\n", " test = hitters_dat.iloc[test_idx]\n", " best = best_subset(3, 'Salary', train)\n", " mse = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", " results = results.append(mse, ignore_index=True)\n", " \n", "best['MSE'] = results.mean()\n", "best" ] }, { "cell_type": "code", "execution_count": 310, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x11e11f278>" ] }, "execution_count": 310, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEKCAYAAAARnO4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAENhJREFUeJzt3XuQZGV9xvHvk12uggKCZmXFBasiAUWBrUQCWkrFiMRoaW5S3lAjqXgJRispkJQJ/qckKa+JEOVignhBUSQYRENQMUEX5H4Jd8WgXIyCaCngL3/0GbZ33GXnDH2me3i/n6quOf2e7nN+8/bMM2fec/rtVBWSpEe+X5l2AZKkpWHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhqxctoFjNt5551rzZo10y5DkpaViy666M6q2mVzj5upwF+zZg3r1q2bdhmStKwkuWUhj3NIR5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDVipi7LvPrWu9j/Lz867TIkaUlddNyrlmQ/HuFLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaMWjgJzkxye1JrhhyP5KkzRv6CP9k4JCB9yFJWoBBA7+qvgL8YMh9SJIWxjF8SWrE1AM/yRFJ1iVZd/9P7pl2OZL0iDX1wK+qE6pqbVWtXbnt9tMuR5IesaYe+JKkpTH0ZZmnAf8FPCXJrUleN+T+JEmbtnLIjVfVYUNuX5K0cA7pSFIjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGrFy2gWM+/XVj2Xdca+adhmS9IjkEb4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjFjxbZpJHAT+tql8k+TVgT+ALVXXfpIr5+W1X8u13Pm1Sm5OkX7LbOy6fdglT0+cI/yvA1kl2Bb4IvBI4eYiiJEmT1yfwU1U/AV4K/GNV/SGw9zBlSZImrVfgJzkAeDnwb13bismXJEkaQp/APxI4Gjijqq5Msgdw3jBlSZImbUEnbZOsAF5UVS+aa6uqG4E/H6owSdJkLegIv6oeAA4auBZJ0oD6fIj5t5KcCXwKuHeusao+M/GqJEkT1yfwtwbuAg4eayvAwJekZWDBgV9VrxmyEEnSsBZ8lU6S1UnOSHJ7d/t0ktVDFidJmpw+l2WeBJwJPKG7fb5rkyQtA30Cf5eqOqmq7u9uJwO7DFSXJGnC+gT+XUlekWRFd3sFo5O4kqRloE/gvxb4I+B7wG3AHwCHD1CTJGkAfS7LXD3+TluAJAcC35lsSZKkIfQ5wn//AtskSTNos0f43QyZvwXskuStY6sejbNlStKysZAhnS2B7brHbj/WfjejcXxJ0jKw2cCvqvOB85OcXFW3LEFNkqQB9BnD/3CSHebuJNkxyTkD1CRJGkCfwN+5qn44d6eq/g943ORLkiQNoU/g/yLJbnN3kjyJ0WyZkqRloM91+McAX0tyPhDgWcARg1QlSZq4PtMj/3uS/YBndk1vqao7hylLkjRpmx3SSbJn93U/YDfgf7vbbl3bQz33iUnOS3JVkiuTHDmJoiVJ/S3kCP9twOuBv9/IumLDT8Ca737gbVV1cZLtgYuSnFtVV/UvVZL0cCzkOvzXd1+f23fjVXUbo4nWqKp7klwN7AoY+JK0xBYytcJLH2r9Qj/EPMkaYF/gwoU8XpI0WQsZ0vm97uvjGM2p8x/d/ecCX2cBH2KeZDvg04xO9N49b90RdFf77PqYLRZWtSSpt4UM6bwGIMkXgb26YRqSrAJO3tzzk2zBKOxP3dh/A1V1AnACwD67buN1/ZI0kD5vvHriXNh3vs/oqp1NShLgI8DVVfUPi6hPkjQhfd549eVu7pzTuvt/DHxpM885EHglcHmSS7q2t1fV2f3KlCQ9XH3eePWmJC8Bnt01nVBVZ2zmOV9j9K5cSdKU9TnCB7gYuKeqvpRk2yTbV9U9QxQmSZqsBY/hJ3k9cDpwfNe0K/DZIYqSJE1en5O2b2Q0Jn83QFVdh9MjS9Ky0Sfwf1ZVP5+7k2QlTo8sSctGn8A/P8nbgW2SPA/4FPD5YcqSJE1an8A/CrgDuBz4U+Bs4K+HKEqSNHkLukonyQrgo1X1cuCfhy1JkjSEBR3hV9UDwJOSbDlwPZKkgfS5Dv9G4IIkZwL3zjU6ZYIkLQ99Av+G7vYrwPbDlCNJGkqfqRWOBUjy6NFd32ErSctJn3fark1yOXAZo8nQLk2y/3ClSZImqc+QzonAG6rqqwBJDgJOAvYZojBJ0mT1uQ7/gbmwhwdnwrx/8iVJkobQ5wj//CTHM5oPvxjNh/+fSfYDqKqLB6hPkjQhfQL/6d3Xv5nXvi+jPwAHT6QiSdIg+lyl89yHWp/k1VV1ysMvSZI0hD5j+Jtz5AS3JUmasEkGvh9lKEkzbJKB79z4kjTDPMKXpEYs+KRtkq2A3wfWjD+vqt7ZLV4w0cokSRPV57LMzwE/Ai4CfjZ/ZVW9aVJFSZImr0/gr66qQwarRJI0qD6B//UkT6uqy4cqZstVe7PbO9YNtXlJalqfwD8IODzJTYyGdMJommQnT5OkZaBP4L9gsCokSYPrM7XCLUMWIkka1iSvw5ckzTADX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1Ij+kyeNrhrbr+GA99/4LTLkKQHXfDmR86H+XmEL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJasSggZ9k6yTfSHJpkiuTHDvk/iRJm7Zy4O3/DDi4qn6cZAvga0m+UFX/PfB+JUnzDBr4VVXAj7u7W3S3GnKfkqSNG3wMP8mKJJcAtwPnVtWFQ+9TkvTLBg/8qnqgqp4BrAZ+I8lTx9cnOSLJuiTr7vvxfUOXI0nNWrKrdKrqh8B5wCHz2k+oqrVVtXaL7bZYqnIkqTlDX6WzS5IduuVtgOcB1wy5T0nSxg19lc4q4JQkKxj9cflkVZ018D4lSRsx9FU6lwH7DrkPSdLC+E5bSWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSI1ZOu4Bxez5uTy548wXTLkOSHpE8wpekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNSFVNu4YHJbkHuHbadSzQzsCd0y5igax1GMul1uVSJ1jrYj2pqnbZ3INm6jp84NqqWjvtIhYiyTprnTxrnbzlUidY69Ac0pGkRhj4ktSIWQv8E6ZdQA/WOgxrnbzlUidY66Bm6qStJGk4s3aEL0kayMwEfpJDklyb5PokRy3RPp+Y5LwkVyW5MsmRXftOSc5Ncl33dceuPUne19V4WZL9xrb16u7x1yV59Vj7/kku757zviR5mDWvSPKtJGd193dPcmG3/U8k2bJr36q7f323fs3YNo7u2q9N8vyx9om9Bkl2SHJ6kmuSXJ3kgFnt1yR/0b3+VyQ5LcnWs9KvSU5McnuSK8baBu/HTe2jZ53Hda//ZUnOSLLDYvtqMa9Hn1rH1r0tSSXZedp9OoiqmvoNWAHcAOwBbAlcCuy1BPtdBezXLW8P/A+wF/Bu4Kiu/SjgXd3yocAXgADPBC7s2ncCbuy+7tgt79it+0b32HTPfcHDrPmtwMeAs7r7nwRe1i1/CPizbvkNwIe65ZcBn+iW9+r6dytg967fV0z6NQBOAf6kW94S2GEW+xXYFbgJ2GasPw+flX4Fng3sB1wx1jZ4P25qHz3r/B1gZbf8rrE6e/dV39ejb61d+xOBc4BbgJ2n3adD3JZ0Zw/xAhwAnDN2/2jg6CnU8TngeYze/LWqa1vF6P0BAMcDh409/tpu/WHA8WPtx3dtq4Brxto3eNwi6lsNfBk4GDir+4G6c+yX6sF+7H5wD+iWV3aPy/y+nXvcJF8D4DGMQjTz2meuXxkF/ne6X9yVXb8+f5b6FVjDhkE6eD9uah996py37iXAqRvrg8311WJ+zhdTK3A68HTgZtYH/lT7dNK3WRnSmfulm3Nr17Zkun8F9wUuBB5fVbd1q74HPL5b3lSdD9V+60baF+s9wF8Bv+juPxb4YVXdv5HtP1hTt/5H3eP7fg+LsTtwB3BSRsNPH07yKGawX6vqu8DfAd8GbmPUTxcxm/06Zyn6cVP7WKzXMjraXUydi/k57yXJi4HvVtWl81bNcp/2NiuBP1VJtgM+Dbylqu4eX1ejP8dTv5QpyQuB26vqomnXsgArGf3L/E9VtS9wL6N/YR80Q/26I/BiRn+kngA8CjhkqkX1sBT9+HD3keQY4H7g1IkVNUFJtgXeDrxjqfY5rZ//WQn87zIaP5uzumsbXJItGIX9qVX1ma75+0lWdetXAbdvps6Hal+9kfbFOBB4UZKbgY8zGtZ5L7BDkrkpMsa3/2BN3frHAHct4ntYjFuBW6vqwu7+6Yz+AMxiv/42cFNV3VFV9wGfYdTXs9ivc5aiHze1j16SHA68EHh5F3KLqfMu+r8efTyZ0R/8S7vfr9XAxUl+dRG1Dt6nD8tSjyFtYjxtJaOTHruz/mTN3kuw3wAfBd4zr/04Njy58u5u+XfZ8ATON7r2nRiNWe/Y3W4CdurWzT+Bc+gE6n4O60/afooNT2a9oVt+IxuezPpkt7w3G54wu5HRybKJvgbAV4GndMt/2/XpzPUr8JvAlcC23bZOAd48S/3KL4/hD96Pm9pHzzoPAa4Cdpn3uN591ff16FvrvHU3s34Mf6p9Ounbku5sMy/AoYyukrkBOGaJ9nkQo3+rLgMu6W6HMhoD/DJwHfClsRcywAe7Gi8H1o5t67XA9d3tNWPta4Eruud8gAWcUFpA3c9hfeDv0f2AXd/9UmzVtW/d3b++W7/H2POP6eq5lrGrWyb5GgDPANZ1ffvZ7pdiJvsVOBa4ptvevzAKopnoV+A0RucW7mP0n9PrlqIfN7WPnnVez2ice+5360OL7avFvB59ap23/mbWB/7U+nSIm++0laRGzMoYviRpYAa+JDXCwJekRhj4ktQIA1+SGmHgq3nd7Ij/OnZ/ZZI7sn5G0scnOSvJpRnNrHp2174myU+TXDJ2e9W0vg9pc2btQ8ylabgXeGqSbarqp4wm0Bt/N+w7gXOr6r0ASfYZW3dDVT1j6UqVFs8jfGnkbEbvqoTRDIenja1bxdiEWFV12RLWJU2MgS+NfBx4WZKtgX0YzZo654PARzL6sJxjkjxhbN2T5w3pPGspi5b6cEhHYnTU3k2RfRijo/3xdeck2YPR3DAvAL6V5Kndaod0tGx4hC+tdyajufFPm7+iqn5QVR+rqlcC32T0qUnSsmLgS+udCBxbVZePNyY5uJsznSTbM5pO99tTqE96WBzSkTpVdSvwvo2s2h/4QJL7GR0kfbiqvtkNAT05ySVjjz2xqja2DWnqnC1TkhrhkI4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEf8PK2y/6dD+e/YAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 311, "metadata": {}, "outputs": [], "source": [ "# Forward selection:\n", "# There are 1 + p(1 + p) / 2 possible models in this space \n", "\n", "def fwd_selection(target_column, data):\n", " predictors = data.drop('Salary', axis=1).columns\n", " formulae_s1 = [(p, 'Salary ~ {}'.format(p)) for p in predictors]\n", " models_s1 = [(p, smf.ols(formula=f, data=data).fit()) for p, f in formulae_s1]\n", " predictor_s1, model_s1 = sorted(models_s1, key=lambda tup: tup[1].rsquared, reverse=True)[0]\n", " predictors = predictors.drop(predictor_s1)\n", " formula = 'Salary ~ {}'.format(predictor_s1)\n", " step_models = [model_s1]\n", " while len(predictors) > 0:\n", " models = []\n", " for p in predictors:\n", " f = formula + ' + ' + p\n", " m = smf.ols(formula=f, data=data).fit()\n", " models.append((p, m))\n", " \n", " predictor, model = sorted(models, key=lambda tup: tup[1].rsquared, reverse=True)[0]\n", " step_models.append(model)\n", " formula = formula + ' + ' + predictor\n", " predictors = predictors.drop(predictor)\n", " \n", " return pd.DataFrame({'model': step_models, 'n_predictors' : range(1, 20)})" ] }, { "cell_type": "code", "execution_count": 312, "metadata": {}, "outputs": [], "source": [ "# Select between models with different numbers of predictors using adjusted training metrics\n", "best = fwd_selection('Salary', hitters_dat)\n", "best['adj_R_sq'] = best['model'].map(lambda m: m.rsquared_adj)\n", "best['bic'] = best['model'].map(lambda m: m.bic)\n", "best['aic'] = best['model'].map(lambda m: m.aic)" ] }, { "cell_type": "code", "execution_count": 313, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x2884715f8>" ] }, "execution_count": 313, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 1080x1080 with 3 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metrics\n", "_, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(15,15))\n", "sns.barplot(y='n_predictors', x='adj_R_sq', data=best, ax=ax1, orient='h')\n", "sns.barplot(y='n_predictors', x='bic', data=best, ax=ax2, orient='h')\n", "sns.barplot(y='n_predictors', x='aic', data=best, ax=ax3, orient='h')" ] }, { "cell_type": "code", "execution_count": 314, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>135508.216515</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>147033.581930</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>147669.747336</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>4</td>\n", " <td>128926.294401</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>5</td>\n", " <td>127644.577719</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>6</td>\n", " <td>125550.266428</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>7</td>\n", " <td>121290.957269</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>8</td>\n", " <td>116963.477452</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>9</td>\n", " <td>120501.056237</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>10</td>\n", " <td>115206.332819</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>11</td>\n", " <td>121145.228980</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>12</td>\n", " <td>128229.611960</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>13</td>\n", " <td>128106.132018</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>14</td>\n", " <td>127551.688712</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>15</td>\n", " <td>128313.047544</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>16</td>\n", " <td>128277.486556</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>17</td>\n", " <td>128137.435745</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>18</td>\n", " <td>128193.974535</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>19</td>\n", " <td>128017.675903</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 1 \n", "1 <statsmodels.regression.linear_model.Regressio... 2 \n", "2 <statsmodels.regression.linear_model.Regressio... 3 \n", "3 <statsmodels.regression.linear_model.Regressio... 4 \n", "4 <statsmodels.regression.linear_model.Regressio... 5 \n", "5 <statsmodels.regression.linear_model.Regressio... 6 \n", "6 <statsmodels.regression.linear_model.Regressio... 7 \n", "7 <statsmodels.regression.linear_model.Regressio... 8 \n", "8 <statsmodels.regression.linear_model.Regressio... 9 \n", "9 <statsmodels.regression.linear_model.Regressio... 10 \n", "10 <statsmodels.regression.linear_model.Regressio... 11 \n", "11 <statsmodels.regression.linear_model.Regressio... 12 \n", "12 <statsmodels.regression.linear_model.Regressio... 13 \n", "13 <statsmodels.regression.linear_model.Regressio... 14 \n", "14 <statsmodels.regression.linear_model.Regressio... 15 \n", "15 <statsmodels.regression.linear_model.Regressio... 16 \n", "16 <statsmodels.regression.linear_model.Regressio... 17 \n", "17 <statsmodels.regression.linear_model.Regressio... 18 \n", "18 <statsmodels.regression.linear_model.Regressio... 19 \n", "\n", " MSE \n", "0 135508.216515 \n", "1 147033.581930 \n", "2 147669.747336 \n", "3 128926.294401 \n", "4 127644.577719 \n", "5 125550.266428 \n", "6 121290.957269 \n", "7 116963.477452 \n", "8 120501.056237 \n", "9 115206.332819 \n", "10 121145.228980 \n", "11 128229.611960 \n", "12 128106.132018 \n", "13 127551.688712 \n", "14 128313.047544 \n", "15 128277.486556 \n", "16 128137.435745 \n", "17 128193.974535 \n", "18 128017.675903 " ] }, "execution_count": 314, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using the validation set approach\n", "train, test = model_selection.train_test_split(hitters_dat, test_size=0.2)\n", "best = fwd_selection('Salary', train)\n", "best['MSE'] = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", "best" ] }, { "cell_type": "code", "execution_count": 315, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1301f4cf8>" ] }, "execution_count": 315, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 316, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>149077.418140</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>134035.444866</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>130954.147117</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>4</td>\n", " <td>123682.462505</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>5</td>\n", " <td>119248.631710</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>6</td>\n", " <td>113154.213964</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>7</td>\n", " <td>115580.660634</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>8</td>\n", " <td>111934.205427</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>9</td>\n", " <td>112854.617360</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>10</td>\n", " <td>114497.887408</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>11</td>\n", " <td>114714.879564</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>12</td>\n", " <td>116153.597903</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>13</td>\n", " <td>115401.230308</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>14</td>\n", " <td>114398.000171</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>15</td>\n", " <td>116273.710693</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>16</td>\n", " <td>116361.112111</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>17</td>\n", " <td>116308.889187</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>18</td>\n", " <td>116595.910924</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>19</td>\n", " <td>116599.013674</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 1 \n", "1 <statsmodels.regression.linear_model.Regressio... 2 \n", "2 <statsmodels.regression.linear_model.Regressio... 3 \n", "3 <statsmodels.regression.linear_model.Regressio... 4 \n", "4 <statsmodels.regression.linear_model.Regressio... 5 \n", "5 <statsmodels.regression.linear_model.Regressio... 6 \n", "6 <statsmodels.regression.linear_model.Regressio... 7 \n", "7 <statsmodels.regression.linear_model.Regressio... 8 \n", "8 <statsmodels.regression.linear_model.Regressio... 9 \n", "9 <statsmodels.regression.linear_model.Regressio... 10 \n", "10 <statsmodels.regression.linear_model.Regressio... 11 \n", "11 <statsmodels.regression.linear_model.Regressio... 12 \n", "12 <statsmodels.regression.linear_model.Regressio... 13 \n", "13 <statsmodels.regression.linear_model.Regressio... 14 \n", "14 <statsmodels.regression.linear_model.Regressio... 15 \n", "15 <statsmodels.regression.linear_model.Regressio... 16 \n", "16 <statsmodels.regression.linear_model.Regressio... 17 \n", "17 <statsmodels.regression.linear_model.Regressio... 18 \n", "18 <statsmodels.regression.linear_model.Regressio... 19 \n", "\n", " MSE \n", "0 149077.418140 \n", "1 134035.444866 \n", "2 130954.147117 \n", "3 123682.462505 \n", "4 119248.631710 \n", "5 113154.213964 \n", "6 115580.660634 \n", "7 111934.205427 \n", "8 112854.617360 \n", "9 114497.887408 \n", "10 114714.879564 \n", "11 116153.597903 \n", "12 115401.230308 \n", "13 114398.000171 \n", "14 116273.710693 \n", "15 116361.112111 \n", "16 116308.889187 \n", "17 116595.910924 \n", "18 116599.013674 " ] }, "execution_count": 316, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using cross validation\n", "results = pd.DataFrame()\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train = hitters_dat.iloc[train_idx]\n", " test = hitters_dat.iloc[test_idx]\n", " best = fwd_selection('Salary', train)\n", " mse = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", " results = results.append(mse, ignore_index=True)\n", " \n", "best['MSE'] = results.mean()\n", "best" ] }, { "cell_type": "code", "execution_count": 317, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x129eafb38>" ] }, "execution_count": 317, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 334, "metadata": {}, "outputs": [], "source": [ "# Backward selection\n", "# There are 1 + p(1 + p) / 2 possible models in this space \n", "\n", "def bwd_selection(target_column, data):\n", " predictors = data.drop('Salary', axis=1).columns\n", " formula_s1 = 'Salary ~ ' + ' + '.join(predictors)\n", " model_s1 = smf.ols(formula=formula_s1, data=data).fit()\n", " step_models = [model_s1]\n", " while len(predictors) > 1:\n", " models = []\n", " for p in predictors:\n", " f = 'Salary ~ ' + ' + '.join(predictors.drop(p))\n", " m = smf.ols(formula=f, data=data).fit()\n", " models.append((p, m))\n", " p, m = sorted(models, key=lambda tup: tup[1].rsquared, reverse=True)[0]\n", " step_models.append(m)\n", " predictors = predictors.drop(p)\n", " return pd.DataFrame({'model': step_models, 'n_predictors' : range(1, 20)[::-1]})" ] }, { "cell_type": "code", "execution_count": 331, "metadata": {}, "outputs": [], "source": [ "# Select between models with different numbers of predictors using adjusted training metrics\n", "best = bwd_selection('Salary', hitters_dat)\n", "best['adj_R_sq'] = best['model'].map(lambda m: m.rsquared_adj)\n", "best['bic'] = best['model'].map(lambda m: m.bic)\n", "best['aic'] = best['model'].map(lambda m: m.aic)" ] }, { "cell_type": "code", "execution_count": 333, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x2a1e25710>" ] }, "execution_count": 333, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 1080x1080 with 3 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metrics\n", "_, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(15,15))\n", "sns.barplot(y='n_predictors', x='adj_R_sq', data=best, ax=ax1, orient='h')\n", "sns.barplot(y='n_predictors', x='bic', data=best, ax=ax2, orient='h')\n", "sns.barplot(y='n_predictors', x='aic', data=best, ax=ax3, orient='h')" ] }, { "cell_type": "code", "execution_count": 336, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>19</td>\n", " <td>133352.083840</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>18</td>\n", " <td>133284.236773</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>17</td>\n", " <td>133233.527027</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>16</td>\n", " <td>133792.636968</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>15</td>\n", " <td>133156.193247</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>14</td>\n", " <td>135624.311929</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>13</td>\n", " <td>135391.355291</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>12</td>\n", " <td>135395.399373</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>11</td>\n", " <td>133589.596734</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>10</td>\n", " <td>132451.355454</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>9</td>\n", " <td>133550.524887</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>8</td>\n", " <td>130130.135009</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>7</td>\n", " <td>139175.543438</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>6</td>\n", " <td>145552.220218</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>5</td>\n", " <td>155594.110409</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>4</td>\n", " <td>160845.201250</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>182224.634993</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>185567.253911</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>236588.256708</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 19 \n", "1 <statsmodels.regression.linear_model.Regressio... 18 \n", "2 <statsmodels.regression.linear_model.Regressio... 17 \n", "3 <statsmodels.regression.linear_model.Regressio... 16 \n", "4 <statsmodels.regression.linear_model.Regressio... 15 \n", "5 <statsmodels.regression.linear_model.Regressio... 14 \n", "6 <statsmodels.regression.linear_model.Regressio... 13 \n", "7 <statsmodels.regression.linear_model.Regressio... 12 \n", "8 <statsmodels.regression.linear_model.Regressio... 11 \n", "9 <statsmodels.regression.linear_model.Regressio... 10 \n", "10 <statsmodels.regression.linear_model.Regressio... 9 \n", "11 <statsmodels.regression.linear_model.Regressio... 8 \n", "12 <statsmodels.regression.linear_model.Regressio... 7 \n", "13 <statsmodels.regression.linear_model.Regressio... 6 \n", "14 <statsmodels.regression.linear_model.Regressio... 5 \n", "15 <statsmodels.regression.linear_model.Regressio... 4 \n", "16 <statsmodels.regression.linear_model.Regressio... 3 \n", "17 <statsmodels.regression.linear_model.Regressio... 2 \n", "18 <statsmodels.regression.linear_model.Regressio... 1 \n", "\n", " MSE \n", "0 133352.083840 \n", "1 133284.236773 \n", "2 133233.527027 \n", "3 133792.636968 \n", "4 133156.193247 \n", "5 135624.311929 \n", "6 135391.355291 \n", "7 135395.399373 \n", "8 133589.596734 \n", "9 132451.355454 \n", "10 133550.524887 \n", "11 130130.135009 \n", "12 139175.543438 \n", "13 145552.220218 \n", "14 155594.110409 \n", "15 160845.201250 \n", "16 182224.634993 \n", "17 185567.253911 \n", "18 236588.256708 " ] }, "execution_count": 336, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using the validation set approach\n", "train, test = model_selection.train_test_split(hitters_dat, test_size=0.2)\n", "best = bwd_selection('Salary', train)\n", "best['MSE'] = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", "best" ] }, { "cell_type": "code", "execution_count": 338, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x12ce60828>" ] }, "execution_count": 338, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 339, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>model</th>\n", " <th>n_predictors</th>\n", " <th>MSE</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>19</td>\n", " <td>116599.013674</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>18</td>\n", " <td>116263.927026</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>17</td>\n", " <td>115939.235073</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>16</td>\n", " <td>115616.664737</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>15</td>\n", " <td>115373.619649</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>14</td>\n", " <td>114599.797519</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>13</td>\n", " <td>115082.285922</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>12</td>\n", " <td>112655.258048</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>11</td>\n", " <td>110741.661641</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>10</td>\n", " <td>110090.954229</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>9</td>\n", " <td>112044.513156</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>8</td>\n", " <td>108997.685343</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>7</td>\n", " <td>115301.385440</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>6</td>\n", " <td>118540.196656</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>5</td>\n", " <td>122845.644224</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>4</td>\n", " <td>128180.461938</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>3</td>\n", " <td>136141.808701</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>2</td>\n", " <td>137656.003772</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td><statsmodels.regression.linear_model.Regressio...</td>\n", " <td>1</td>\n", " <td>148995.519722</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " model n_predictors \\\n", "0 <statsmodels.regression.linear_model.Regressio... 19 \n", "1 <statsmodels.regression.linear_model.Regressio... 18 \n", "2 <statsmodels.regression.linear_model.Regressio... 17 \n", "3 <statsmodels.regression.linear_model.Regressio... 16 \n", "4 <statsmodels.regression.linear_model.Regressio... 15 \n", "5 <statsmodels.regression.linear_model.Regressio... 14 \n", "6 <statsmodels.regression.linear_model.Regressio... 13 \n", "7 <statsmodels.regression.linear_model.Regressio... 12 \n", "8 <statsmodels.regression.linear_model.Regressio... 11 \n", "9 <statsmodels.regression.linear_model.Regressio... 10 \n", "10 <statsmodels.regression.linear_model.Regressio... 9 \n", "11 <statsmodels.regression.linear_model.Regressio... 8 \n", "12 <statsmodels.regression.linear_model.Regressio... 7 \n", "13 <statsmodels.regression.linear_model.Regressio... 6 \n", "14 <statsmodels.regression.linear_model.Regressio... 5 \n", "15 <statsmodels.regression.linear_model.Regressio... 4 \n", "16 <statsmodels.regression.linear_model.Regressio... 3 \n", "17 <statsmodels.regression.linear_model.Regressio... 2 \n", "18 <statsmodels.regression.linear_model.Regressio... 1 \n", "\n", " MSE \n", "0 116599.013674 \n", "1 116263.927026 \n", "2 115939.235073 \n", "3 115616.664737 \n", "4 115373.619649 \n", "5 114599.797519 \n", "6 115082.285922 \n", "7 112655.258048 \n", "8 110741.661641 \n", "9 110090.954229 \n", "10 112044.513156 \n", "11 108997.685343 \n", "12 115301.385440 \n", "13 118540.196656 \n", "14 122845.644224 \n", "15 128180.461938 \n", "16 136141.808701 \n", "17 137656.003772 \n", "18 148995.519722 " ] }, "execution_count": 339, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select between models with different numbers of predictors using cross validation\n", "results = pd.DataFrame()\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train = hitters_dat.iloc[train_idx]\n", " test = hitters_dat.iloc[test_idx]\n", " best = bwd_selection('Salary', train)\n", " mse = best['model'].map(lambda m: metrics.mean_squared_error(test['Salary'], m.predict(test)))\n", " results = results.append(mse, ignore_index=True)\n", " \n", "best['MSE'] = results.mean()\n", "best" ] }, { "cell_type": "code", "execution_count": 340, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1d903e9e8>" ] }, "execution_count": 340, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the error metric\n", "sns.barplot(y='n_predictors', x='MSE', data=best, orient='h')" ] }, { "cell_type": "code", "execution_count": 382, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1e3e4f8d0>" ] }, "execution_count": 382, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Ridge Regression - Cross-Validation with test MSE\n", "\n", "results = pd.DataFrame()\n", "alpha_range = np.linspace(10**-3, 3 * (10**2), num=1000)\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train, test = hitters_dat.iloc[train_idx], hitters_dat.iloc[test_idx]\n", " errors = []\n", " for alpha in alpha_range:\n", " reg = linear_model.Ridge(alpha=alpha)\n", " reg.fit(train.drop('Salary', axis=1), train['Salary'])\n", " error = metrics.mean_squared_error(test['Salary'], reg.predict(test.drop('Salary', axis=1)))\n", " errors.append(error)\n", " results = pd.concat([results, pd.DataFrame(errors, index=alpha_range).T], axis=0, ignore_index=True)\n", "\n", "df = pd.DataFrame({'lambda' : alpha_range, 'MSE' : results.mean()})\n", "sns.scatterplot(x='lambda', y='MSE', data=df)" ] }, { "cell_type": "code", "execution_count": 384, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1e77340f0>" ] }, "execution_count": 384, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# LASSO Regression - Cross-Validation with test MSE\n", "\n", "results = pd.DataFrame()\n", "alpha_range = np.linspace(10**-2, 10**3, num=500)\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train, test = hitters_dat.iloc[train_idx], hitters_dat.iloc[test_idx]\n", " errors = []\n", " for alpha in alpha_range:\n", " reg = linear_model.Lasso(alpha=alpha, max_iter=10000)\n", " reg.fit(train.drop('Salary', axis=1), train['Salary'])\n", " error = metrics.mean_squared_error(test['Salary'], reg.predict(test.drop('Salary', axis=1)))\n", " errors.append(error)\n", " results = pd.concat([results, pd.DataFrame(errors, index=alpha_range).T], axis=0, ignore_index=True)\n", " \n", "\n", "df = pd.DataFrame({'lambda' : alpha_range, 'MSE' : results.mean()})\n", "sns.scatterplot(x='lambda', y='MSE', data=df)" ] }, { "cell_type": "code", "execution_count": 387, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1f0dd6160>" ] }, "execution_count": 387, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# PCR Regression - Cross-Validation with test MSE\n", "\n", "results = pd.DataFrame()\n", "predictors = hitters_dat.drop('Salary', axis=1).columns\n", "c_range = range(1, len(predictors) + 1)\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train, test = hitters_dat.iloc[train_idx], hitters_dat.iloc[test_idx]\n", " errors = []\n", " for c in c_range:\n", " pca = decomposition.PCA(n_components=c)\n", " X = pca.fit_transform(train.drop('Salary', axis=1))\n", " reg = linear_model.LinearRegression()\n", " reg.fit(X, train['Salary'])\n", " \n", " test_X = pca.transform(test.drop('Salary', axis=1))\n", " error = metrics.mean_squared_error(test['Salary'], reg.predict(test_X))\n", " #error = metrics.mean_squared_error(test['Salary'], reg.predict(test_X))\n", " errors.append(error)\n", " results = pd.concat([results, pd.DataFrame(errors, index=c_range).T], axis=0, ignore_index=True)\n", "\n", "df = pd.DataFrame({'n_components' : c_range, 'MSE' : results.mean()})\n", "sns.barplot(x='n_components', y='MSE', data=df)" ] }, { "cell_type": "code", "execution_count": 388, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x1e7ba3128>" ] }, "execution_count": 388, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# PLS Regression - Cross-Validation with test MSE\n", "\n", "results = pd.DataFrame()\n", "predictors = hitters_dat.drop('Salary', axis=1).columns\n", "c_range = range(1, len(predictors) + 1)\n", "for train_idx, test_idx in model_selection.KFold(n_splits=10).split(hitters_dat):\n", " train, test = hitters_dat.iloc[train_idx], hitters_dat.iloc[test_idx]\n", " errors = []\n", " for c in c_range:\n", " reg = cross_decomposition.PLSRegression(n_components=c)\n", " reg.fit(train.drop('Salary', axis=1), train['Salary'])\n", " error = metrics.mean_squared_error(test['Salary'], reg.predict(test.drop('Salary', axis=1)))\n", " errors.append(error)\n", " results = pd.concat([results, pd.DataFrame(errors, index=c_range).T], axis=0, ignore_index=True)\n", "\n", "df = pd.DataFrame({'n_components' : c_range, 'MSE' : results.mean()})\n", "sns.barplot(x='n_components', y='MSE', data=df)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }