{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Recitation Week 2: Bias Variance Tradeoff + Ridge and Lasso regularization" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "import numpy as np\n", "import matplotlib.pyplot as plt \n", "\n", "\n", "xtrue = np.linspace(-2,1,10)\n", "\n", "ttrue = xtrue+xtrue**2\n", "tnoisy = ttrue+ np.random.normal(0,0.1, len(xtrue))\n", "\n", "xtest = np.linspace(-2,1,100)\n", "ttest = xtrue + xtrue**2\n", "\n", "plt.scatter(xtrue, tnoisy, c='r')\n", "plt.plot(xtrue, ttrue)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# step 1) generate polynomial features with PolynomialFeatures for several max degrees\n", "# step 2) fit a regression model using LinearRegression\n", "# apply our model on test data generated on the [-2,1]\n", "\n", "\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.linear_model import LinearRegression\n", "\n", "degree = 20\n", "\n", "poly = PolynomialFeatures(degree)\n", "Xpoly = poly.fit_transform(xtrue.reshape(-1,1)) \n", "# matrix encoding the augmented features vectors \n", "\n", "reg = LinearRegression()\n", "reg.fit(Xpoly, tnoisy)\n", "\n", "Xtestpoly = poly.fit_transform(xtest.reshape(-1,1)) \n", "\n", "prediction = reg.predict(Xtestpoly)\n", "\n", "\n", "plt.scatter(xtrue, tnoisy, c='r')\n", "plt.plot(xtest, prediction)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt \n", "\n", "\n", "xtrue = np.linspace(-2,1,10)\n", "\n", "t = xtrue+xtrue**2\n", "tnoisy = t+ np.random.normal(0,0.1, len(xtrue))\n", "\n", "xtest = np.linspace(-2,1,100)\n", "ttest = xtrue + xtrue**2\n", "\n", "plt.scatter(xtrue, tnoisy, c='r')\n", "plt.plot(xtrue, t)\n", "plt.show()\n", "\n", "\n", "# step 1) generate polynomial features with PolynomialFeatures for several max degrees\n", "# step 2) fit a regression model using LinearRegression\n", "# apply our model on test data generated on the [-2,1]\n", "\n", "\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.linear_model import LinearRegression\n", "\n", "degree = 20\n", "\n", "poly = PolynomialFeatures(degree)\n", "Xpoly = poly.fit_transform(xtrue.reshape(-1,1)) \n", "# matrix encoding the augmented features vectors \n", "\n", "reg = LinearRegression()\n", "reg.fit(Xpoly, tnoisy)\n", "\n", "Xtestpoly = poly.fit_transform(xtest.reshape(-1,1)) \n", "\n", "prediction = reg.predict(Xtestpoly)\n", "\n", "\n", "plt.scatter(xtrue, tnoisy, c='r')\n", "plt.plot(xtest, prediction)\n", "plt.show()\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# demo bias variance \n", "\n", "x = np.linspace(-2,1,10)\n", "\n", "ttrue = 0.1 * x**3 - 0.1* x**2 + x +1\n", "\n", "tnoisy = ttrue+ np.random.normal(0,.25, len(x))\n", "\n", "plt.scatter(x, tnoisy, c='r')\n", "plt.plot(x, ttrue)\n", "plt.show()\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Low Bias, High variance ((almost) no regularization)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.linear_model import Ridge\n", "from sklearn.linear_model import Lasso\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "\n", "numXp = 20\n", "\n", "degree = 10\n", "\n", "xtest = np.linspace(-2,1,100)\n", "from sklearn.linear_model import LinearRegression\n", "\n", "\n", "\n", "plt.scatter(x, tnoisy, c='r')\n", "\n", "predictionMat = np.zeros((len(xtest), numXp))\n", "\n", "for xp in np.arange(numXp):\n", " \n", " # generate the subsets D^i \n", " tnoisy = ttrue+ np.random.normal(0,.25, len(x))\n", " \n", " poly = PolynomialFeatures(degree)\n", " Xpoly = poly.fit_transform(x.reshape(-1,1)) \n", " reg = Ridge(alpha=1e-4,tol=1e-2)\n", " reg.fit(Xpoly, tnoisy)\n", " \n", " # represent the model \n", " XpolyTest = poly.fit_transform(xtest.reshape(-1,1))\n", " prediction = reg.predict(XpolyTest)\n", " predictionMat[:,xp] = prediction\n", " \n", " \n", " \n", " plt.plot(xtest, prediction, alpha=.1, c='b')\n", " \n", "familyAverage = np.mean(predictionMat, axis=1)\n", "\n", "plt.plot(xtest, familyAverage, '--',linewidth=2, c='black')\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Low Variance, high bias (heavy regularization)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVjElEQVR4nO3df3DcdZ3H8dd7kzRpmtpC09KWkkQocpzaUScq6hw4UrUqR9UZZtD14E6YyDjeiDKMnDnPO5neHecPxpvjRtdD5YY9mHGOQucOheKdx4iKpEi1pQhISZs20EBpoN0mTTaf++O722w2302+m/1mdz/J8zGz8/1+P9/vfr+fz37hlU8+38+m5pwTAMBfiVpXAABQGYIcADxHkAOA5whyAPAcQQ4AnmusxUXb29tdV1dXLS4NAN7atWvXS8651cXlNQnyrq4u9fX11eLSAOAtM+sPK2doBQA8R5ADgOcqDnIzO8fM/tfM9pnZXjP7fBwVAwBEE8cY+bikG5xzj5vZckm7zGync+7JGM4NAJhFxT1y59ygc+7x3PprkvZJOrvS8wIAool1jNzMuiS9VdKjIft6zKzPzPqGhobivCwA1L90WurqkhKJYJlOx3bq2ILczNok/aek651zrxbvd86lnHPdzrnu1aunTYMEgIUrnZZ6eqT+fsm5YNnTE1uYxxLkZtakIMTTzrl74jgnAMRiHnvCkfX2SpnM1LJMJiiPQcUPO83MJN0uaZ9z7luVVwkAYpLvCedDNN8TlqRksnr1OHCgvPIyxdEjf4+kP5P0PjN7Ivf6cAznBYDKzHNPOLKOjvLKy1Rxj9w593NJFkNdACBe89wTjmzbtqm/GUhSa2tQHgO+2Qlg4ZrnnnBkyaSUSkmdnZJZsEylYhveIcgBLFzbtgU930Ix9oTLkkxKzz8vTUwEyxjH6AlyAAvXPPeE60VN/owtAFRNMrnggrsYPXIA8BxBDgCeI8gBwHMEOQB4jiAHAM8R5ADgOYIcADxHkAOA5whyAPAcQQ4AniPIAcBzBDkAeI4gBwDPEeQA4DmCHAA8R5ADgOcIcgDwHEEOAJ4jyAHAcwQ5AHiOIAcAzxHkAOA5ghwAPEeQA4DnCHIA8BxBDgCeI8gBwHMEOQB4jiAHAM8R5ADgOYIcWGjSaamrS0okgmU6XesaYZ411roCAGKUTks9PVImE2z39wfbkpRM1q5emFf0yIGFpLd3MsTzMpmgHAsWQQ4sJAcOlFeOBYEgBxaSjo7yyucTY/VVE0uQm9n3zeyIme2J43wA5mjbNqm1dWpZa2tQXk35sfr+fsm5ybF6wnxexNUj/6GkLTGdC8BcJZNSKiV1dkpmwTKVqv6DTsbqqyqWWSvOuYfNrCuOcwGoUDJZ+xkqjNVXVdXGyM2sx8z6zKxvaGioWpcFUAv1NFa/CFQtyJ1zKedct3Oue/Xq1dW6LIBaqJex+kWCWSsA4lcvY/WLBN/sBDA/6mGsfpGIa/rhXZJ+KekCMxsws2viOC8AYHZxzVr5RBznAQCUjzFyIC58kxE1whg5EAf+6iBqiB45EAe+yYgaIsiBOPBNRtQQQQ7EgW8yooYIciAOfJMRNUSQA3Hgm4yoIWatAHHhm4yoEXrkAOA5ghwAPEeQA4DnCHIA8BxBDgCeI8gBwHMEOQB4jiAHAM8R5ADgOYIcADxHkAOA5whyAPAcQQ4AniPIAcBzBDkAeG7R/z1y5yaX+VfhdvExMy2L16Nsz1Yedf98vdcLP/qRdPPN0sCAtGGD9JWvSFdcUetaAaHa2qSmpnjP6WWQOydNTASvsPXCZfF68SvueoWtF2+bhV/bLPy95R4zV/lzz7dY6779HulLX5ZGTkpqlQaOSp//sjTSIH3s4zFeCIjH0qWLPMhffVU6cWLmIMgHdHFYlgrPfFlhiOXfX7hefM6w3njxe0vVr5zyqPvnYkH01G+5VRrJSloyWTaSDco/QpCj/kxMxH9Or4I83/POK+4B5/cV9sCLt4t74mHBO1MY58sL94eVFZ9jtu0o187vD6tPVAsivAsdfkGhj3oOvyBlq14bYFbz8f+gV0E+MiIND4fvSySCl1mwLGQ2/TWTwl56lA89ylBI8XnDzlGtoY1CtbhmrNatkQYHQso3SDH/+grEoTif4uBVkDc1Sa2twQdRGJ7F4+Hj49N7yWbTe+ZSeLgWDqUUKlVefEz+vGHvL77mTOPqpcqi/iAq15xC/b57pW9+Sxo8JK07W7rhi9LWj86tAnNxw41Sb680enKyrHlpUE6PHHVo0ffIR0el48eD9XwoF/ayC3vlefnhllJj3vmAL3xImhc2oyVsZkv+fcXnKrx+8TBP4TXCbmypcbTicf+Zjg17T7lmfO//PCR9+3ZpvEXSedKgpJtulwbbpPdtnvtFy3HhR6XPtUk/vEN6+UVp1VnSn18tXbhZero6VQDKsXx50CGNk1dB3tQUPPGVpgZoNhu8Mpmg7NSpoFeezQbL8fHguFOnxjU6elInT2a0bFm7pAZNTEgHDz6pV14Z1NjYqE6dGtX4+KjGxk5pbGxUK1as06ZNl+XOe0IPPPANZbNjymbHNTERvLLZcTk3rosv7lFn59tkJv3mN/fp0UfvlnMTci6riYmJ3PqEmpuX6bOfvet0u7773U/p2LFBOeckOTnncsc6vfvdn9Kll35GkvTss4/qzjs/n2t/cOzkunT99ffqzDPPliTdddeN2rv3p7krTE3j8857p6655juSpExmWDfffMm0Y/I+8Ymva9OmD0iSfvaz2/WTn3x7cufhw1J2XJK0VE36qt4njUv6wb26Zec39corh0LPeckln9aHPnS9JOmZZ36p22//TKlbrhtvvF+rVm2QJN155w3as2dn6HEbN12ka6998HSbvnbdppLn/OQnv1G6TQWWLn2dvvrVn5/evuWWD1W3TRsv0rXXpibb9LU/oU0LqE1x8irIDx6UHnvsoF5++YheffVlDQ+/rNdeO6rjx1/RyZPHtHHjxeruvlyS9Nxzv9APftCj0dHjGhk5rtHR4xobGz19ru9973mtW9epxkbpttv+Wo88sj30mm9/+2ZdeeVlSiSkY8dGtWPH35as35Ytl+rNbw6CvK/vSf3613eHHrd8+UpdeGGwPjEh7d//iAYHnw899qKL3qMLLgjWjx4d1h/+8GjJ63d0nNL69cH6yZP71d//m9DjzjqrXRs3BuvDw1kdOLC75DlXrBjW+ecH67/4xZAOHvxd6HGvU4PO1+9zFX1aQ83HSrYpkXjh9DmPHj1e8pySdM45Yzo7+Nmk0dEDJY9dt+6s0+ccHs7OeM7IbXrdGaePk6ShoadoE22quE3LlpXcPWfm5mPAZhbd3d2ur6+v7Pft2iV95CPn6sUX94fuTyav1w033CpJevzxh3XttZdM2W9mWrp0mZqbW3XbbY9o/fqNymalO+74O+3e/X9asqRFjY3NampqVkPDEjU0NKmj4026/PIvSJJGRkZ0zz3/ILNGNTQ0KpFoUkNDo8walEg0atOmD2jt2vMkSQcP7lV//24lEg1KJBIySyiRaJBZQg0NS/SWt2w5Xa+nnvq5xsZGJJkkUyKROL2+atU5WrPm9ZKkEyeO6dChfbLTY0c2Zb2jY5OampolSUeOPKcTJ45NaXteS8tyrV0bJHk2O66BgT3Tjslrb+/UsmUrJUnDw0c0PPzC5M6/3ya9FlwjIdMGrQjKz2jX4S//pcbHT4Xep+XLV+uMM9ZJkk6efE1DQ+H3U5LWr/8jNTYGUwuHhvp18mT40+6WljatWXPu6TYdOvRkyXOuWtVRuk0FEokGbdjwxtPbhw7tUzY7RptoU0Vt+uAHV6qjo+QhMzKzXc657mnlPgX5/v3SVVddrhdeOKjly1eprW2VWlvPVGvrGWppWanzznun3vCGS+SclMkc10sv7deSJW1qbm5TS0ublixpURB+wfkKZ7AUPkluaJgc8y4cc88fk3/Y2tAw+Z7C9+fP69zU8sLtKA9Fi89bbLYpkvPuoZ3Srd+SxkYmy5papC98Udr8/ipVAvDLW98qtbfP7b2lgtyroZX+fumqq3ZImnyw2dQkNTYGodrUFCwTCWnNmjade+6b1dCgKa/8+/IKw7xwWRjSxcfONoe8cDZK8Qyb4mPDRAniSueTz8W0n/l/8X6pPSPdeqs0OCitWyd94TrpTwlxoJT8c744eRXkXV3SihWTveSwV35f8QyT/APPvOKZJ3n5WSXZ7OyzVgrPUyjs26UzTWcsJcovS1G+sj+vv3S9fav0H1unloX/BgxA0urVin2c3Ksgz2aDGSnF88HzM1cK54pLUwM/bDusN154XNgPjML3RemdF5eHTR+UZp4vXuoLRDOJMtcdQPU1N8d/Tu+CfGQkCNj8cEp+mR82KeyZS9N76sVhXHjcfHzjCgAKxf0Hs6SYgtzMtkj6tqQGSf/mnPvHOM5bbNWqYHwpbPx6si7RX4XHh62HLaOuh22XKouyL873AKid+fh/tuIgN7MGSbdJer+kAUmPmdkO51zpOThz1NYmtbRM/7sqxb1rAFhM4uiRv0PSs8655yTJzO6WtFVS7EHe3Dw/40sA4LM4+rBnSzpYsD2QK5vCzHrMrM/M+oaGhmK4LABAiifIw0Z8ps2pcM6lnHPdzrnu1atXx3BZAIAUT5APSDqnYHuDpMMxnBcAEEEcQf6YpPPN7PVmtkTSlZJ2xHBeAEAEFT/sdM6Nm9nnJD2gYPrh951zeyuuGQAgkljmkTvn7pd0fxznAgCUh5nXAOA5ghwAPEeQA4DnCHIA8BxBDgCeI8gBwHMEOQB4jiAHAM8R5ADgOYIcADxHkAOA5whyAPAcQQ4AniPIfZVOS11dwb843dUVbANYlGL5M7aosnRa6umRMplgu78/2JakZLJ29QJQE/TIfdTbOxnieZlMUA5g0SHIfXTgQHnlABY0gtxHHR3llQNY0AhyH23bJrW2Ti1rbQ3KASw6BLmPkkkplZI6OyWzYJlK8aATWKSYteKrZJLgBiCJHjkAeI8gBwDPEeQA4DmCHAA8R5ADgOcIcgDwHEEOAJ4jyAHAcwQ5AHiOIAcAzxHkAOA5ghwAPEeQA4DnCHIA8BxBDgCeI8gBwHMEOQB4jiAHAM9VFORmdoWZ7TWzCTPrjqtSAIDoKu2R75H0cUkPx1AXAMAcVPSPLzvn9kmSmcVTGwBA2ao2Rm5mPWbWZ2Z9Q0ND1bosACx4s/bIzewhSWtDdvU65+6LeiHnXEpSSpK6u7td5BoCAGY0a5A75zZXoyIAgLlh+iEAeK7S6YcfM7MBSe+S9N9m9kA81QIARFXprJXtkrbHVBcAwBwwtAIAniPIAcBzBDkAeI4gBwDPEeQA4DmCHAA8R5ADgOcIcgDwHEEOAJ4jyAHAcwQ5AHiOIAcAzxHk5Uqnpa4uKZEIlul0rWsEYJGr6K8fLjrptNTTI2UywXZ/f7AtSclk7eoFYFGjR16O3t7JEM/LZIJyAKgRgrwcBw6UVw4AVUCQl6Ojo7xyAKgCgrwc27ZJra1Ty1pbg3IAqBGCvBzJpJRKSZ2dklmwTKV40Amgppi1Uq5kkuAGUFfokQOA5whyAPAcQQ4AniPIAcBzBDkAeI4gBwDPEeQA4DmCHAA8R5ADgOcIcgDwHEEOAJ4jyAHAcwQ5AHiOIAcAzxHkAOA5ghwAPEeQA4DnCHIA8FxFQW5mXzezp8zst2a23cxWxlQvAEBElfbId0p6k3Nuk6SnJf1V5VUCAJSjoiB3zj3onBvPbf5K0obKqwQAKEecY+SflvTjGM8HAIigcbYDzOwhSWtDdvU65+7LHdMraVxSeobz9EjqkaSOjo45VRYAMN2sQe6c2zzTfjO7WtJlki51zrkZzpOSlJKk7u7ukscBAMoza5DPxMy2SPqSpEucc5l4qgQAKEelY+T/Imm5pJ1m9oSZfSeGOgEAylBRj9w5tzGuigAA5oZvdgKA5whyAPCcP0GeTktdXVIiESzTJWc6AsCiUtEYedWk01JPj5TJTYzp7w+2JSmZrF29AKAO+NEj7+2dDPG8TCYoB4BFzo8gP3CgvHIAWET8CPJSX+nnq/4A4EmQb9smtbZOLWttDcoBYJHzI8iTSSmVkjo7JbNgmUrxoBMA5MusFSkIbYIbAKbxo0cOACiJIAcAzxHkAOA5ghwAPEeQA4DnbIZ/nW3+Lmo2JKl/jm9vl/RSjNWpJdpSfxZKOyTaUq8qaUunc251cWFNgrwSZtbnnOuudT3iQFvqz0Jph0Rb6tV8tIWhFQDwHEEOAJ7zMchTta5AjGhL/Vko7ZBoS72KvS3ejZEDAKbysUcOAChAkAOA5+o+yM3s62b2lJn91sy2m9nKEsdtMbPfm9mzZnZTlasZiZldYWZ7zWzCzEpOPzKz583sd2b2hJn1VbOOUZXRlrq+L2Z2ppntNLNncsszShxXt/dkts/YAv+c2/9bM3tbLeo5mwjteK+ZDefuwRNm9je1qGcUZvZ9MztiZntK7I/3njjn6vol6QOSGnPrt0i6JeSYBkl/kHSupCWSdkv641rXPaSeF0q6QNLPJHXPcNzzktprXd9K2+LDfZH0T5Juyq3fFPbfVz3fkyifsaQPS/qxJJN0kaRHa13vObbjvZL+q9Z1jdieiyW9TdKeEvtjvSd13yN3zj3onBvPbf5K0oaQw94h6Vnn3HPOuVOS7pa0tVp1jMo5t8859/ta1yMOEdviw33ZKumO3Podkj5au6rMSZTPeKukf3eBX0laaWbrql3RWfjw30pkzrmHJR2d4ZBY70ndB3mRTyv4KVbsbEkHC7YHcmW+cpIeNLNdZtZT68pUwIf7cpZzblCScss1JY6r13sS5TP24T5EreO7zGy3mf3YzN5YnarNi1jvSV38C0Fm9pCktSG7ep1z9+WO6ZU0LikddoqQsprMq4zSlgje45w7bGZrJO00s6dyP+GrKoa21MV9makdZZymLu5JiCifcV3ch1lEqePjCv7WyHEz+7CkeyWdP98Vmyex3pO6CHLn3OaZ9pvZ1ZIuk3Spyw0wFRmQdE7B9gZJh+OrYXSztSXiOQ7nlkfMbLuCXzurHhoxtKUu7stM7TCzF81snXNuMPer7ZES56iLexIiymdcF/dhFrPW0Tn3asH6/Wb2r2bW7pzz8Y9pxXpP6n5oxcy2SPqSpMudc5kShz0m6Xwze72ZLZF0paQd1apjnMxsmZktz68reNgb+uTbAz7clx2Srs6tXy1p2m8adX5PonzGOyRdlZspcZGk4fxwUh2ZtR1mttbMLLf+DgX59XLVaxqPeO9JrZ/uRnj6+6yCsaQncq/v5MrXS7q/6Cnw0wqefPfWut4l2vIxBT+JRyW9KOmB4rYoeGq/O/fa63NbfLgvklZJ+qmkZ3LLM327J2GfsaTrJF2XWzdJt+X2/04zzJiq83Z8Lvf571Yw8eHdta7zDG25S9KgpLHc/yfXzOc94Sv6AOC5uh9aAQDMjCAHAM8R5ADgOYIcADxHkAOA5whyAPAcQQ4Anvt/V/cv2aVQ5Q8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.linear_model import Ridge\n", "from sklearn.linear_model import Lasso\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "\n", "numXp = 20\n", "\n", "degree = 10\n", "\n", "xtest = np.linspace(-2,1,100)\n", "from sklearn.linear_model import LinearRegression\n", "\n", "\n", "\n", "plt.scatter(x, tnoisy, c='r')\n", "\n", "predictionMat = np.zeros((len(xtest), numXp))\n", "\n", "for xp in np.arange(numXp):\n", " \n", " # generate the subsets D^i \n", " tnoisy = ttrue+ np.random.normal(0,.25, len(x))\n", " \n", " poly = PolynomialFeatures(degree)\n", " Xpoly = poly.fit_transform(x.reshape(-1,1)) \n", " reg = Ridge(alpha=1e8,tol=1e-2)\n", " reg.fit(Xpoly, tnoisy)\n", " \n", " # represent the model \n", " XpolyTest = poly.fit_transform(xtest.reshape(-1,1))\n", " prediction = reg.predict(XpolyTest)\n", " predictionMat[:,xp] = prediction\n", " \n", " \n", " \n", " plt.plot(xtest, prediction, alpha=.1, c='b')\n", " \n", "familyAverage = np.mean(predictionMat, axis=1)\n", "\n", "plt.plot(xtest, familyAverage, '--',linewidth=2, c='black')\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evolution of the MSE, Bias and Variance as a function of the model complexity" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "numXp = 20\n", "\n", "maxDegree = 7\n", "\n", "xtest = np.linspace(-2,1,100) \n", "\n", "ttest = 0.1 * xtest**3 - 0.1* xtest**2 + xtest +1\n", "\n", "\n", "\n", "bias_squared = np.zeros((maxDegree, 1))\n", "variance = np.zeros((maxDegree, 1))\n", "MSE = np.zeros((maxDegree, 1))\n", "\n", "for degree in np.arange(maxDegree):\n", " \n", " predictionMat = np.zeros((len(xtest), numXp))\n", "\n", " poly = PolynomialFeatures(degree)\n", " \n", " for xp in np.arange(numXp):\n", " \n", " # generate the subsets D^i \n", " tnoisy = ttrue+ np.random.normal(0,.25, len(x))\n", " \n", " Xpoly = poly.fit_transform(x.reshape(-1,1)) \n", " reg = LinearRegression()\n", " reg.fit(Xpoly, tnoisy)\n", " \n", " # represent the model \n", " XpolyTest = poly.fit_transform(xtest.reshape(-1,1))\n", " prediction = reg.predict(XpolyTest)\n", " predictionMat[:,xp] = prediction\n", " \n", " \n", " familyAverage = np.mean(predictionMat, axis=1)\n", " bias_vec = (ttest.reshape(-1,1) - familyAverage.reshape(-1,1))**2\n", " \n", " \n", " ones_vec = np.ones((1, numXp))\n", " tmp = (predictionMat - np.matmul(familyAverage.reshape(-1,1), ones_vec))**2 \n", " variance_vec = np.mean(tmp,axis=1)\n", " \n", " bias_squared[degree] = np.mean(bias_vec)\n", " variance[degree] = np.mean(variance_vec)\n", " \n", " MSE[degree] = bias_squared[degree] + variance[degree]\n", " \n", "plt.plot(bias_squared[1:], label='bias')\n", "plt.plot(variance[1:], label = 'variance')\n", "plt.plot(MSE[1:], label = 'MSE')\n", "plt.legend()\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }