{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression Tutorial\n", "by Marc Deisenroth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The purpose of this notebook is to practice implementing some linear algebra (equations provided) and to explore some properties of linear regression." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import scipy.linalg\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We consider a linear regression problem of the form\n", "$$\n", "y = \\boldsymbol x^T\\boldsymbol\\theta + \\epsilon\\,,\\quad \\epsilon \\sim \\mathcal N(0, \\sigma^2)\n", "$$\n", "where $\\boldsymbol x\\in\\mathbb{R}^D$ are inputs and $y\\in\\mathbb{R}$ are noisy observations. The parameter vector $\\boldsymbol\\theta\\in\\mathbb{R}^D$ parametrizes the function.\n", "\n", "We assume we have a training set $(\\boldsymbol x_n, y_n)$, $n=1,\\ldots, N$. We summarize the sets of training inputs in $\\mathcal X = \\{\\boldsymbol x_1, \\ldots, \\boldsymbol x_N\\}$ and corresponding training targets $\\mathcal Y = \\{y_1, \\ldots, y_N\\}$, respectively.\n", "\n", "In this tutorial, we are interested in finding good parameters $\\boldsymbol\\theta$." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEGCAYAAAB2EqL0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAP6klEQVR4nO3df6zdd13H8efLbvwQVNBeZGyTO+OdYcovPVkkGrPYVQcjKygkW4wOlSwYB3XRSOMSpkOSERPr+BGwymQYZBB+SLXF0Q1wEDPc6TJgW1lbl5GVLuyyCThBSOHtH/cM7+7ObT+9P873nHuej+Sk3x+ffvv6pu195fvjfL+pKiRJOpEf6DqAJGkyWBiSpCYWhiSpiYUhSWpiYUiSmpzSdYD1snnz5pqdne06hiRNlP3793+1qmaGrduwhTE7O0u/3+86hiRNlCRfWm6dp6QkSU0sDElSEwtDktTEwpCkDWbnvoPrsl0LQ5I2mGtvPrQu27UwJElNLAxJUhMLQ5LUZMN+cU+SpsHOfQeHXrOY3bHnMfPbt8xxxdazV/VnZaO+QKnX65Xf9JY0jWZ37OG+ay5c0e9Nsr+qesPWeUpKktTEwpAkNbEwJElNLAxJ2mC2b5lbl+1aGJK0waz2bqjlWBiSpCYWhiSpyVgURpLrkjyY5M5l1p+X5OtJ7hh83jDqjJI07cblm97vBt4GvOc4Yz5dVS8dTRxJ0lJjcYRRVbcAD3edQ5K0vLEojEYvSvK5JB9L8jPDBiS5LEk/SX9+fn7U+SRpQ5uUwrgdeHZVPR94K/BPwwZV1a6q6lVVb2ZmZqQBJWmjm4jCqKpvVNUjg+m9wKlJNnccS5KmykQURpJnJslg+lwWcj/UbSpJmi5jcZdUkvcB5wGbkxwBrgJOBaiqdwKvAH4/yTHgW8DFtVGfyy5JY2osCqOqLjnB+rexcNutJKkjE3FKSpLUPQtDktTEwpAkNbEwJElNLAxJUhMLQ5LUxMKQJDWxMCRJTSwMSVITC0OS1MTCkCQ1sTAkSU0sDElSEwtDktTEwpAkNbEwJElNLAxJUhMLQ5LUxMKQJDWxMCRJTSwMSVITC0OS1MTCkCQ1sTAkSU0sDElSk7EojCTXJXkwyZ3LrE+StyQ5nOTzSX5u1BkladqNRWEA7wYuOM76FwNzg89lwDtGkEmStMhYFEZV3QI8fJwh24D31IJbgaclOW006SRJMCaF0eB04P5F80cGyx4jyWVJ+kn68/PzIwsnSdNgUgojQ5bV4xZU7aqqXlX1ZmZmRhBLkqbHpBTGEeDMRfNnAEc7yiJJU2lSCmM38NuDu6V+Afh6VT3QdShJmiandB0AIMn7gPOAzUmOAFcBpwJU1TuBvcBLgMPAN4Hf6SapJE2vsSiMqrrkBOsL+IMRxZEkDTEpp6QkSR2zMCRJTSwMSVITC0OS1MTCkCQ1sTAkSU0sDElSEwtDktTEwpAkNbEwJElNLAxJUhMLQ5LUxMKQtCo79x3sOoJGxMKQtCrX3nyo6wgaEQtDktTEwpAkNbEwJElNxuKNe5Imw859B4des5jdsecx89u3zHHF1rNHFUsjkoW3n248vV6v+v1+1zGkDW92xx7uu+bCrmNojSTZX1W9Yes8JSVJamJhSJKaWBiSpCYWhqRV2b5lrusIGhELQ9KqeDfU9LAwJElNLAxJUpOxKIwkFyS5J8nhJDuGrH9Vkvkkdww+r+4ipyRNs86/6Z1kE/B2YCtwBLgtye6qunvJ0PdX1eUjDyhJAsbjCONc4HBV3VtV3wFuALZ1nEmStMQ4FMbpwP2L5o8Mli31G0k+n+SDSc4ctqEklyXpJ+nPz8+vR1ZJmlrjUBgZsmzpA67+GZitqucBNwHXD9tQVe2qql5V9WZmZtY4piRNt3EojCPA4iOGM4CjiwdU1UNV9e3B7N8CPz+ibJKkgXEojNuAuSRnJXkCcDGwe/GAJKctmr0IODDCfJIkxuAuqao6luRy4EZgE3BdVd2V5GqgX1W7gdcluQg4BjwMvKqzwJI0pXwfhiTp+3wfhiRp1SwMSVITC0OS1MTCkCQ1sTAkSU0sDElSEwtDktTEwpAkNbEwJElNLAxJUhMLQ5LU5ISFkeSmJM8fRRhJ0vhqOcL4E2Bnkr9f8phxSdIUOWFhVNXtVfUrwL8A/5rkqiRPXv9okqRx0nQNI0mAe4B3AK8FDiX5rfUMJm1kO/cd7DqCdNJarmF8BvgysBM4nYWXF50HnJtk13qGkzaqa28+1HUE6aS1vHHvNcBd9fg3Lb02ia9KlaQpccLCqKo7j7P6wjXMIkkaY6v6HkZV3btWQSRJ463llJSkVdi57+DQaxazO/Y8Zn77ljmu2Hr2qGJJJy2PvzSxMfR6ver3+13HkIaa3bGH+67xjK7GT5L9VdUbts5Hg0iSmlgYkqQmFoYkqYmFIXVg+5a5riNIJ20sCiPJBUnuSXI4yY4h65+Y5P2D9Z9NMjv6lNLa8W4oTaLOCyPJJuDtwIuBc4BLkpyzZNjvAf9VVT/FwiNK3jzalJKkzgsDOBc4XFX3VtV3gBuAbUvGbAOuH0x/ENgyeCCiJGlExqEwTgfuXzR/ZLBs6JiqOgZ8HfixpRtKclmSfpL+/Pz8OsWVpOk0DoUx7Ehh6bcJW8ZQVbuqqldVvZmZmTUJJ0laMA6FcQQ4c9H8GcDR5cYkOQX4EeDhkaSTJAHjURi3AXNJzkryBOBiYPeSMbuBSwfTrwA+MeRx65KkddT5wwer6liSy4EbgU3AdVV1V5KrgX5V7QbeBfxDksMsHFlc3F1iSZpOnRcGQFXtBfYuWfaGRdP/C7xy1LkkSf9vHE5JSZImgIUhSWpiYUiSmlgYkqQmFoYkqYmFIUlqYmFIkppYGJKkJhaGJKmJhSFJamJhSJKaWBiSpCYWhiSpiYUhSWpiYUiSmlgYkqQmFoYkqYmFIUlqYmFIkppYGJKkJhaGJKmJhSFJamJhSJKaWBiSpCYWhiSpiYUhSWrSaWEk+dEk+5IcGvz69GXGfTfJHYPP7lHnlCR1f4SxA7i5quaAmwfzw3yrql4w+Fw0uniSpEd1XRjbgOsH09cDL+swiyTpOLoujB+vqgcABr8+Y5lxT0rST3JrkmVLJcllg3H9+fn59cgrSVPrlPX+A5LcBDxzyKorT2IzP1FVR5P8JPCJJF+oqv9cOqiqdgG7AHq9Xq0osCRpqHUvjKo6f7l1Sb6S5LSqeiDJacCDy2zj6ODXe5N8Cngh8LjCkCStn65PSe0GLh1MXwp8dOmAJE9P8sTB9GbgF4G7R5ZQkgR0XxjXAFuTHAK2DuZJ0kvyd4MxzwH6ST4HfBK4pqosDEkasXU/JXU8VfUQsGXI8j7w6sH0vwPPHXE0SdISXR9hSJImhIUhSWpiYWhi7Nx3sOsI0lSzMDQxrr35UNcRpKlmYUiSmlgYkqQmFoYkqUmn38OQlrNz38Gh1yxmd+x5zPz2LXNcsfXsUcWSplqqNuYz+nq9XvX7/a5jaA3N7tjDfddc2HUMaUNLsr+qesPWeUpKktTEwpAkNbEwJElNLAxNjO1b5rqOIE01C0MTw7uhpG5ZGJKkJhaGJKmJhSFJamJhSJKaWBiSpCYWhiSpiYUhSWpiYUiSmlgYkqQmFoYkqYmFIUlq0mlhJHllkruSfC/J0Bd2DMZdkOSeJIeT7BhlRknSgq6PMO4Efh24ZbkBSTYBbwdeDJwDXJLknNHEkyQ9qtN3elfVAYAkxxt2LnC4qu4djL0B2Abcve4BJUnf1/URRovTgfsXzR8ZLHucJJcl6Sfpz8/PjyScJE2LdT/CSHIT8Mwhq66sqo+2bGLIsho2sKp2AbsAer3e0DGSpJVZ98KoqvNXuYkjwJmL5s8Ajq5ym5KkkzQJp6RuA+aSnJXkCcDFwO6OM0nS1On6ttqXJzkCvAjYk+TGwfJnJdkLUFXHgMuBG4EDwAeq6q6uMkvStOr6LqmPAB8Zsvwo8JJF83uBvSOMJklaYhJOSUmSxoCFIUlqYmFIkppYGMvYue9g1xEkaaxYGMu49uZDXUeQpLFiYUiSmlgYkqQmFoYkqUmnX9wbFzv3HRx6zWJ2x57HzG/fMscVW88eVSxJGiup2pgPde31etXv91f8+2d37OG+ay5cw0SSNP6S7K+qoW9A9ZSUJKmJhSFJamJhSJKaWBjL2L5lrusIkjRWLIxleDeUJD2WhSFJamJhSJKaWBiSpCYb9ot7SeaBL61iE5uBr65RnC5tlP0A92VcbZR92Sj7Aavbl2dX1cywFRu2MFYrSX+5bztOko2yH+C+jKuNsi8bZT9g/fbFU1KSpCYWhiSpiYWxvF1dB1gjG2U/wH0ZVxtlXzbKfsA67YvXMCRJTTzCkCQ1sTAkSU0sjGUkeWOSzye5I8nHkzyr60wrleQvk3xxsD8fSfK0rjOtVJJXJrkryfeSTNwtkEkuSHJPksNJdnSdZzWSXJfkwSR3dp1lNZKcmeSTSQ4M/m1t7zrTSiV5UpL/SPK5wb78+Zpu32sYwyX54ar6xmD6dcA5VfWajmOtSJJfBT5RVceSvBmgql7fcawVSfIc4HvA3wB/XFUrf63iiCXZBBwEtgJHgNuAS6rq7k6DrVCSXwYeAd5TVT/bdZ6VSnIacFpV3Z7kh4D9wMsm8e8lSYCnVNUjSU4FPgNsr6pb12L7HmEs49GyGHgKMLHNWlUfr6pjg9lbgTO6zLMaVXWgqu7pOscKnQscrqp7q+o7wA3Ato4zrVhV3QI83HWO1aqqB6rq9sH0fwMHgNO7TbUyteCRweypg8+a/eyyMI4jyZuS3A/8JvCGrvOskd8FPtZ1iCl1OnD/ovkjTOgPpo0qySzwQuCz3SZZuSSbktwBPAjsq6o125epLowkNyW5c8hnG0BVXVlVZwLvBS7vNu3xnWhfBmOuBI6xsD9jq2VfJlSGLJvYI9eNJslTgQ8Bf7jkDMNEqarvVtULWDiTcG6SNTtdeMpabWgSVdX5jUP/EdgDXLWOcVblRPuS5FLgpcCWGvMLVyfx9zJpjgBnLpo/AzjaURYtMjjf/yHgvVX14a7zrIWq+lqSTwEXAGtyY8JUH2EcT5LF72i9CPhiV1lWK8kFwOuBi6rqm13nmWK3AXNJzkryBOBiYHfHmabe4ELxu4ADVfVXXedZjSQzj94FmeTJwPms4c8u75JaRpIPAT/Nwh05XwJeU1Vf7jbVyiQ5DDwReGiw6NYJvuPr5cBbgRnga8AdVfVr3aZql+QlwF8Dm4DrqupNHUdasSTvA85j4VHaXwGuqqp3dRpqBZL8EvBp4Ass/H8H+NOq2ttdqpVJ8jzgehb+ff0A8IGqunrNtm9hSJJaeEpKktTEwpAkNbEwJElNLAxJUhMLQ5LUxMKQJDWxMCRJTSwMaYQG713YOpj+iyRv6TqT1GqqnyUldeAq4Ookz2DhqagXdZxHauY3vaURS/JvwFOB8wbvX5AmgqekpBFK8lzgNODbloUmjYUhjcjgVaDvZeEte/+TZGIemiiBhSGNRJIfBD4M/FFVHQDeCPxZp6Gkk+Q1DElSE48wJElNLAxJUhMLQ5LUxMKQJDWxMCRJTSwMSVITC0OS1OT/AE4cDzAccAchAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Define training set\n", "X = np.array([-3, -1, 0, 1, 3]).reshape(-1,1) # 5x1 vector, N=5, D=1\n", "y = np.array([-1.2, -0.7, 0.14, 0.67, 1.67]).reshape(-1,1) # 5x1 vector\n", "\n", "# Plot the training set\n", "plt.figure()\n", "plt.plot(X, y, '+', markersize=10)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Maximum Likelihood\n", "We will start with maximum likelihood estimation of the parameters $\\boldsymbol\\theta$. In maximum likelihood estimation, we find the parameters $\\boldsymbol\\theta^{\\mathrm{ML}}$ that maximize the likelihood\n", "$$\n", "p(\\mathcal Y | \\mathcal X, \\boldsymbol\\theta) = \\prod_{n=1}^N p(y_n | \\boldsymbol x_n, \\boldsymbol\\theta)\\,.\n", "$$\n", "From the lecture we know that the maximum likelihood estimator is given by\n", "$$\n", "\\boldsymbol\\theta^{\\text{ML}} = (\\boldsymbol X^T\\boldsymbol X)^{-1}\\boldsymbol X^T\\boldsymbol y\\in\\mathbb{R}^D\\,,\n", "$$\n", "where \n", "$$\n", "\\boldsymbol X = [\\boldsymbol x_1, \\ldots, \\boldsymbol x_N]^T\\in\\mathbb{R}^{N\\times D}\\,,\\quad \\boldsymbol y = [y_1, \\ldots, y_N]^T \\in\\mathbb{R}^N\\,.\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us compute the maximum likelihood estimate for a given training set" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def max_lik_estimate(X, y):\n", " \n", " # X: N x D matrix of training inputs\n", " # y: N x 1 vector of training targets/observations\n", " # returns: maximum likelihood parameters (D x 1)\n", " \n", " N, D = X.shape\n", " theta_ml = np.linalg.solve(X.T @ X, X.T @ y) ## <-- SOLUTION\n", " return theta_ml" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# get maximum likelihood estimate\n", "theta_ml = max_lik_estimate(X,y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, make a prediction using the maximum likelihood estimate that we just found" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def predict_with_estimate(Xtest, theta):\n", " \n", " # Xtest: K x D matrix of test inputs\n", " # theta: D x 1 vector of parameters\n", " # returns: prediction of f(Xtest); K x 1 vector\n", " \n", " prediction = Xtest @ theta ## <-- SOLUTION\n", " \n", " return prediction " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's see whether we got something useful:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# define a test set\n", "Xtest = np.linspace(-5,5,100).reshape(-1,1) # 100 x 1 vector of test inputs\n", "\n", "# predict the function values at the test points using the maximum likelihood estimator\n", "ml_prediction = predict_with_estimate(Xtest, theta_ml)\n", "\n", "# plot\n", "plt.figure()\n", "plt.plot(X, y, '+', markersize=10)\n", "plt.plot(Xtest, ml_prediction)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Questions\n", "1. Does the solution above look reasonable?\n", "2. Play around with different values of $\\theta$. How do the corresponding functions change?\n", "3. Modify the training targets $\\mathcal Y$ and re-run your computation. What changes?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now look at a different training set, where we add 2.0 to every $y$-value, and compute the maximum likelihood estimate" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEGCAYAAABlxeIAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAP1UlEQVR4nO3dfYxldX3H8feny1ap2JJ0p3ELq9OkS9NqBezNqrFpiAstioE21QTT+NQ2ROPDSGzqRhOoWBNME7coiXQbqNBQqxFqqAuNI2qVP0DvbhYEF5eNwbBKyggF3GppVr/9Yw5xdvbOPszOOWfu3Pcruck59/zm3s/J7M4n5zlVhSRJv9B3AEnS6mAhSJIAC0GS1LAQJEmAhSBJapzSd4Dl2rBhQ01PT/cdQ5LGyq5du35YVVOjlo1tIUxPTzMcDvuOIUljJcn3llrmLiNJEmAhSJIaFoIkCbAQJGnsbJ/d18rnWgiSNGauufOhVj7XQpAkARaCJKlhIUiSgDG+ME2SJsH22X0jjxlMb9t52PzM1s1cfsFZJ/VdGdcH5AwGg/JKZUmTaHrbTh6++qJl/WySXVU1GLXMXUaSJMBCkCQ1LARJEmAhSNLYmdm6uZXPtRAkacyc7NlES7EQJEmAhSBJarReCEmem+QbSe5N8kCSD40Y89Ykc0n2NK+/bDuXJOlwXVyp/Azw6qo6mGQ9cFeSO6rq7kXjPlNV7+ogjyRphNYLoeYvhT7YzK5vXuN5ebQkrWGdHENIsi7JHuAxYLaq7hkx7E+T3Jfkc0k2LfE5lyUZJhnOzc21mlmSJk0nhVBVP62qc4AzgS1JXrJoyL8D01X1UuBLwI1LfM6OqhpU1WBqaqrd0JI0YTo9y6iqngS+Cly46P3Hq+qZZvYfgd/rMpckqZuzjKaSnN5MnwqcDzy4aMzGBbMXA3vbziVJOlwXZxltBG5Mso75AvpsVX0hyVXAsKpuA96T5GLgEPAE8NYOckmSFvB5CJI0QXwegiTpmCwESRJgIUiSGhaCJAmwECRJDQtBkgRYCJKkhoUgSQIsBElSw0KQJAEWgiSpYSFIkgALQZLUsBAkSYCFIElqWAiSJMBCkCQ1LARJEmAhSJIaFoIkCbAQJEkNC0GSBFgIkqSGhSBJAiwESVKj9UJI8twk30hyb5IHknxoxJjnJPlMkv1J7kky3XYuSdLhuthCeAZ4dVWdDZwDXJjkFYvG/AXw31X1m8B24KMd5JIkLdB6IdS8g83s+uZVi4ZdAtzYTH8O2JokbWeTJP1cJ8cQkqxLsgd4DJitqnsWDTkDeASgqg4BTwG/OuJzLksyTDKcm5trO7YkTZROCqGqflpV5wBnAluSvGTRkFFbA4u3IqiqHVU1qKrB1NRUG1ElaWJ1epZRVT0JfBW4cNGiA8AmgCSnAL8CPNFlNkmadF2cZTSV5PRm+lTgfODBRcNuA97STL8e+HJVHbGFIElqzykdfMdG4MYk65gvoM9W1ReSXAUMq+o24Hrgn5PsZ37L4NIOckmSFmi9EKrqPuDcEe9fsWD6f4E3tJ1FkrQ0r1SWJAEWgiSpYSFIkgALQZLUsBAkSYCFIElqWAiSJMBCkCQ1LARJEmAhSJIaFoIkCbAQJEkNC0HSkrbP7us7gjpkIUha0jV3PtR3BHXIQpAkARaCJKlhIUiSgG4eoSlpDGyf3TfymMH0tp2Hzc9s3czlF5zVVSx1KOP6LPvBYFDD4bDvGNKaNr1tJw9ffVHfMbSCkuyqqsGoZe4ykiQBFoIkqWEhSJIAC0HSUcxs3dx3BHXIQpC0JM8mmiwWgiQJsBAkSY3WCyHJpiRfSbI3yQNJZkaMOS/JU0n2NK8r2s4lSTpcF1cqHwLeV1W7kzwf2JVktqq+vWjc16vqdR3kkSSN0PoWQlU9WlW7m+kfAXuBM9r+XknSien0GEKSaeBc4J4Ri1+Z5N4kdyR58RI/f1mSYZLh3Nxci0klafJ0VghJTgNuAd5bVU8vWrwbeFFVnQ18Avj8qM+oqh1VNaiqwdTUVLuBJWnCdFIISdYzXwY3V9Wti5dX1dNVdbCZvh1Yn2RDF9kkSfO6OMsowPXA3qr62BJjXtCMI8mWJtfjbWeTJP1cF2cZvQp4E/CtJHua9z4AvBCgqq4DXg+8I8kh4CfApTWu9+WWpDHVeiFU1V1AjjHmWuDatrNIkpbmlcqSJMBCkCQ1LARJEmAhSJIaFoIkCbAQJEkNC0GSBFgIkqSGhSBJAiwESVLDQpAkAcdRCEm+lOTsLsJIkvpzPFsIfw1sT/JPSTa2HUiS1I9jFkJV7a6qVwNfAP4jyZVJTm0/miSpS8d1DKF5eM13gE8C7wYeSvKmNoNJ42z77L6+I0gn7HiOIdwFfB/YDpwBvBU4D9iSZEeb4aRxdc2dD/UdQTphx/OAnLcDD4x4gtm7k+xtIZMkqQfHLISquv8oiy9awSySpB6d1HUIVfXdlQoiSepX689Ulta67bP7Rh4zmN6287D5ma2bufyCs7qKJZ2wHHloYDwMBoMaDod9x5BGmt62k4evdo+qVp8ku6pqMGqZt66QJAEWgiSpYSFIkgALQWrFzNbNfUeQTljrhZBkU5KvJNmb5IEkMyPGJMnHk+xPcl+Sl7WdS2qTZxNpHHVx2ukh4H1VtTvJ84FdSWar6tsLxrwG2Ny8Xs78PZNe3kE2SVKj9S2Eqnq0qnY30z8C9jJ/T6SFLgFuqnl3A6d7q21J6lanxxCSTAPnAvcsWnQG8MiC+QMcWRokuSzJMMlwbm6urZiSNJE6K4QkpwG3AO+tqqcXLx7xI0dcMVdVO6pqUFWDqampNmJK0sTqpBCSrGe+DG6uqltHDDkAbFowfybwgy6ySZLmdXGWUYDrgb1V9bElht0GvLk52+gVwFNV9Wjb2SRJP9fFWUavAt4EfCvJnua9DwAvBKiq64DbgdcC+4EfA2/rIJckaYHWC6Gq7mL0MYKFYwp4Z9tZJElL80plSRJgIUiSGhaCJAmwECRJDQtBkgRYCJKkhoUgSQIsBElSw0KQJAEWgiSpYSFIkgALQZLUsBAkSYCFIElqWAiSJMBCkCQ1LARJEmAhSJIaFoIkCbAQJEkNC0GSBFgIkqSGhSBJAiwESVLDQpAkARaCJKnReiEkuSHJY0nuX2L5eUmeSrKneV3RdiZJ0pFO6eA7PgVcC9x0lDFfr6rXdZBFkrSE1rcQquprwBNtf48k6eSslmMIr0xyb5I7krx4qUFJLksyTDKcm5vrMp8krXmroRB2Ay+qqrOBTwCfX2pgVe2oqkFVDaampjoLKEmToPdCqKqnq+pgM307sD7Jhp5jSdLE6b0QkrwgSZrpLcxnerzfVJI0eVo/yyjJp4HzgA1JDgBXAusBquo64PXAO5IcAn4CXFpV1XYuSdLhWi+EqnrjMZZfy/xpqZKkHvW+y0iStDpYCJIkwELQKrJ9dl/fEaSJZiFo1bjmzof6jiBNNAtBkgRYCJKkhoUgSQK6uf21dITts/tGHjOY3rbzsPmZrZu5/IKzuoolTbSM60XBg8GghsNh3zG0gqa37eThqy/qO4a0piXZVVWDUcvcZSRJAiwESVLDQpAkARaCVpGZrZv7jiBNNAtBq4ZnE0n9shAkSYCFIElqWAiSJMBCkCQ1LARJEmAhSJIaFoIkCbAQJEkNC0GSBFgIkqSGhSBJAjoohCQ3JHksyf1LLE+SjyfZn+S+JC9rO5Mk6UhdbCF8CrjwKMtfA2xuXpcBn+wgkyRpkdYLoaq+BjxxlCGXADfVvLuB05NsbDuXJOlwq+EYwhnAIwvmDzTvHSHJZUmGSYZzc3OdhJOkSbEaCiEj3qtRA6tqR1UNqmowNTXVcixJmiyroRAOAJsWzJ8J/KCnLJI0sVZDIdwGvLk52+gVwFNV9WjfoSRp0pzS9hck+TRwHrAhyQHgSmA9QFVdB9wOvBbYD/wYeFvbmSRJR2q9EKrqjcdYXsA7284hSTq61bDLSJK0ClgIkiTAQpAkNSayELbP7us7giStOhNZCNfc+VDfESRp1ZnIQpAkHclCkCQBFoIkqdH6hWl92z67b+Qxg+ltOw+bn9m6mcsvOKurWJK06mT+QuHxMxgMajgcLutnp7ft5OGrL1rhRJK0+iXZVVWDUcvcZSRJAiwESVLDQpAkARNaCDNbN/cdQZJWnYksBM8mkqQjTWQhSJKOZCFIkgALQZLUGNsL05LMAd9b5o9vAH64gnH65LqsTmtlXdbKeoDr8qwXVdXUqAVjWwgnI8lwqSv1xo3rsjqtlXVZK+sBrsvxcJeRJAmwECRJjUkthB19B1hBrsvqtFbWZa2sB7guxzSRxxAkSUea1C0ESdIiFoIkCZjgQkjy4ST3JdmT5ItJfr3vTMuV5O+SPNisz78lOb3vTMuV5A1JHkjysyRjd4pgkguTfCfJ/iTb+s6zXEluSPJYkvv7znKykmxK8pUke5t/WzN9Z1qOJM9N8o0k9zbr8aEV/45JPYaQ5Jer6ulm+j3A71TV23uOtSxJ/hD4clUdSvJRgKp6f8+xliXJbwM/A/4B+KuqWt5j8XqQZB2wD7gAOAB8E3hjVX2712DLkOQPgIPATVX1kr7znIwkG4GNVbU7yfOBXcAfj9vvJUmA51XVwSTrgbuAmaq6e6W+Y2K3EJ4tg8bzgLFtxqr6YlUdambvBs7sM8/JqKq9VfWdvnMs0xZgf1V9t6r+D/hX4JKeMy1LVX0NeKLvHCuhqh6tqt3N9I+AvcAZ/aY6cTXvYDO7vnmt6N+tiS0EgCQfSfII8GfAFX3nWSF/DtzRd4gJdQbwyIL5A4zhH561LMk0cC5wT79JlifJuiR7gMeA2apa0fVY04WQ5EtJ7h/xugSgqj5YVZuAm4F39Zv26I61Ls2YDwKHmF+fVet41mVMZcR7Y7vludYkOQ24BXjvoj0EY6OqflpV5zC/F2BLkhXdnXfKSn7YalNV5x/n0H8BdgJXthjnpBxrXZK8BXgdsLVW+YGhE/i9jJsDwKYF82cCP+gpixZo9rnfAtxcVbf2nedkVdWTSb4KXAis2IH/Nb2FcDRJFj5H82Lgwb6ynKwkFwLvBy6uqh/3nWeCfRPYnOQ3kvwicClwW8+ZJl5zMPZ6YG9VfazvPMuVZOrZMwiTnAqczwr/3Zrks4xuAX6L+TNavge8vaq+32+q5UmyH3gO8Hjz1t1jfMbUnwCfAKaAJ4E9VfVH/aY6fkleC/w9sA64oao+0nOkZUnyaeA85m+z/F/AlVV1fa+hlinJ7wNfB77F/P93gA9U1e39pTpxSV4K3Mj8v61fAD5bVVet6HdMaiFIkg43sbuMJEmHsxAkSYCFIElqWAiSJMBCkCQ1LARJEmAhSJIaFoK0Qpp77l/QTP9tko/3nUk6EWv6XkZSx64Erkrya8zfUfPinvNIJ8QrlaUVlOQ/gdOA85p770tjw11G0gpJ8rvARuAZy0DjyEKQVkDzmMabmX9C2v8kGZsb8knPshCkk5Tkl4BbgfdV1V7gw8Df9BpKWgaPIUiSALcQJEkNC0GSBFgIkqSGhSBJAiwESVLDQpAkARaCJKnx/8YDRQrCx2s4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ynew = y + 2.0\n", "\n", "plt.figure()\n", "plt.plot(X, ynew, '+', markersize=10)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.499]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# get maximum likelihood estimate\n", "theta_ml = max_lik_estimate(X, ynew)\n", "print(theta_ml)\n", "\n", "# define a test set\n", "Xtest = np.linspace(-5,5,100).reshape(-1,1) # 100 x 1 vector of test inputs\n", "\n", "# predict the function values at the test points using the maximum likelihood estimator\n", "ml_prediction = predict_with_estimate(Xtest, theta_ml)\n", "\n", "# plot\n", "plt.figure()\n", "plt.plot(X, ynew, '+', markersize=10)\n", "plt.plot(Xtest, ml_prediction)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Question:\n", "1. This maximum likelihood estimate doesn't look too good: The orange line is too far away from the observations although we just shifted them by 2. Why is this the case?\n", "2. How can we fix this problem?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now define a linear regression model that is slightly more flexible:\n", "$$\n", "y = \\theta_0 + \\boldsymbol x^T \\boldsymbol\\theta_1 + \\epsilon\\,,\\quad \\epsilon\\sim\\mathcal N(0,\\sigma^2)\n", "$$\n", "Here, we added an offset (bias) parameter $\\theta_0$ to our original model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Question:\n", "1. What is the effect of this bias parameter, i.e., what additional flexibility does it offer?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we now define the inputs to be the augmented vector $\\boldsymbol x_{\\text{aug}} = \\begin{bmatrix}1\\\\\\boldsymbol x\\end{bmatrix}$, we can write the new linear regression model as \n", "$$\n", "y = \\boldsymbol x_{\\text{aug}}^T\\boldsymbol\\theta_{\\text{aug}} + \\epsilon\\,,\\quad \\boldsymbol\\theta_{\\text{aug}} = \\begin{bmatrix}\n", "\\theta_0\\\\\n", "\\boldsymbol\\theta_1\n", "\\end{bmatrix}\\,.\n", "$$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "N, D = X.shape\n", "X_aug = np.hstack([np.ones((N,1)), X]) # augmented training inputs of size N x (D+1)\n", "theta_aug = np.zeros((D+1, 1)) # new theta vector of size (D+1) x 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now compute the maximum likelihood estimator for this setting.\n", "_Hint:_ If possible, re-use code that you have already written" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def max_lik_estimate_aug(X_aug, y):\n", " \n", " theta_aug_ml = max_lik_estimate(X_aug, y) ## <-- SOLUTION\n", " \n", " return theta_aug_ml" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "theta_aug_ml = max_lik_estimate_aug(X_aug, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can make predictions again:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# define a test set (we also need to augment the test inputs with ones)\n", "Xtest_aug = np.hstack([np.ones((Xtest.shape[0],1)), Xtest]) # 100 x (D + 1) vector of test inputs\n", "\n", "# predict the function values at the test points using the maximum likelihood estimator\n", "ml_prediction = predict_with_estimate(Xtest_aug, theta_aug_ml)\n", "\n", "# plot\n", "plt.figure()\n", "plt.plot(X, y, '+', markersize=10)\n", "plt.plot(Xtest, ml_prediction)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It seems this has solved our problem! \n", "#### Question:\n", "1. Play around with the first parameter of $\\boldsymbol\\theta_{\\text{aug}}$ and see how the fit of the function changes.\n", "2. Play around with the second parameter of $\\boldsymbol\\theta_{\\text{aug}}$ and see how the fit of the function changes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Nonlinear Features\n", "So far, we have looked at linear regression with linear features. This allowed us to fit straight lines. However, linear regression also allows us to fit functions that are nonlinear in the inputs $\\boldsymbol x$, as long as the parameters $\\boldsymbol\\theta$ appear linearly. This means, we can learn functions of the form\n", "$$\n", "f(\\boldsymbol x, \\boldsymbol\\theta) = \\sum_{k = 1}^K \\theta_k \\phi_k(\\boldsymbol x)\\,,\n", "$$\n", "where the features $\\phi_k(\\boldsymbol x)$ are (possibly nonlinear) transformations of the inputs $\\boldsymbol x$.\n", "\n", "Let us have a look at an example where the observations clearly do not lie on a straight line:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOAElEQVR4nO3de4yld13H8feHXRDKJWA6JrXtOphsiQ1e0Em9YEjDslpp02qisUQJqMmGP4ClwcgoiVWQpEbDUo0hbqCIoUK0rZGwiB0KFUmkMlurtCzdbWpLtyBd1IaLxlr79Y85yO50S2fn9psz3/cr2XTOM2fO832y2/Oe53LOSVUhSernKaMHkCSNYQAkqSkDIElNGQBJasoASFJTO0cPcCbOPvvsmp2dHT2GJE2Vw4cPf7mqZpYvn6oAzM7Osri4OHoMSZoqSe4/3XIPAUlSUwZAkpoyAJLUlAGQpKYMgCQ11SYABxaOjh5BkraUDQ9AkuuSPJTkzpOWfXuShSTHJv993kbPce0txzZ6FZI0VTZjD+BPgEuWLZsHbqmq3cAtk9uSpE204S8Eq6pPJJldtvgK4OLJ1+8FbgXetN7rPrBw9JTf/GfnDwGwf89urtp7wXqvTpKmSjbjA2EmAfhQVb1wcvvhqnruSd//j6o67WGgJPuAfQC7du36ofvvP+0L2p7U7Pwh7rvm0lX9rCRNsySHq2pu+fItfxK4qg5W1VxVzc3MPO6tLCRJqzQqAF9Kcg7A5L8PbfQK9+/ZvdGrkKSpMioAHwReNfn6VcBfbfQKPeYvSafajMtA3w/8PfCCJMeT/ApwDbA3yTFg7+S2JGkTbcZVQK94gm/t2eh1S5Ke2JY/CSxJ2hgGQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkASWrKAEhSUwZAkpoyAJLUlAGQpKYMgCQ1ZQAkqSkDIElNGQBJasoASFJTBkCSmjIAktSUAZCkpgyAJDVlACSpKQMgSU0ZAElqygBIUlMGQJKaGhqAJFcluSvJnUnen+TpI+eRpK3owMLRDXncYQFIci7wemCuql4I7ACuHDWPJG1V195ybEMed/QhoJ3AM5LsBM4CvjB4HklqY+eoFVfVg0l+H/g88F/AzVV18/L7JdkH7APYtWvX5g4pSYMcWDh6ym/+s/OHANi/ZzdX7b1gXdaRqlqXBzrjFSfPA24Efh54GPgL4Iaqet8T/czc3FwtLi5u0oSStDXMzh/ivmsuXfXPJzlcVXPLl488BPQy4F+q6kRV/Q9wE/BjA+eRpFZGBuDzwI8kOStJgD3AkYHzSNKWtH/P7g153GEBqKrbgBuA24HPTGY5OGoeSdqq1uuY/3LDTgIDVNXVwNUjZ5CkrkZfBipJGsQASFJTBkCSmjIAktSUAZCkpgyAJDVlACSpKQMgSU0ZAElqygBIUlMGQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkASWrKAEhSUwZAkpoyAJLUlAGQpKYMgCQ1ZQAkqSkDIElNGQBJasoASFJTQwOQ5LlJbkjyuSRHkvzoyHkkqZOdg9d/LfCRqvrZJE8Dzho8jyS1MSwASZ4DvAR4NUBVPQI8MmoeSepm5CGg7wZOAO9J8o9J3pXkmcvvlGRfksUkiydOnNj8KSVpmxoZgJ3ADwLvrKoXAV8H5pffqaoOVtVcVc3NzMxs9oyStG2NDMBx4HhV3Ta5fQNLQZAkbYJhAaiqfwUeSPKCyaI9wGdHzSNJ3Yy+Cuh1wPWTK4DuBX5p8DyS1MbQAFTVHcDcyBkkqStfCSxJTRkASWrKAEhSUwZAkpoyAJLUlAGQpKYMgCQ1ZQAkqSkDIElNGQBJasoASFJTBkCSmjIAktSUAZCkpgyAJDVlACSpKQMgSU0ZAElqygBIUlMGQJKaMgCS1NSTBiDJR5N8/2YMI0naPCvZA/g14ECS9yQ5Z6MHkiRtjicNQFXdXlUvBT4EfCTJ1UmesfGjSZI20orOASQJcDfwTuB1wLEkr9zIwSRJG2sl5wA+CTwIHADOBV4NXAxclOTgRg4nSdo4O1dwn9cAd1VVLVv+uiRH1jpAkh3AIvBgVV221seTJK3MSs4B3HmaJ/9vuHQdZtgPrDkkkqQzs6bXAVTVvWv5+STnsRSRd63lcSRJZ270C8HewdJlpo890R2S7EuymGTxxIkTmzeZJG1zwwKQ5DLgoao6/K3uV1UHq2ququZmZmY2aTpJ2v5G7gG8GLg8yX3AB4CXJnnfwHkkqZVhAaiqX6+q86pqFrgS+FhV/eKoeSSpm9HnACRJg6zkdQAbrqpuBW4dPIYkteIegCQ1ZQAkqSkDIElNGQBJasoASFJTBkCSmjIAktSUAZCkpgyAJDVlACSpKQMgSU0ZAElqygBIUlMGQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkASWrKAEhSUwZAkpoyAJLUlAGQpKYMgCQ1ZQAkqalhAUhyfpKPJzmS5K4k+0fNIkkd7Ry47keBN1bV7UmeDRxOslBVnx04kyS1MWwPoKq+WFW3T77+KnAEOHfUPJLUzZY4B5BkFngRcNtpvrcvyWKSxRMnTmz2aJK0bQ0PQJJnATcCb6iqryz/flUdrKq5qpqbmZnZ/AElaZsaGoAkT2Xpyf/6qrpp5CyS1M3Iq4ACvBs4UlVvHzWHJHU1cg/gxcArgZcmuWPy5+UD55GkVoZdBlpVnwQyav2S1N3wk8Dq68DC0dEjSK0ZAA1z7S3HRo8gtWYAJKmpkW8FoYYOLBw95Tf/2flDAOzfs5ur9l4waiyppVTV6BlWbG5urhYXF0ePoXUyO3+I+665dPQY0raX5HBVzS1f7iEgSWrKAGiY/Xt2jx5Bas0AaBiP+UtjGQBJasoASFJTBkCSmjIAktSUAZCkpgyAJDVlACSpKQMgSU0ZAElqygBIUlMGQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkASWrKAEhSU0MDkOSSJHcnuSfJ/MhZJC05sHB09AjaJMMCkGQH8EfATwEXAq9IcuGoeSQtufaWY6NH0CYZuQdwEXBPVd1bVY8AHwCuGDiPJLWyc+C6zwUeOOn2ceCHl98pyT5gH8CuXbs2ZzKpmQMLR0/5zX92/hAA+/fs5qq9F4waSxtsZABymmX1uAVVB4GDAHNzc4/7vqS1u2rvBf//RD87f4j7rrl08ETaDCMPAR0Hzj/p9nnAFwbNIkntjAzAp4HdSZ6f5GnAlcAHB84jiaXDPuph2CGgqno0yWuBvwF2ANdV1V2j5pG0xGP+fYw8B0BVfRj48MgZJKkrXwksSU0ZAElqygBIUlMGQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkASWrKAEhSUwZAkpoyAJLUlAGQ1sGBhaOjR5DOmAGQ1sHJn6crTQsDIElNDf1AGGmaHVg4espv/rPzh4Clj1T0U7U0DVJVo2dYsbm5uVpcXBw9hvQ4s/OHuO+aS0ePIZ1WksNVNbd8uYeAJKkpAyCtg/17do8eQTpjBkBaBx7z1zQyAJLUlAGQpKYMgCQ1ZQAkqSkDIElNTdULwZKcAO5f5Y+fDXx5HccZabtsy3bZDnBbtqrtsi1r3Y7vqqqZ5QunKgBrkWTxdK+Em0bbZVu2y3aA27JVbZdt2ajt8BCQJDVlACSpqU4BODh6gHW0XbZlu2wHuC1b1XbZlg3ZjjbnACRJp+q0ByBJOokBkKSmWgUgyVuT/HOSO5LcnOQ7R8+0Wkl+L8nnJtvzl0meO3qm1Ujyc0nuSvJYkqm8XC/JJUnuTnJPkvnR86xWkuuSPJTkztGzrEWS85N8PMmRyb+t/aNnWq0kT0/yD0n+abItv72uj9/pHECS51TVVyZfvx64sKpeM3isVUnyE8DHqurRJL8LUFVvGjzWGUvyPcBjwB8Dv1pVU/WRb0l2AEeBvcBx4NPAK6rqs0MHW4UkLwG+BvxpVb1w9DyrleQc4Jyquj3Js4HDwE9P6d9JgGdW1deSPBX4JLC/qj61Ho/fag/gG0/+E88EprZ+VXVzVT06ufkp4LyR86xWVR2pqrtHz7EGFwH3VNW9VfUI8AHgisEzrUpVfQL499FzrFVVfbGqbp98/VXgCHDu2KlWp5Z8bXLzqZM/6/a81SoAAEneluQB4BeA3xw9zzr5ZeCvRw/R1LnAAyfdPs6UPtlsR0lmgRcBt42dZPWS7EhyB/AQsFBV67Yt2y4AST6a5M7T/LkCoKreXFXnA9cDrx077bf2ZNsyuc+bgUdZ2p4taSXbMcVymmVTu2e5nSR5FnAj8IZle/9Tpar+t6p+gKW9/IuSrNvhuZ3r9UBbRVW9bIV3/TPgEHD1Bo6zJk+2LUleBVwG7KktfDLnDP5OptFx4PyTbp8HfGHQLJqYHC+/Ebi+qm4aPc96qKqHk9wKXAKsy4n6bbcH8K0kOfmTuy8HPjdqlrVKcgnwJuDyqvrP0fM09mlgd5LnJ3kacCXwwcEztTY5cfpu4EhVvX30PGuRZOYbV/gleQbwMtbxeavbVUA3Ai9g6aqT+4HXVNWDY6danST3AN8G/Ntk0aem8YqmJD8D/CEwAzwM3FFVPzl2qjOT5OXAO4AdwHVV9bbBI61KkvcDF7P01sNfAq6uqncPHWoVkvw48HfAZ1j6fx3gN6rqw+OmWp0k3we8l6V/W08B/ryq3rJuj98pAJKkb2p1CEiS9E0GQJKaMgCS1JQBkKSmDIAkNWUAJKkpAyBJTRkAaQ0m7zu/d/L17yT5g9EzSSu17d4LSNpkVwNvSfIdLL3r5OWD55FWzFcCS2uU5G+BZwEXT95/XpoKHgKS1iDJ9wLnAP/tk7+mjQGQVmny0YPXs/QJYF9PMlVvYicZAGkVkpwF3AS8saqOAG8FfmvoUNIZ8hyAJDXlHoAkNWUAJKkpAyBJTRkASWrKAEhSUwZAkpoyAJLU1P8Bb+BW42QZKkMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y = np.array([10.05, 1.5, -1.234, 0.02, 8.03]).reshape(-1,1)\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Polynomial Regression\n", "One class of functions that is covered by linear regression is the family of polynomials because we can write a polynomial of degree $K$ as\n", "$$\n", "\\sum_{k=0}^K \\theta_k x^k = \\boldsymbol \\phi(x)^T\\boldsymbol\\theta\\,,\\quad\n", "\\boldsymbol\\phi(x)= \n", "\\begin{bmatrix}\n", "x^0\\\\\n", "x^1\\\\\n", "\\vdots\\\\\n", "x^K\n", "\\end{bmatrix}\\in\\mathbb{R}^{K+1}\\,.\n", "$$\n", "Here, $\\boldsymbol\\phi(x)$ is a nonlinear feature transformation of the inputs $x\\in\\mathbb{R}$.\n", "\n", "Similar to the earlier case we can define a matrix that collects all the feature transformations of the training inputs:\n", "$$\n", "\\boldsymbol\\Phi = \\begin{bmatrix}\n", "\\boldsymbol\\phi(x_1) & \\boldsymbol\\phi(x_2) & \\cdots & \\boldsymbol\\phi(x_n)\n", "\\end{bmatrix}^T \\in\\mathbb{R}^{N\\times K+1}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us start by computing the feature matrix $\\boldsymbol \\Phi$" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def poly_features(X, K):\n", " \n", " # X: inputs of size N x 1\n", " # K: degree of the polynomial\n", " # computes the feature matrix Phi (N x (K+1))\n", " \n", " X = X.flatten()\n", " N = X.shape[0]\n", " \n", " #initialize Phi\n", " Phi = np.zeros((N, K+1))\n", " \n", " # Compute the feature matrix in stages\n", " for k in range(K+1):\n", " Phi[:,k] = X**k ## <-- SOLUTION\n", " return Phi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With this feature matrix we get the maximum likelihood estimator as\n", "$$\n", "\\boldsymbol \\theta^\\text{ML} = (\\boldsymbol\\Phi^T\\boldsymbol\\Phi)^{-1}\\boldsymbol\\Phi^T\\boldsymbol y\n", "$$\n", "For reasons of numerical stability, we often add a small diagonal \"jitter\" $\\kappa>0$ to $\\boldsymbol\\Phi^T\\boldsymbol\\Phi$ so that we can invert the matrix without significant problems so that the maximum likelihood estimate becomes\n", "$$\n", "\\boldsymbol \\theta^\\text{ML} = (\\boldsymbol\\Phi^T\\boldsymbol\\Phi + \\kappa\\boldsymbol I)^{-1}\\boldsymbol\\Phi^T\\boldsymbol y\n", "$$" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def nonlinear_features_maximum_likelihood(Phi, y):\n", " # Phi: features matrix for training inputs. Size of N x D\n", " # y: training targets. Size of N by 1\n", " # returns: maximum likelihood estimator theta_ml. Size of D x 1\n", " \n", " kappa = 1e-08 # 'jitter' term; good for numerical stability\n", " \n", " D = Phi.shape[1] \n", " \n", " # maximum likelihood estimate\n", " Pt = Phi.T @ y # Phi^T*y\n", " PP = Phi.T @ Phi + kappa*np.eye(D) # Phi^T*Phi + kappa*I\n", " \n", " # maximum likelihood estimate\n", " C = scipy.linalg.cho_factor(PP)\n", " theta_ml = scipy.linalg.cho_solve(C, Pt) # inv(Phi^T*Phi)*Phi^T*y \n", " \n", " return theta_ml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have all the ingredients together: The computation of the feature matrix and the computation of the maximum likelihood estimator for polynomial regression. Let's see how this works.\n", "\n", "To make predictions at test inputs $\\boldsymbol X_{\\text{test}}\\in\\mathbb{R}$, we need to compute the features (nonlinear transformations) $\\boldsymbol\\Phi_{\\text{test}}= \\boldsymbol\\phi(\\boldsymbol X_{\\text{test}})$ of $\\boldsymbol X_{\\text{test}}$ to give us the predicted mean\n", "$$\n", "\\mathbb{E}[\\boldsymbol y_{\\text{test}}] = \\boldsymbol \\Phi_{\\text{test}}\\boldsymbol\\theta^{\\text{ML}}\n", "$$" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "K = 5 # Define the degree of the polynomial we wish to fit\n", "Phi = poly_features(X, K) # N x (K+1) feature matrix\n", "\n", "theta_ml = nonlinear_features_maximum_likelihood(Phi, y) # maximum likelihood estimator\n", "\n", "# test inputs\n", "Xtest = np.linspace(-4,4,100).reshape(-1,1)\n", "\n", "# feature matrix for test inputs\n", "Phi_test = poly_features(Xtest, K)\n", "\n", "y_pred = Phi_test @ theta_ml # predicted y-values\n", "\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.plot(Xtest, y_pred)\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Experiment with different polynomial degrees in the code above.\n", "#### Questions:\n", "1. What do you observe?\n", "2. What is a good fit?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating the Quality of the Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us have a look at a more interesting data set" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEGCAYAAAB2EqL0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQ7klEQVR4nO3df4wc513H8c8HJyGFAgXsENd2uCDOqFZ/pNXJSpV/ojgHTlPZtBDJEZQEWlmVmmIsULlgqYEUJJdKHKmIWpk2NEBJWpVWMbmAe3FSIgQpWYcktev6bKwUXx3ItYHyI5DIzZc/ds5d782dn/Pt7jMz+35JJ9/sTPY+SuL77MwzzzOOCAEAcD7fkzsAAKAeKAwAQBIKAwCQhMIAACShMAAASS7KHaBfVq9eHSMjI7ljAECtHDp06JsRsaZsX2MLY2RkRK1WK3cMAKgV219fbB+XpAAASSgMAEASCgMAkITCAAAkoTAAAEkoDKAPJqdnckcAeo7CAPrgroPHc0cAeo7CAAAkaezEPWDQJqdnzjmzGJmYkiTt2jKq3eMbc8UCesZNfYDS2NhYMNMbuYxMTOnZvTfmjgEsm+1DETFWto9LUgCAJBQG0Ae7tozmjgD0HIUB9AFjFmgiCgMAkITCAAAkoTAAAEkoDABAEgoDAJCEwgAAJKEwAABJKAwAQBIKAwCQpBKFYfse28/bPrzIftv+qO0Ttp+x/ZZBZwSAYVeJwpD0KUlbl9h/g6TR4munpI8NIBMAoEMlCiMiHpP0whKHbJf0p9H2uKTX2F47mHQAAKkihZFgnaRTHduzxWvnsL3Tdst2a25ubmDhAGAY1KUwXPLagic/RcS+iBiLiLE1a9YMIBYADI+6FMaspA0d2+slnc6UBQCGUl0KY7+kXyrulrpa0rcj4rncoQBgmFyUO4Ak2b5P0rWSVtuelXSHpIslKSI+LukhSW+TdELSi5J+OU9SABhelSiMiLj5PPtD0vsGFAcAUKIul6QAAJlRGACAJBQGACAJhQEASEJhAACSUBgAgCQUBgAgCYUBAEhCYQAVNDk9kzsCsACFAVTQXQeP544ALEBhAACSVGItKQDty1CdZxYjE1OSpF1bRrV7fGOuWMBZbq/r1zxjY2PRarVyx0BNTU7PZP0lPTIxpWf33njB/3zu/Kgv24ciYqxsH5ekgBJ1H0Ooe35UE4UBVNCuLaO5IwALcEkKKHSPIcyryxhC3fOjGpa6JEVhACVWOoaQW93zIx/GMAAAK0ZhACXqPoZQ9/yoJi5JAQDO4pIUAGDFKAwAQBIKAwCQhMIAACShMAAASSgMAEASCgMAkITCAAAkoTAAAEkoDABAEgoDAJCEwgAAJKEwAABJKAwAQJJKFIbtrbaP2T5he6Jk/62252w/VXy9J0dOABhm2QvD9ipJd0u6QdImSTfb3lRy6Gci4qri6xMDDYnamZyeyR0BaJzshSFps6QTEXEyIl6WdL+k7ZkzoebuOng8dwSgcapQGOsknerYni1e6/Zztp+x/TnbG8reyPZO2y3brbm5uX5kBYChdVHuAJJc8lr3c2P/StJ9EfGS7fdKulfSdQv+oYh9kvZJ7Ue09jooqm1yeuacM4uRiSlJ7edb7x7fmCsW0BhVKIxZSZ1nDOslne48ICK+1bH5x5I+PIBcqJnd4xvPFsPIxJSe3Xtj5kRAs1ThktQTkkZtX2n7Ekk7JO3vPMD22o7NbZKODjAfAEAVOMOIiDO2b5N0QNIqSfdExBHbd0pqRcR+Sb9qe5ukM5JekHRrtsCohV1bRnNHABrHEc281D82NhatVit3DACoFduHImKsbF8VLkkBAGqAwgAAJKEwAABJKAwAQBIKAwCQhMJAJbF4IFA9FAYqicUDgeqhMAAASbLP9AbmsXggUG3M9EYlsXggkAczvQEMFDctNBOFgUpi8cB646aFZqIwUEmMWQDVw6A3gJ7gpoXmY9AbQM9x00J9MegNAFgxCgNAz3HTQjNRGAB6jjGLZqIwACzAPAqUoTAALMA8CpShMAAASZiHAUAS8yhwfszDALAA8yiGF/MwAAArRmEAWIB5FChDYQBYgDELlKEwAABJKAwAQBIKAwCQhMIAACShMAAASSgMAECS8xaG7Ydtv2kQYQAA1ZVyhvEBSZO2/8T22n6EsL3V9jHbJ2xPlOz/XtufKfZ/2fZIP3IAABZ33sKIiCcj4jpJD0r6G9t32H5VrwLYXiXpbkk3SNok6Wbbm7oOe7ekf4+In5Q0KenDvfr5AIA0SWMYti3pmKSPSXq/pOO239WjDJslnYiIkxHxsqT7JW3vOma7pHuL7z8naUuRCQAwICljGH8n6Rtqf7JfJ+lWSddK2mx7Xw8yrJN0qmN7tnit9JiIOCPp25J+tAc/GwCQKOV5GO+VdCQWroP+fttHe5Ch7Eyh+2elHCPbOyXtlKQrrrhi5ckAAGeljGEcLimLeb1YMH9W0oaO7fWSTi92jO2LJP2QpBe63ygi9kXEWESMrVmzpgfRAADzVjQPIyJO9iDDE5JGbV9p+xJJOyTt7zpmv6Rbiu9/XtIjS5QYAKAPsj+iNSLO2L5N0gFJqyTdExFHbN8pqRUR+yV9UtKf2T6h9pnFjnyJAWA4ZS8MSYqIhyQ91PXaBzu+/z9JNw06FwDgu1gaBACQhMIAACShMAAASSgMAEASCgMAkITCAAAkoTAAAEkoDABAEgoDAJCEwgAAJKEwAABJKAwAQBIKAwCQhMJAX0xOz+SOAKDHKAz0xV0Hj+eOAKDHKAwAQJJKPEAJzTA5PXPOmcXIxJQkadeWUe0e35grFoAecVMfjT02NhatVit3jKE1MjGlZ/femDsGgGWyfSgixsr2cUkKAJCEwkBf7NoymjsCgB6jMNAXjFkAzUNhAACSUBgAgCQUBgAgCYUBAA3Tr6V5KAwAaJh+Lc1DYQAAkrA0CAA0wCCW5mFpkEVMTs8wlwBALa1kaR6WBrkALM8NAOeiMACgYfq1NA+XpDp0XwOcx/LcAIbFUpekKIxFsDw3gGHEGAYAYMWyFobtH7E9bft48ecPL3Lcd2w/VXztH0Q2lucGgHPlPsOYkHQwIkYlHSy2y/xvRFxVfG0bRDDGLADgXLkLY7uke4vv75X0sxmzAKiIfq2FhJXJXRg/FhHPSVLx52WLHHep7Zbtx20vWiq2dxbHtebm5vqRF8AAMA+qmvq+NIjthyVdXrJrzzLe5oqIOG37JyQ9YvsrEfHP3QdFxD5J+6T2XVIXFBgAUKrvhRER1y+2z/a/2V4bEc/ZXivp+UXe43Tx50nbX5L0ZkkLCgNAfQ1iLSSsTO7FB/dLukXS3uLPB7oPKO6cejEiXrK9WtI1kn5/oCkB9N3u8Y1ni4F5UNWUewxjr6Rx28cljRfbsj1m+xPFMa+T1LL9tKRHJe2NiK9mSQsAQyzrGUZEfEvSlpLXW5LeU3z/95LeMOBoADJiHlQ15T7DAIAFGLOoJgoDAJCEwgAAJKEw+oSZqgCahsLoE2aqAmgaCgMAkCT3xL1GYaYqgCbjiXt9wkxVAHXEE/cAACtGYfQJM1UBNA2F0SeMWQBoGgoDAJCEwgAAJKEwAABJKAwAjcPSPP1BYQBoHJbm6Q8KA6X4hAagG4WBUnxCQ91MTs9oZGLq7JI8899fyIcfPjCVYy0pAI2we3zj2flPK12a566Dx5lLVYLCwFksnghgKSw+iFIsnog6m5yeWfaHnO4PTPOG7QPTUosPcoYBoHEu5Bd8Ly9pNRWD3ijF4okAulEYDbXSuzyG6RQc6MYHpnIURkNxWyxw4fjAVI7CAAAkYdC7QbgtFkA/cVttQ3GXB4ALwTO9AaBGqro0CYXRUNzlAdRXVW9aoTAaijELAL3GoDcAVEAdblph0BsAKibnTSsMegMAVixrYdi+yfYR26/YLm204ritto/ZPmF7YpAZAWDQqnrTSu4zjMOS3inpscUOsL1K0t2SbpC0SdLNtjcNJh4ADF5Vxiy6ZR30joijkmR7qcM2SzoRESeLY++XtF3SV/seEABwVu4zjBTrJJ3q2J4tXlvA9k7bLdutubm5gYQDgGHR9zMM2w9Lurxk156IeCDlLUpeK721KyL2Sdonte+SSg4JADivvhdGRFy/wreYlbShY3u9pNMrfE8A6JsLeURsHdThktQTkkZtX2n7Ekk7JO3PnAkAFlXVpT1WKvdtte+wPSvprZKmbB8oXn+t7YckKSLOSLpN0gFJRyV9NiKO5MoMAMOKmd4A0APdS3vMq9LSHimWmulNYQBAj9X5eTQsDVJDVV0PH8DwojAqqqmDZsAwqOrSHitFYQBAj9VpzGI5eB5GhdRhPXwAw4tB74qq86AZgPpi0BsAsGIURkU1ddAMQH1RGBXFmAWAqqEwAABJKAwAQBIKAwCQhMIAACShMAAASRo7cc/2nKSvr+AtVkv6Zo/i9BK5lodcy0Ou5Wlirh+PiDVlOxpbGCtlu7XYbMecyLU85Foeci3PsOXikhQAIAmFAQBIQmEsbl/uAIsg1/KQa3nItTxDlYsxDABAEs4wAABJKAwAQBIK4zxs/4btsL06d5Z5tj9k+xnbT9n+ou3XViDTR2x/rcj1BduvyZ1pnu2bbB+x/YrtrLdA2t5q+5jtE7YncmbpZPse28/bPpw7SyfbG2w/avto8d9wV+5MkmT7Utv/aPvpItfv5M40z/Yq2/9k+8FevzeFsQTbGySNS/qX3Fm6fCQi3hgRV0l6UNIHcweSNC3p9RHxRkkzkm7PnKfTYUnvlPRYzhC2V0m6W9INkjZJutn2ppyZOnxK0tbcIUqckfTrEfE6SVdLel9F/p29JOm6iHiTpKskbbV9deZM83ZJOtqPN6YwljYp6QOSKnVnQET8Z8fm96sC+SLiixFxpth8XNL6nHk6RcTRiDiWO4ekzZJORMTJiHhZ0v2StmfOJEmKiMckvZA7R7eIeC4iniy+/y+1fxGuy5tKirb/LjYvLr6y/z20vV7SjZI+0Y/3pzAWYXubpG9ExNO5s5Sx/Xu2T0n6BVXjDKPTr0j669whKmidpFMd27OqwC+/urA9IunNkr6cN0lbcennKUnPS5qOiCrk+kO1P+S+0o83v6gfb1oXth+WdHnJrj2SfkvSTw820XctlS0iHoiIPZL22L5d0m2S7sidqThmj9qXET7d7zzLzVYBLnkt+6fSOrD9akl/KenXus6ws4mI70i6qhiv+4Lt10dEtjEg22+X9HxEHLJ9bT9+xlAXRkRcX/a67TdIulLS07al9uWVJ21vjoh/zZmtxF9ImtIACuN8mWzfIuntkrbEgCf4LOPfV06zkjZ0bK+XdDpTltqwfbHaZfHpiPh87jzdIuI/bH9J7TGgnDcNXCNpm+23SbpU0g/a/vOI+MVe/QAuSZWIiK9ExGURMRIRI2r/RX/LoMrifGyPdmxuk/S1XFnm2d4q6TclbYuIF3PnqagnJI3avtL2JZJ2SNqfOVOluf2J7ZOSjkbEH+TOM8/2mvk7AW2/StL1yvz3MCJuj4j1xe+sHZIe6WVZSBRGXe21fdj2M2pfNqvCrYZ/JOkHJE0Xt/t+PHegebbfYXtW0lslTdk+kCNHcVPAbZIOqD14+9mIOJIjSzfb90n6B0k/ZXvW9rtzZypcI+ldkq4r/r96qvgEndtaSY8WfwefUHsMo+e3sVYNS4MAAJJwhgEASEJhAACSUBgAgCQUBgAgCYUBAEhCYQAAklAYAIAkFAYwQMWzHcaL73/X9kdzZwJSDfVaUkAGd0i60/Zlaq+8ui1zHiAZM72BAbP9t5JeLena4hkPQC1wSQoYoGIl5LWSXqIsUDcUBjAgtteq/ZyQ7ZL+x/bPZI4ELAuFAQyA7e+T9Hm1n099VNKHJP121lDAMjGGAQBIwhkGACAJhQEASEJhAACSUBgAgCQUBgAgCYUBAEhCYQAAkvw/OflzCMK/364AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def f(x): \n", " return np.cos(x) + 0.2*np.random.normal(size=(x.shape))\n", "\n", "X = np.linspace(-4,4,20).reshape(-1,1)\n", "y = f(X)\n", "\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let us use the work from above and fit polynomials to this dataset." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "K = 2 # Define the degree of the polynomial we wish to fit\n", "\n", "Phi = poly_features(X, K) # N x (K+1) feature matrix\n", "\n", "theta_ml = nonlinear_features_maximum_likelihood(Phi, y) # maximum likelihood estimator\n", "\n", "# test inputs\n", "Xtest = np.linspace(-5,5,100).reshape(-1,1)\n", "ytest = f(Xtest) # ground-truth y-values\n", "\n", "# feature matrix for test inputs\n", "Phi_test = poly_features(Xtest, K)\n", "\n", "y_pred = Phi_test @ theta_ml # predicted y-values\n", "\n", "# plot\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.plot(Xtest, y_pred)\n", "plt.plot(Xtest, ytest)\n", "plt.legend([\"data\", \"prediction\", \"ground truth observations\"])\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Questions:\n", "1. Try out different degrees of polynomials. \n", "2. Based on visual inspection, what looks like the best fit?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now look at a more systematic way to assess the quality of the polynomial that we are trying to fit. For this, we compute the root-mean-squared-error (RMSE) between the $y$-values predicted by our polynomial and the ground-truth $y$-values. The RMSE is then defined as\n", "$$\n", "\\text{RMSE} = \\sqrt{\\frac{1}{N}\\sum_{n=1}^N(y_n - y_n^\\text{pred})^2}\n", "$$\n", "Write a function that computes the RMSE." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def RMSE(y, ypred):\n", " rmse = np.sqrt(np.mean((y-ypred)**2)) ## SOLUTION\n", " return rmse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now compute the RMSE for different degrees of the polynomial we want to fit." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "K_max = 20\n", "rmse_train = np.zeros((K_max+1,))\n", "\n", "for k in range(K_max+1):\n", " \n", " \n", " # feature matrix\n", " Phi = poly_features(X, k)\n", " \n", " # maximum likelihood estimate\n", " theta_ml = nonlinear_features_maximum_likelihood(Phi, y)\n", " \n", " # predict y-values of training set\n", " ypred_train = Phi @ theta_ml\n", " \n", " # RMSE on training set\n", " rmse_train[k] = RMSE(y, ypred_train)\n", " \n", "\n", "plt.figure()\n", "plt.plot(rmse_train)\n", "plt.xlabel(\"degree of polynomial\")\n", "plt.ylabel(\"RMSE\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Question: \n", "1. What do you observe?\n", "2. What is the best polynomial fit according to this plot?\n", "3. Write some code that plots the function that uses the best polynomial degree (use the test set for this plot). What do you observe now?" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# WRITE THE PLOTTING CODE HERE\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "\n", "# feature matrix\n", "Phi = poly_features(X, 5)\n", "\n", "# maximum likelihood estimate\n", "theta_ml = nonlinear_features_maximum_likelihood(Phi, y) \n", "\n", "# feature matrix for test inputs\n", "Phi_test = poly_features(Xtest, 5)\n", "\n", "ypred_test = Phi_test @ theta_ml\n", "\n", "plt.plot(Xtest, ypred_test) \n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\")\n", "plt.legend([\"data\", \"maximum likelihood fit\"]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The RMSE on the training data is somewhat misleading, because we are interested in the generalization performance of the model. Therefore, we are going to compute the RMSE on the test set and use this to choose a good polynomial degree." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEGCAYAAAB7DNKzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU9bn48c+TyQYJCavKFoiKyCI7KIoWFxQUpFVUrPZiXdCqdblXf9XbVlu9vfXWVq2iIirFlcUdFJW6IC4om4IssqMEkFUIAbJM5vn9cU7CECYbmZkzOXner9e85qzf8+QkmWe+33PO9yuqijHGGFOZJK8DMMYYk9gsURhjjKmSJQpjjDFVskRhjDGmSpYojDHGVCnZ6wBioWXLltqxY0evwzDGmHpj4cKFO1S1VaR1vkwUHTt2ZMGCBV6HYYwx9YaIfF/ZOmt6MsYYUyVLFMYYY6rkq0QhIiNEZMKePXu8DsUYY3zDV9coVHUGMKNfv37XVVxXUlJCXl4ehYWFHkTW8KSnp9OuXTtSUlK8DsUYU0e+ShRVycvLo0mTJnTs2BER8TocX1NVdu7cSV5eHrm5uV6HY4ypI181PVWlsLCQFi1aWJKIAxGhRYsWVnszxid8lSiqu0ZhSSJ+7Fwb4x++anqq6hqFMcbUO6EQlOx3XsX73Pf9ULLPfd8PxQUHl5UcgLP+CFH+ouarRJHIdu/ezcsvv8yNN95Y633PP/98Xn75ZZo2bVrpNvfccw9nnHEG55xzTl3CrLU333yTE044ga5du8b1uMb40py/w8LnDiaC4IHa7S8BOONOSGkU1bAsUcTJ7t27eeKJJyImitLSUgKBQKX7zpw5s9ry77vvvjrFd6TefPNNhg8fbonCmLrauxU++Rsc3RXaDIHUxpCS4b43htSMCu8R1gdSo16bAJw7VPz26tu3r1a0fPnyw5bF02WXXabp6enas2dPveOOO/Tjjz/WwYMH6+WXX65dunRRVdWRI0dqnz59tGvXrvrUU0+V79uhQwfdvn27rl+/Xk888US99tprtWvXrjpkyBDdv3+/qqqOGTNGX3nllfLt77nnHu3du7d2795dV6xYoaqq27Zt03POOUd79+6tY8eO1ZycHN2+ffshcQaDQR0zZox269ZNu3fvrg899JCqqq5Zs0bPO+887dOnjw4aNEhXrFihn3/+uTZr1kw7duyoPXv21DVr1hxSltfn3Jh6ZdYfVf/UVHXHmuq3jQFggVbymZrwNQoRSQLuB7JwfpDn6lrmn2csY/nm/DrHFq5rmyzuHdGt0vUPPPAAS5cu5ZtvvgFg9uzZzJs3j6VLl5bfQjpx4kSaN2/OgQMH6N+/PxdffDEtWrQ4pJzVq1czefJknn76aS699FJee+01rrzyysOO17JlSxYtWsQTTzzB3//+d5555hn+/Oc/c9ZZZ3H33Xfz3nvvMWHChMP2++abb9i0aRNLly4FnJoQwNixYxk/fjydOnXiq6++4sYbb+Sjjz7iwgsvZPjw4YwaNerITpwxBg7shvkTodsvoMVxXkdzGE/uehKRiSKyTUSWVlg+VERWisgaEbnLXTwSaAuUAHnxjjWWBgwYcMhzBo8++ig9e/bklFNOYePGjaxevfqwfXJzc+nVqxcAffv2ZcOGDRHLvuiiiw7b5rPPPmP06NEADB06lGbNmh2237HHHsu6dev47W9/y3vvvUdWVhYFBQV88cUXXHLJJfTq1Yvrr7+eLVu21OVHN8aEm/80FO+FQbd7HUlEXtUoJgHjgOfLFohIAHgcGIKTEOaLyHSgMzBXVZ8SkVeBD+t68Kq++cdTRkZG+fTs2bP54IMPmDt3Lo0bN2bw4MERn0NIS0srnw4EAhw4EPliV9l2gUCAYDAIOM2M1WnWrBmLFy/m/fff5/HHH2fatGk88sgjNG3atLw2ZIyJouL98OWT0OlcOOYkr6OJyJMaharOAXZVWDwAWKOq61S1GJiCU5vIA35ytymtrEwRGSsiC0Rkwfbt22MRdp00adKEvXv3Vrp+z549NGvWjMaNG/Pdd9/x5ZdfRj2GQYMGMW3aNABmzZrFTz/9dNg2O3bsIBQKcfHFF3P//fezaNEisrKyyM3N5ZVXXgGchLN48eIa/VzGmGp8/QLs3wmD/tPrSCqVSA/ctQU2hs3nucteB84TkceAOZXtrKoTgD8Di1JTU2MZ5xFp0aIFp512Gt27d+fOO+88bP3QoUMJBoP06NGDP/7xj5xyyilRj+Hee+9l1qxZ9OnTh3fffZfWrVvTpEmTQ7bZtGkTgwcPplevXlx11VX89a9/BeCll17i2WefpWfPnnTr1o233noLgNGjR/Pggw/Su3dv1q5dG/WYjfG1YDF8/ijkDIQOA72OplJSk+aImBxYpCPwtqp2d+cvAc5T1Wvd+V8BA1T1t7Utu1+/flpx4KIVK1bQpUuXuoZdrxUVFREIBEhOTmbu3Ln85je/iWlzkp1zY6rx9Uvw1o1wxavQaYinoYjIQlXtF2ldIt31lAe0D5tvB2yuTQEiMgIYcfzxx0czLt/44YcfuPTSSwmFQqSmpvL00097HZIxDVcoBJ8/4lyXOD6+D8rWViIlivlAJxHJBTYBo4FfehuSv3Tq1Imvv/7a6zCMMQDfvQ07VsGoibF5SC6KvLo9djIwF+gsInkico2qBoGbgfeBFcA0VV1Wm3JVdYaqjs3Ozo5+0MYYEy2q8NlD0PxY6Ppzr6Oplic1ClW9vJLlM4Hq+6uohDU9GWPqhXUfw+avYcQ/Iany7nsSRSLd9VRnVqMwxtQLnz4ETVpDz4jfmROOrxKFjZltjEl4eQtgw6cw8GZITqt++wTgq0SRyDWKst5jj9QjjzzC/v376xzH7Nmz+eKLL+pcjjHmCH36EDRqBn2v8jqSGvNVokhkliiMMWxbASvfgQHXQ1qm19HUmK8SRSI3Pd11112sXbuWXr16lT+Z/eCDD9K/f3969OjBvffeC8C+ffu44IIL6NmzJ927d2fq1Kk8+uijbN68mTPPPJMzzzwzYtldu3alR48e3HHHHQBs376diy++mP79+9O/f38+//xzNmzYwPjx43n44Yfp1asXn376afxOgDEGPnvEGUPi5Ou9jqRWEuk5ijrTmg6F+u5d8OO30T34MSfBsAcqXV2xm/FZs2axevVq5s2bh6py4YUXMmfOHLZv306bNm145513AKcPqOzsbB566CE+/vhjWrZseUi5u3bt4o033uC7775DRMq7Bb/11lu5/fbbGTRoED/88APnnXceK1as4IYbbiAzM7M8oRhj4uSnDfDtK3DyDdC4udfR1IqvEkV9MmvWLGbNmkXv3r0BKCgoYPXq1Zx++unccccd/O53v2P48OGcfvrpVZaTlZVFeno61157LRdccAHDhw8H4IMPPmD58uXl2+Xn51vnfcZ46YvHQJJg4E1eR1JrvkoUNX6Ooopv/vGiqtx9991cf/3hVdCFCxcyc+ZM7r77bs4991zuueeeSstJTk5m3rx5fPjhh0yZMoVx48bx0UcfEQqFmDt3Lo0aRXfsXGPMESjYBl+/CD1HQ3Zbr6OpNV9do0jku54qdsd93nnnMXHiRAoKCgCn19Zt27axefNmGjduzJVXXskdd9zBokWLIu5fpqCggD179nD++efzyCOPlDdtnXvuuYwbN658u7Ll1i24MR748gkoLU7YgYmq46saRSIL72Z82LBhPPjgg6xYsYKBA52uhTMzM3nxxRdZs2YNd955J0lJSaSkpPDkk08CzlCkw4YNo3Xr1nz88cfl5e7du5eRI0dSWFiIqvLwww8Dzmh5N910Ez169CAYDHLGGWcwfvx4RowYwahRo3jrrbd47LHHqm3aMsbUUeEemP8sdB2ZkMOc1oRn3YzHknUznhjsnBsDfPoP+PA+uH4OtO7pdTSVqqqbcV81PRljTEIp3g9zn3C6EU/gJFEdXyWKRH6OwhjTAH39IuzfkdDDnNaErxJFdRez/djMlqjsXJsGr7QEvngU2p8MHU71Opo68VWiqEp6ejo7d+60D7A4UFV27txJenq616EY451vX4U9G+H0/0r4gYmq02DuemrXrh15eXls377d61AahPT0dNq1a+d1GMZ4IxSCzx6Go7tDp3O9jqbOEj5RiMhg4H5gGTBFVWcfSTkpKSnk5uZGMTJjjKnEypmwYyVc/Gy9r02Ad0OhThSRbSKytMLyoSKyUkTWiMhd7mIFCoB0IC/esRpjTK2oOrfENutYL4Y5rQmvrlFMAoaGLxCRAPA4MAzoClwuIl2BT1V1GPA74M9xjtMYY2pn/SeweRGcdisEEr7RpkY8SRSqOgfYVWHxAGCNqq5T1WJgCjBSVUPu+p+A+jEclDGm4frsYcg8Gnr+0utIoiaR0l1bYGPYfB5wsohcBJwHNAXGRdoRQETGAmMBcnJyYhimMcZUYtd6WDcbzvoDpPjnrr9EShSRrvioqr4OvF7dzqo6QUS2ACNSU1P7Rj06Y4ypzpJpgECP0V5HElWJ9BxFHtA+bL4dsNmjWIwxpnZUYfFk6DgImravfvt6JJESxXygk4jkikgqMBqYXpsCErmbcWOMz+XNh5/WO2NO+IxXt8dOBuYCnUUkT0SuUdUgcDPwPrACmKaqy2pZrvX1ZIzxxuLJkNwIulzodSRR12C6GTfGmJgJFsHfT3B6iR31rNfRHJEG08241SiMMZ5YPQsKd0PPy72OJCZ8lSjsGoUxxhOLp0DGUXDsYK8jiQlfJQqrURhj4m7/Llj1PvS41DdPYlfkq0RhNQpjTNwtfQ1CJdDjMq8jiRlfJQqrURhj4m7JVDiqKxxzkteRxIyvEoXVKIwxcbVzrfP8RM/RvuhOvDK+ShTGGBNXi6cAAidd4nUkMeWrRGFNT8aYuAmFYMkU506nrDZeRxNTvkoU1vRkjImbjV/C7h982WVHRb5KFMYYEzeLp0BKBpw43OtIYs4ShTHG1FZJISx7E7qMgLRMr6OJOUsUxhhTW6vehaI90NO/z06E81WisIvZxpi4WDwVmrSG3J95HUlc+CpR2MVsY0zM7dsBa/7t3BKbFPA6mrjwVaIwxpiYW/oahIK+7Sk2EksUxhhTG4snO911HN3V60jixhKFMcbU1PZVsPnrBlWbgHqSKEQkQ0QWioj/b1g2xiSuJVNAkqD7KK8jiSuvxsyeKCLbRGRpheVDRWSliKwRkbvCVv0OmBbfKI0xJkwoBEumwXFnQZOjvY4mrryqUUwChoYvEJEA8DgwDOgKXC4iXUXkHGA5sDXeQRpjTLnvP4c9GxtcsxOAJ8MxqeocEelYYfEAYI2qrgMQkSnASCATyMBJHgdEZKaqhiqWKSJjgbEAOTk5sQveGNMwLZ4CqU2g8/leRxJ3iTRuX1tgY9h8HnCyqt4MICJXATsiJQkAVZ0gIluAEampqX1jHawxpgEp3g/L34KuIyG1sdfRxF0iXcyONOqHlk+oTlLVt6sqwB64M8bExMqZULy3wXTZUVEiJYo8oH3YfDtgc20KsC48jDExsXgKZLWDDoO8jsQTiZQo5gOdRCRXRFKB0cB0j2MyxjR0e7fC2o+gx6WQlEgfmfHj1e2xk4G5QGcRyRORa1Q1CNwMvA+sAKap6rLalGtNT8aYqFv6KmhpgxigqDJe3fUU8f4yVZ0JzDzSckVkBDDi+OOPP9IijDHmUIunQJve0Kqz15F4xlf1KKtRGGOiauty+HEJ9Gi4tQnwWaKwi9nGmKhaMgWSkqH7xV5H4ilfJQqrURhjoiZUCktegePPgcxWXkfjKV8lCmOMiZr1c2Dv5gZ9EbuMrxKFNT0ZY6JmyVRIy4YThnkdied8lSis6ckYExXF+2D5dOg2ElLSvY7Gc75KFMYYExUr3oaSfQ2yp9hIfJUorOnJGBMVS6ZA0xxof4rXkSQEXyUKa3oyxtTJT9/Du7+DdbOdZycaaJcdFSVSN+PGGOONzd/AF4/CsjedoU57jIZTb/Y6qoRhicIY0zCpwtoP4fNHYf0nzqBEA2+Ek38D2W29ji6hWKIwxjQspSWw9DX44jHYuhSatIYh90HfqyDdmq0j8VWisE4BjTGVKsyHRc/Bl09C/iZo1QVGPgEnXQLJqV5Hl9BEVavfqp7p16+fLliwwOswjDGJIH8LfPUkLJgERXug4+lw6i3QaQhIpIE1GyYRWaiq/SKt81WNwhhjym37zmleWjLVGU+iy4Vw2i3Qtq/XkdU7liiMMf4RKoXVs2DeBGdUuuRGzrWHgTdB81yvo6u3Ej5RiEgX4FagJfChqj7pcUjGmESzfxcseh4WPAu7f4AmbeDM30O/ayCjhdfR1XueJAoRmQgMB7apavew5UOBfwIB4BlVfUBVVwA3iEgS8LQX8RpjEtTmb2De085wpcFC6DAIhtwPJ14AgRSvo/MNr2oUk4BxwPNlC0QkADwODAHygPkiMl1Vl4vIhcBd7j7GmIYsWAzL33Kal/LmQUpjp0+mAdfB0d28js6XvBoze46IdKyweACwRlXXAYjIFGAksFxVpwPTReQd4OV4xmqMSRD5m2HBv2DhJNi3DZofB0MfcJJEo6ZeR+driXSNoi2wMWw+DzhZRAYDFwFpwMzKdhaRscBYgJycnNhFaYyJH1X4/nOneWnFDNAQnHCeU3s49izriylOEilRRLqhWVV1NjC7up1VdYKIbAFGpKam2v1vxtRHxfthxyrYvhK2fwer3odtyyC9qdO9Rr9r7O4lD1SZKETkLFX9yJ3OVdX1YesuUtXXoxhLHtA+bL4dsDmK5RtjEkVh/qEJoex99w+A+xBwUjIc0wMufAy6j4LUxp6G3JBV+WS2iCxS1T4VpyPN1/rAzjWKt8vuehKRZGAVcDawCZgP/FJVl9W2bHsy25gEUVIIW745NBlsX+l0oVEmkAotT4BWnaHViQffmx9rdy7FUV2ezJZKpiPN1yagycBgoKWI5AH3quqzInIz8D7O7bETa5skrK8nYxKIKrzwC/jhC2c+pbGTEDqeDq1OcJPCidC0AwQSqRXcVFTdb0crmY40X2OqGnF8QVWdSRUXrI0x9UjefCdJnH4H9PkPyG5vF5/rqeoSxbEiMh2n9lA2jTufcFeUVHUGMKNfv37XeR2LMQ3eV09BWjYMuh3SMr2OxtRBdYliZNj03yusqzjvOWt6MiZB7P0Rlr8JA8ZakvCBKhOFqn4SPi8iKUB3YJOqbotlYEfCahTGJIiFz0EoCP2v9ToSEwVVNhiKyHgR6eZOZwOLcbrd+FpEIl5nMMY0cKUlsGAiHD8EWhzndTQmCqq7snR62J1HvwZWqepJQF/g/8U0siMgIiNEZMKePXu8DsWYhmvFdCj40Wl2Mr5QXaIoDpseArwJoKo/xiyiOlDVGao6Njvbxr01xjPznoZmHeH4c7yOxERJdYlit4gMF5HewGnAe1D+cFyjWAdnjKlntiyBH+ZC/+vsVlgfqe6up+uBR4FjgNvCahJnA+/EMrAjYXc9GeOx+U87D9b1vsLrSEwUVdmFR31lXXgY44H9u+ChrtDzMhjxT6+jMbV0xF14iMijVa1X1VvqEpgxxke+fhGCB5xmJ+Mr1TU93QAsBabh9OR6xP07GWN8LFQK85+BDqfBMd2r397UK9UlitbAJcBlQBCYCrymqj/FOjBjTD2y+t+w+3sYcp/XkZgYqPK2BFXdqarjVfVM4CqgKbBMRH4Vj+Bqy56jMMYj856CJm3gxAu8jsTEQI3uXxORPsBtwJXAu8DCWAZ1pOw5CmM8sGM1rP0I+l1t40f4VHUXs/8MDAdWAFOAu1U1GI/AjDH1xPxnICkF+o7xOhITI9Vdo/gjsA7o6b7+V0TAuaitqtojtuEZYxJa0V745mXo9gvIPMrraEyMVJcoEm7MCWNMAlk8BYry4eTrvY7ExFB13Yx/H2m5iASA0UDE9dEkIj8HLgCOAh5X1VmxPqYxpgZUnX6d2vSGtn29jsbEUHXdjGeJyN0iMk5EzhXHb3Gaoy490oOKyEQR2SYiSyssHyoiK0VkjYjcBaCqb6rqdTh3XV12pMc0xkTZ+jmwY6XTS6zYI1Z+Vt1dTy8AnYFvgWuBWcAoYKSqjqxqx2pMAoaGL3BrKY8Dw4CuwOUi0jVskz+4640xiWDeBGjcArpd5HUkJsaqHTPbHX8CEXkG2AHkqOreuhxUVeeISMcKiwcAa1R1nXu8KcBIEVkBPAC8q6qLKitTRMYCYwFycnLqEp4xpjq7f4CVM+G02yAl3etoTIxVlyhKyiZUtVRE1tc1SVShLbAxbD4POBn4LXAOkC0ix6vq+Eg7q+oEEdkCjEhNTbUGU2NiacFE573f1d7GYeKiukTRU0Ty3WkBGrnzZbfHZkUxlkiNnKqqj+J0dV4tGzPbmDgoKXTGxO58PjRt73U0Jg6qu+spEK9AcGoQ4X917XA6IqwxG4/CmDhY9joc2GVDnTYgiTQE1Xygk4jkikgqzu230z2OyRgTThW+egpanQi5Z3gdjYkTTxKFiEwG5gKdRSRPRK5xuwa5GXgfp8uQaaq6rDblWl9PxsRY3gLY8g0MuM5uiW1AqrtGEROqenkly2cCM4+0XGt6MibG5k2AtCzoMdrrSEwcJVLTU51ZjcKYGCrYBsvegF6/hLRMr6MxceSrRGHjURgTQwsnQajEhjptgHyVKKxGYUyMlJY4z04cdza0tKbdhsZXicJqFMbEyHdvw94tdktsA+WrRGE1CmNiZN7T0LQDdBridSTGA75KFMaYGPhxKXz/OfS/FpLi+QyuSRSWKIwxVZs3AZIbQe8rvY7EeMRXicKuURgTRbvWw1s3w9cvQo9LoHFzryMyHvFVorBrFMZEwa518OZN8FhfWDLNaXIacr/XURkPefJktjEmAe1cC5/+wxkHO5DidNNx2m2Q1drryIzHLFEY09DtXAtz/g5LpjoJ4uTr4bRbockxXkdmEoSvEoX19WRMLexYA3MehG+nQSANTr7BTRBHex2ZSTC+ShQ2cJExNbBjtZsgXnESxCk3wqm3WIIwlfJVojDGVGH7KidBLH0VktNh4E1Ogsg8yuvITIKzRGGMn4VCzsNyC/8FS1+HlEYw8GY3QbTyOjpTT1iiMMaPti53Lk5/+wrkb4LUTDjtFidBZLT0OjpTzyR8ohCRY4HfA9mqOsrreIxJWPmb4dtXnWcftn4LEoDjz4Eh90Hn8yG1sdcRmnrKk0QhIhOB4cA2Ve0etnwo8E8gADyjqg+o6jrgGhF51YtYjUlohfmwYoZTe1g/B1Bo2w+GPQjdfmHNSyYqvKpRTALGAc+XLRCRAPA4MATIA+aLyHRVXe5JhMYkqmAxrP3QSQ4r34VgITTLhZ/9DnpcCi2O8zpC4zNejZk9R0Q6Vlg8AFjj1iAQkSnASKBGiUJExgJjAXJycqIWqzEJQRXy5jvJYenrcGAXNGoOvX8FPS6Ddv1AxOsojU8l0jWKtsDGsPk84GQRaQH8BegtIner6l8j7ayqE4AJAP369dNYB2tM3OzbCW/dBKvedW5r7Xy+kxyOP9t5ktqYGEukRBHp65Cq6k7ghhoVYE9mG79Z+zG8cYNTgxhyP/S9CtKzvI7KNDCJlCjygPZh8+2AzR7FYoy3gsXw0f3wxaPQsjNc+Socc5LXUZkGKpG6GZ8PdBKRXBFJBUYD02tTgHUzbnxhxxp4doiTJPpdDWNnW5IwnvIkUYjIZGAu0FlE8kTkGlUNAjcD7wMrgGmquqyW5drARab+UnUGCXrqDNj9PVz2Igx/2J5/MJ7z6q6nyytZPhOYGedwjPHegd3w9u2w7HXoeDr84inIbut1VMYAidX0VGfW9GTqpe/nwvhBsGI6nH0v/MdbliRMQkmki9l1Znc9mXqlNOj05jrnb9A0B66eBe36eh2VMYexGoUxXtj9A0y6AD55wHkm4vpPLUmYhOWrGoUx9cLS12DG7aAhuOgZ6HGJ1xEZUyVfJQprejIJqzQIBVvh4/+Fb16Edv3hoqehea7XkRlTLV8lChsK1cSVKhTuhoJtThIof996+LJ9OwAFBM640+nAz7rfMPWErxKFMTEVLIYP/ww/zD2YAEqLD98ukAaZRztDjDbt4NQeyubb9YPWPeMfuzF14KtEYU1PJmaK9sLUX8G6jyH3DOdZh8yjDiaAzKMPTqdnW0+uxld8lSis6cnERMF2eGkU/PgtjHwCel/hdUTGxJWvEoUxUbdrPbx4EeRvgcsnwwnneR2RMXFnicKYymxZDC+OglAJjJkO7Qd4HZExnvDVA3fGRM26T+BfF0AgFa5+35KEadB8lSis91gTFUtfhxcvhqbt4dp/Q6vOXkdkjKd8lSisCw9TZ189Ba9e7dzS+uuZkNXG64iM8ZxdozAGnIfnPrwPPnsIThwOFz8DKY28jsqYhGCJwpjSIMy41elao+9VcMFDkBTwOipjEoYlCtOwFe+HV38Nq96Dn90Fg++yh+WMqSDhE4WIZABPAMXAbFV9yeOQjF/s3wUvXwabFji1iP7XeB2RMQnJqzGzJ4rINhFZWmH5UBFZKSJrROQud/FFwKuqeh1wYdyDNf60eyNMHOo8K3HJc5YkjKmCVzWKScA44PmyBSISAB4HhgB5wHwRmQ60A751NyuNb5gmIeVvgcWTnemURpCc7r6nQXIjSEl33pPTKqxPd1471zi3vxbvg1+9AR1P8/bnMSbBeZIoVHWOiHSssHgAsEZV1wGIyBRgJE7SaAd8QxU1IBEZC4wFyMnJiX7QfqMKwSIo2Q+hUkCdgXQ05KzTUNiysPeKy5q2h7Qm8Yt5yVR49/9BYR2flck8Bq5+F47uFp3YjPGxRLpG0RbYGDafB5wMPAqME5ELgBmV7ayqE0RkCzAiNTU1sceULBvHYM8m2JMH+XlOx3PoEZYXgmAhlByAkkLnwz/ovpccOPQVDJs+0uOFS28KZ/0B+v4aAjH8c9q7Fd6+DVbOhPYnw8jHIbud+zO5P3uwyP35Cp33YFHk9arQ83InyRljqpVIiSLSrSaqqvuAX8c7mDopOQD5m2HPxkOTwZ68g/Ml+6J4QDnYvJLS2Gl6KZtOTodGzd1ljcOaYtzplEaQlOzc6SzOvmQAABQgSURBVCNJTlmSdHA+4jJxlmkIFj0PM++ABRNh6ANw7M+i+HPhfKh/+wrMvNP5wD/3L3DKbw7evmrPOhgTc4mUKPKA8K947YDNtSmgrt2Mf/qPX9Ko8EcABEXcb9xOBgubVxAp+zaulE02SSrkaN1Bo5KfDi884yjnG3CrE+C4s5zp7LaQ3R6y2jrjGNTHe/e7XwzfvQ3v/zc8fyF0uRDO/R9o1qHuZRdsg7dvd8pv1x9+/iS07FT3co0xtZJIiWI+0ElEcoFNwGjgl7UpoK4DFzVN2keGFHAwNYCWfXvGSRVly8tSSfj8j8EsPi/KYbO24EDj1hzT/jg6n9CF3t27kpWZeUQxJTwR6DICjj8HvhjnPNm8ehacegsMug1SM2pfpiosfc2pRRTvgyH3wcCb62ciNcYHRDUK7dS1PajIZGAw0BLYCtyrqs+KyPnAI0AAmKiqfzmS8vv166cLFiyIVri1snHXfuas3s6cVdv5fM1OCoqCBJKEPjlNOaNTK844oRUntc0mKcmnD3Xt2QT/vgeWvurUlIbc59Q6avoQW8F2eOd2WDED2vZ1ahHWKZ8xMSciC1W1X8R1XiSKWAmrUVy3evVqr8OhpDTE1z/s5pNV25izagffbnLu1Gmekcqg41vysxNacfoJLTmqSbrHkcbA93Odu5N+XAI5p8KwB6ofK3rp6871jqK9cOZ/w8DfxvYCuTGmXINJFGW8rFFUZUdBEZ+t3sGcVduZs3o7OwqKAejSOosBHZuRlnKwaeWQ798SPnlwpuxLenpygCtOyaFlZloMoz8CoVL4+gWns739u6DvGDjrj5DR8tDt9u2Ad/4Llr8JbXo7tYijungTszENVINJFIlWo6hKKKQs35LPJ6ucZqqlm/YQcn8VGnbbaviv55DfVNhMcWmIDi0a8/zVA+jQ4giuCcTagd3wyf/BvAnONYvBd0P/ayGQAsvfgrf/03ku4sy74dRbrRZhjAcaTKIok6g1ilhZ9MNPXDNpPoEk4V9XDeCkdgk6Hsf2lfDeXbD2I2h1onMH04oZ0LqXU4s4uqvXERrTYFWVKHw1cFFDHeGuT04zXv3NqaQlB7hswlw+WbXd65Aia9UZrnwdRr/sPBOx8j048w9w7QeWJIxJYFaj8JGt+YVc9a/5rN66l7+N6sFFfdp5HVLlgsVw4CdocrTXkRhjaEA1iobu6Kx0pl5/CgNym/Of0xYz/pO1JOwXgeRUSxLG1BOWKHwmKz2Ff/26P8N7tOaBd7/jvreXEwolaLIwxtQLvrq9pK5PZvtFWnKAR0f35qgm6Uz8fD3b9hbx0KU9SUu2J5uNMbXnqxqFqs5Q1bHZ2Ql6108cJSUJfxzehf8+/0TeWbKFMRPnkV9Y4nVYxph6yFeJwhxKRBh7xnE8clkvFmz4iUvHz2VrfqHXYRlj6hlLFA3Az3u3ZeJV/dm4az8XPfEFa7YVeB2SMaYe8VWiaKjPUdTEGSe0Yur1AykKljJq/Bcs/D5CV+jGGBOBrxKFXaOoWve22bz2m1Np2iiFK575kg+Wb/U6JGNMPeCrRGGq16FFBq/+5lQ6H92EsS8sYMq8H7wOyRiT4OzJ7AZqX1GQG19axCerttO2aSOSA0JAhKSksPckCIgQSHJeSRGm05KTyEhLJjMtmYy0wMHp1ORDlmemJdM4LZnMVGc+OWDfUYxJJFU9me2r5yhMzWWkJfPMmH48OXstG3buIxRSgiElpEppSCkNUT4dUiVYqpSqUhwMUapavn1RMMS+oiAFRUH2FQWp6bN9aclJNElPpkl6Clll742SaZLmvh+yPMXdNpms9BSy0lPITE+mbOwnqemgSMaYI5LwiUJEjgV+D2Sr6iiv4/GTlEASt5wdvTGoVZ3EUZY0nPfSQxLJvmJnfl9RkL1FQfYWBsk/UMLewhJ+zC9kb2EJ+QeCHCgprXM8ZflDDlnmzAWShBS3dpQSSCI5ICQnlb2HTQeS3Hlnu7IaVXi5B/OURDzuwXlBBPfljCwiIiQJ5dMiznZJZdshJCW5e8uhZUqE41VMmhW3E8KOL2ExhW0rFeJJEuH4ozLp37E5rZok2JgnJi5imihEZCIwHNimqt3Dlg8F/okz5OkzqvpAZWWo6jrgGhF5NZaxmroTEdJTAqSnBOo8iFJJaYi9hcHyxLG3sIT8whLy3cSyr6gURasYr0MPW6Zh430EQ0ppqfNeUhqiNKSUlCrBUIhgSAmWhgi664OhECWlyoGSUoKlIUJ6cMyQ8jLLyy6bP7RqVTYbUmdPVSf2sumyMlVxX4cuc2pqWulx9JBYwn72sO21/NhOuSF1VpQft8I2kRzbKoMBHZszINd5tWvWOPKGxldiXaOYBIwDni9bICIB4HFgCJAHzBeR6ThJ468V9r9aVbfFOEaTgFICSTTPSKV5RqrXoTRYZcmsJBRi+eZ85q3fxfwNu5j57RamzN8IQNumjRiQ25z+bvI4rlWGNQX6UMwvZotIR+DtshqFiAwE/qSq57nzdwOoasUkUbGcV6tqehKRscBYgJycnL7ff/99VOI3xhwqFFJWbt3LvPW7mLd+F1+t38WOgiIAWmSklieNAbnN6dI6q7ypziS2RLuY3RbYGDafB5xc2cYi0gL4C9BbRO6uLKGo6gRgAjh3PUUvXGNMuKQkoUvrLLq0zmLMqR1RVTbs3M+89Tv5yk0e7y37EYAmacnktGhMciCJVPc6UEpyEinuNZ/w6eSAuyxQNu/sk5qcRFpygNTkJFIDSc67+0pzX6mBwGHLUwJJhFduIl2rOny5+zO6zajG4UWiiPT1otIPdlXdCdxQo4Kt91hj4k5EyG2ZQW7LDC7rnwPA5t0HmL/BqW1syy+kpNS5FlRSGuLAgVJK3GtAJaUhSkIhSoIH15ddKyop9fb73jFZ6XRrk0W3Nll0bZNNtzZZtGvWqEE2rXmRKPKA9mHz7YDNHsRhjImRNk0bMbJXW0b2anvEZagqxW7iKCoppbg0RHHQeRUFQxSXhigqCR2yvLi09OD6YKiScsOmw76jhi8vKQ2xelsByzfn8/HKbeW3fWc3SqFr67LkkUW3Ntkc1yrD988FeZEo5gOdRCQX2ASMBn7pQRzGmAQmIqQlB0hLhsw07+7kP1Bcync/5rNss/NavnkPL3z5PUVuIkpLTuLEY5qU1zq6tckip3ljmjZOjdv1GVUl/0CQXfuLyW2ZEfXyY3oxW0QmA4OBlsBW4F5VfVZEzgcewbnTaaKq/iWax7Uns40xsRQsDbFuxz6Wbd7Dsk1lSWQP+YXB8m2ShPI791pkpNEiM5WWmWm0yEilRWbZ/MF1mWnJ5c1aqsr+4lJ2FhSzc19R+fuOgmJ27StmZ0ERO/cVu/NF7NpXXN5Ut+p/hpGaXPsaTlUXs33VhUfYNYrrVq9e7XU4xpgGRFXJ++kAy7fk8+OeQnYWFLGj7EO9oNj9YC9ib1gyCZcaSKJFZipJIuzcV0RhSeSms4zUQHmiaRGWhFpkptEyM5Vh3VtboqgJq1EYYxJVUbDUrRU4yaMskezYV8SOvcWoKi2bpLk1EbcWknmwZtIoNTZ3YyXa7bExY3c9GWMSXVpygNbZjWid3cjrUGrMV5fqbTwKY4yJPl8lCmOMMdHnq0RhQ6EaY0z0+SpRWNOTMcZEn68ShTHGmOjzVaKwpidjjIk+XyUKa3oyxpjo81WiMMYYE32+fDJbRLYDRzpyUUtgRxTDiRaLq3YsrtqxuGrHj3F1UNVWkVb4MlHUhYgsqOwxdi9ZXLVjcdWOxVU7DS0ua3oyxhhTJUsUxhhjqmSJ4nATvA6gEhZX7VhctWNx1U6DisuuURhjjKmS1SiMMcZUyRKFMcaYKjXYRCEiQ0VkpYisEZG7IqxPE5Gp7vqvRKRjHGJqLyIfi8gKEVkmIrdG2GawiOwRkW/c1z2xjss97gYR+dY95mHDB4rjUfd8LRGRPnGIqXPYefhGRPJF5LYK28TlfInIRBHZJiJLw5Y1F5F/i8hq971ZJfuOcbdZLSJj4hDXgyLynft7ekNEmlayb5W/8xjE9ScR2RT2uzq/kn2r/N+NQVxTw2LaICLfVLJvLM9XxM+GuP2NqWqDewEBYC1wLJAKLAa6VtjmRmC8Oz0amBqHuFoDfdzpJsCqCHENBt724JxtAFpWsf584F1AgFOArzz4nf6I89BQ3M8XcAbQB1gatuxvwF3u9F3A/0XYrzmwzn1v5k43i3Fc5wLJ7vT/RYqrJr/zGMT1J+COGvyeq/zfjXZcFdb/A7jHg/MV8bMhXn9jDbVGMQBYo6rrVLUYmAKMrLDNSOA5d/pV4GwRkVgGpapbVHWRO70XWAG0jeUxo2gk8Lw6vgSaikjrOB7/bGCtqh7pE/l1oqpzgF0VFof/DT0H/DzCrucB/1bVXar6E/BvYGgs41LVWaoadGe/BNpF63h1iauGavK/G5O43P//S4HJ0TpeTVXx2RCXv7GGmijaAhvD5vM4/AO5fBv3n2oP0CIu0QFuU1dv4KsIqweKyGIReVdEusUpJAVmichCERkbYX1Nzmksjabyf2AvzhfA0aq6BZx/dOCoCNt4fd6uxqkJRlLd7zwWbnabxCZW0ozi5fk6HdiqqqsrWR+X81XhsyEuf2MNNVFEqhlUvE+4JtvEhIhkAq8Bt6lqfoXVi3CaV3oCjwFvxiMm4DRV7QMMA24SkTMqrPfyfKUCFwKvRFjt1fmqKS/P2++BIPBSJZtU9zuPtieB44BewBacZp6KPDtfwOVUXZuI+fmq5rOh0t0iLKvVOWuoiSIPaB823w7YXNk2IpIMZHNkVeVaEZEUnD+El1T19YrrVTVfVQvc6ZlAioi0jHVcqrrZfd8GvIHTBBCuJuc0VoYBi1R1a8UVXp0v19ay5jf3fVuEbTw5b+4FzeHAFeo2ZFdUg995VKnqVlUtVdUQ8HQlx/PqfCUDFwFTK9sm1uerks+GuPyNNdREMR/oJCK57rfR0cD0CttMB8ruDhgFfFTZP1S0uG2gzwIrVPWhSrY5puxaiYgMwPkd7oxxXBki0qRsGudi6NIKm00H/kMcpwB7yqrEcVDpNz0vzleY8L+hMcBbEbZ5HzhXRJq5TS3nustiRkSGAr8DLlTV/ZVsU5PfebTjCr+m9YtKjleT/91YOAf4TlXzIq2M9fmq4rMhPn9jsbhCXx9eOHfprMK5g+L37rL7cP55ANJxmjLWAPOAY+MQ0yCcKuES4Bv3dT5wA3CDu83NwDKcuz2+BE6NQ1zHusdb7B677HyFxyXA4+75/BboF6ffY2OcD/7ssGVxP184iWoLUILzDe4anGtaHwKr3ffm7rb9gGfC9r3a/TtbA/w6DnGtwWmzLvsbK7u7rw0ws6rfeYzjesH921mC8wHYumJc7vxh/7uxjMtdPqnsbyps23ier8o+G+LyN2ZdeBhjjKlSQ216MsYYU0OWKIwxxlTJEoUxxpgqWaIwxhhTJUsUxhhjqmSJwtQbbu+id3gdR1VE5ES399CvReS4OpY1SURGRSu2Wh57plTSq2zYNhvi+PCi8ZAlCtPgiEgghsX/HHhLVXur6toYHiemVPV8Vd3tdRwmMViiMAlNRH7vjj3wAdA5bPlxIvKe2wHbpyJyYtjyL0VkvojcJyIF7vLBbn/+L+M81IWIXCki89wawFNlCUREzhWRuSKySERecfvXqRhXL/c4ZWM6NBNn/ITbgGtF5OMI+xSIyD/ccj8UkVaVlVVhv7NF5I2w+SEi8npYmX8Rp9PDL0XkaHd5B/cYS9z3HHf5JBF50j0X60TkZ+J0wLdCRCaFHaO8tiAib7rneZnEr3NAk0ii+fSgvewVzRfQF+dDvTGQhfNU6R3uug+BTu70yThdrAC8DVzuTt8AFLjTg4F9QK473wWYAaS4808A/wG0BOYAGe7y3xFh/AGcJ2R/5k7fBzziTv+JSsZUwHmy9gp3+h5gXDVlTcLpPkaA74BW7vKXgRFhZZZN/w34gzs9AxjjTl8NvBlW5hS3zJFAPnASzpfGhUAvd7sNuGMrcPBp30Y43VK0qLiNvfz9shqFSWSnA2+o6n51esqcDuU9aJ4KvCLOaGNP4QzsAjCQg73IvlyhvHmqut6dPhsnEc13yzgbpxuGU3AGhPncXT4G6BBeiIhkA01V9RN30XM4A95UJ8TBTuVeBAbVpCx1PpVfAK50rxsM5GDX4MU4yRGcD/qO7vTAsJ//BZwuIMrMcMv8Fqfb7G/V6YhvWdj+4W4RkbIuUNoDnWrwsxofSfY6AGOqEamPmSRgt6r2qmVZ+8KmBXhOVe8O30BERuAM8nJ5Lcs+ErXpP+dfOLWEQuAVPTjwUIn7oQ9QSuX/0+HHKnLfQ2HTZfOH7C8ig3E6xBuoqvtFZDZOP2imAbEahUlkc4BfiEgjt2fOEeB0HQ6sF5FLoHy87p7uPl8CF7vTo6so+0NglIgc5ZbRXEQ6uPufJiLHu8sbi8gJ4Tuq6h7gJxE53V30K+ATqpeE05QE8Evgs5qWpU4X1puBP+A0H1XnCw7+/FcAn9Vgn0iygZ/cJHEiTo3LNDBWozAJS1UXichUnJ4yvwc+DVt9BfCkiPwBSMFpd1+MczH5RRH5L+AdnJEJI5W93N13logk4fQWepOqfikiVwGTRSTN3fwPOL2VhhsDjBeRxjhjEP+6Bj/SPqCbiCx047qslmW9hHOdYnkNjnULMFFE7gS21zC+SN4DbhCRJcBKnERqGhjrPdb4ivthe0BVVURG41zYjtqYynUhIgWqetgdVLXYfxzwtao+G8WwjKmW1SiM3/QFxomIALtx7vip99xayD7gv7yOxTQ8VqMwxhhTJbuYbYwxpkqWKIwxxlTJEoUxxpgqWaIwxhhTJUsUxhhjqvT/AdIKGDLO18n+AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "K_max = 20\n", "rmse_train = np.zeros((K_max+1,))\n", "rmse_test = np.zeros((K_max+1,))\n", "\n", "for k in range(K_max+1):\n", " \n", " # feature matrix\n", " Phi = poly_features(X, k)\n", " \n", " # maximum likelihood estimate\n", " theta_ml = nonlinear_features_maximum_likelihood(Phi, y)\n", " \n", " # predict y-values of training set\n", " ypred_train = Phi @ theta_ml\n", " \n", " # RMSE on training set\n", " rmse_train[k] = RMSE(y, ypred_train) \n", " \n", " # feature matrix for test inputs\n", " Phi_test = poly_features(Xtest, k)\n", " \n", " # prediction\n", " ypred_test = Phi_test @ theta_ml\n", " \n", " # RMSE on test set\n", " rmse_test[k] = RMSE(ytest, ypred_test)\n", " \n", "\n", "plt.figure()\n", "plt.semilogy(rmse_train) # this plots the RMSE on a logarithmic scale\n", "plt.semilogy(rmse_test) # this plots the RMSE on a logarithmic scale\n", "plt.xlabel(\"degree of polynomial\")\n", "plt.ylabel(\"RMSE\")\n", "plt.legend([\"training set\", \"test set\"]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Questions:\n", "1. What do you observe now?\n", "2. Why does the RMSE for the test set not always go down?\n", "3. Which polynomial degree would you choose now?\n", "4. Plot the fit for the \"best\" polynomial degree." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# WRITE THE PLOTTING CODE HERE\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "k = 5\n", "# feature matrix\n", "Phi = poly_features(X, k)\n", "\n", "# maximum likelihood estimate\n", "theta_ml = nonlinear_features_maximum_likelihood(Phi, y) \n", "\n", "# feature matrix for test inputs\n", "Phi_test = poly_features(Xtest, k)\n", "\n", "ypred_test = Phi_test @ theta_ml\n", "\n", "plt.plot(Xtest, ypred_test) \n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\")\n", "plt.legend([\"data\", \"maximum likelihood fit\"]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Question\n", "If you did not have a designated test set, what could you do to estimate the generalization error (purely using the training set)?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Maximum A Posteriori Estimation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are still considering the model\n", "$$\n", "y = \\boldsymbol\\phi(\\boldsymbol x)^T\\boldsymbol\\theta + \\epsilon\\,,\\quad \\epsilon\\sim\\mathcal N(0,\\sigma^2)\\,.\n", "$$\n", "We assume that the noise variance $\\sigma^2$ is known." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of maximizing the likelihood, we can look at the maximum of the posterior distribution on the parameters $\\boldsymbol\\theta$, which is given as\n", "$$\n", "p(\\boldsymbol\\theta|\\mathcal X, \\mathcal Y) = \\frac{\\overbrace{p(\\mathcal Y|\\mathcal X, \\boldsymbol\\theta)}^{\\text{likelihood}}\\overbrace{p(\\boldsymbol\\theta)}^{\\text{prior}}}{\\underbrace{p(\\mathcal Y|\\mathcal X)}_{\\text{evidence}}}\n", "$$\n", "The purpose of the parameter prior $p(\\boldsymbol\\theta)$ is to discourage the parameters to attain extreme values, a sign that the model overfits. The prior allows us to specify a \"reasonable\" range of parameter values. Typically, we choose a Gaussian prior $\\mathcal N(\\boldsymbol 0, \\alpha^2\\boldsymbol I)$, centered at $\\boldsymbol 0$ with variance $\\alpha^2$ along each parameter dimension." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MAP estimate of the parameters is\n", "$$\n", "\\boldsymbol\\theta^{\\text{MAP}} = (\\boldsymbol\\Phi^T\\boldsymbol\\Phi + \\frac{\\sigma^2}{\\alpha^2}\\boldsymbol I)^{-1}\\boldsymbol\\Phi^T\\boldsymbol y\n", "$$\n", "where $\\sigma^2$ is the variance of the noise." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "def map_estimate_poly(Phi, y, sigma, alpha):\n", " # Phi: training inputs, Size of N x D\n", " # y: training targets, Size of D x 1\n", " # sigma: standard deviation of the noise \n", " # alpha: standard deviation of the prior on the parameters\n", " # returns: MAP estimate theta_map, Size of D x 1\n", " \n", " D = Phi.shape[1] \n", " \n", " # SOLUTION\n", " PP = Phi.T @ Phi + (sigma/alpha)**2 * np.eye(D)\n", " theta_map = scipy.linalg.solve(PP, Phi.T @ y)\n", " \n", " return theta_map" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# define the function we wish to estimate later\n", "def g(x, sigma):\n", " p = np.hstack([x**0, x**1, np.sin(x)])\n", " w = np.array([-1.0, 0.1, 1.0]).reshape(-1,1)\n", " return p @ w + sigma*np.random.normal(size=x.shape) " ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEGCAYAAAB2EqL0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAATD0lEQVR4nO3df4xl513f8fcHuwkQ0hLqTex41x2nWVexjAkwWooslYT1OhvbeEtpKrsqdQhoFQnDEjWCSVciyCGSUaRuUxFBtinIAiMnolh2sw72rPkl/jD1bOqEmI3Hy8rBmwUywSlEpWBt8u0fc+2Ox/fOnvHOPc+9c98vaTT3nPPsme8dy/O55zzneZ5UFZIknc83tC5AkjQdDAxJUicGhiSpEwNDktSJgSFJ6uTi1gWM0yWXXFJzc3Oty5CkqXHixIkvV9WOYce2dWDMzc2xtLTUugxJmhpJvjDqmLekJEmdGBiSpE4MDElSJwaGJKkTA0OS1ImBMSOOLC63LkHSlDMwZsSHH3mqdQmSppyBIUnqZCICI8n+JE8mOZVkYcjxdyZZSfL44OvHWtQ5bY4sLjO3cIy5hWMAL7z29pSklyOtF1BKchGwDOwDzgCPAbdV1Z+safNOYL6q7tjMuefn58uR3qvmFo7x9F03tS5D0oRLcqKq5ocdm4QrjD3Aqao6XVXPAfcCBxrXJElaZxIC43LgmTXbZwb71vuhJJ9N8ptJdo06WZKDSZaSLK2srGx1rVPr0N7drUuQNOUmITAyZN/6+2T/A5irqmuB48Ddo05WVUerar6q5nfsGDrh4kx6z76rWpcgacpNQmCcAdZeMewEzq5tUFV/VVV/P9j8r8B391TbVLATW1IfJiEwHgN2J7kyySuAW4EH1jZIctmazVuAkz3WN/EcYyGpD83Xw6iqc0nuAB4CLgJ+paqeSHInsFRVDwA/meQW4BzwLPDOZgVL0oxq/ljtOG3nx2qPLC4PvbI4tHe3/RWSXraNHqs1MLYBx1hI2iqTPg5DkjQFDIxtwDEWkvpgYGwD9llI6oOBIUnqxMCQJHViYEiSOjEwJDXn9DbTwcCQ1JzT20wHA0OS1EnzuaQkzab109s8v5Sw09tMLqcGkdSc09tMDqcGkSRdMANDUnNObzMdDAxJzdlnMR0MDElSJwaGJKmTiQiMJPuTPJnkVJKFIcdfmeTjg+N/lGSu/yolabY1D4wkFwEfAd4OXA3cluTqdc1+FPhKVb0ROAL8Qr9VSpKaBwawBzhVVaer6jngXuDAujYHgLsHr38T2JskPdYoSTNvEgLjcuCZNdtnBvuGtqmqc8BfA/942MmSHEyylGRpZWVlDOVK0myahMAYdqWwfvh5lzarO6uOVtV8Vc3v2LHjgouTJK2ahMA4A+xas70TODuqTZKLgX8EPNtLdZIkYDIC4zFgd5Irk7wCuBV4YF2bB4DbB6//NfA7tZ0nwZKkCdR8ttqqOpfkDuAh4CLgV6rqiSR3AktV9QDw34BfS3KK1SuLW9tVLEmzqXlgAFTVg8CD6/b97JrXfwe8o++6JEn/3yTckpIkTQEDQ5LUiYEhSerEwJAkdWJgSJI6MTAkSZ0YGJKkTgwMSVInBoYkqRMDQ5LUiYEhSerEwJAkdWJgbJEji8utS5CksTIwtsiHH3mqdQmSNFYGhiSpk4lYD2NaHVlcftGVxdzCMQAO7d3Ne/Zd1aosSRqLtFzpNMm3AR8H5oCngX9TVV8Z0u5rwB8PNv+sqm7pcv75+flaWlrammLPY27hGE/fdVMvP0uSxiXJiaqaH3as9S2pBeCRqtoNPDLYHub/VtWbB1+dwkKStLVaB8YB4O7B67uBf9mwlgtyaO/u1iVI0li1DozXVdWfAwy+v3ZEu29MspTk0SQTGSr2WUja7sbe6Z3kOHDpkEOHN3GaK6rqbJI3AL+T5I+r6k9H/LyDwEGAK664YtP1SpKGG/sVRlVdX1XXDPm6H/jLJJcBDL5/acQ5zg6+nwZ+D/jODX7e0aqar6r5HTt2bPn7kbYTB5xqM1rfknoAuH3w+nbg/vUNkrwmySsHry8BrgP+pLcKpW3MAafajNaBcRewL8lTwL7BNknmk3xs0OZNwFKSzwC/C9xVVQaGJPWs6TiMcetzHIY0LdYPOH2eA04FG4/DMDCkGeaAU603yQP3JElTwsCQZpgDTrUZBoY0w+yz0GYYGJKkTgwMSVInBoYkqRMDQ5LUiYEhSerEwJAkdWJgSJI6MTCkDpwGXDIwpE6cBlwyMCRJHY19iVZpWq2fBnxu4RjgNOAtHFlc9nc+AQwMaYT37LvqhT9STgPe1ocfecrAmADekpIkddI0MJK8I8kTSb6eZOiCHYN2+5M8meRUkoU+a5TAacBbOLK4zNzCsRduBT7/2ifW2mm64l6SNwFfBz4KvLeqXrI8XpKLgGVW1/w+AzwG3NZlXW9X3JO2B28J9mejFfea9mFU1UmAJBs12wOcqqrTg7b3AgeA8waGJGnrTEMfxuXAM2u2zwz2DZXkYJKlJEsrKytjL07S+HlLcDKM/QojyXHg0iGHDlfV/V1OMWTfyPtoVXUUOAqrt6Q6FSlpovmE1GQYe2BU1fUXeIozwK412zuBsxd4TknSJk3DLanHgN1JrkzyCuBW4IHGNUnSzGn9WO0PJjkDfC9wLMlDg/2vT/IgQFWdA+4AHgJOAp+oqida1SxJs6ppYFTVfVW1s6peWVWvq6q3Dfafraob17R7sKquqqp/WlUfbFexJE2+cY1VmYZbUpKkTRjX7MoGhiSpEycflKQptXYW3z5mVzYwJGlKrZ3Ft4/Zlb0lJUnqxCuMbcDFZaTZ0eXW07imUmk6W+24zcpstc7kKc2mcfy/v9Fstd6SkiR14i2pKeV605L6nsXXW1LbgLekJG0Vb0lJki6YgbENuLiMpD6cNzCSHE/yHX0Uo5fHPgtJfehyhfHTwJEkv5rksnEXJEmaTOcNjKr6dFV9P/BJ4LeTvD/JN42/NEmaDOOaLnzadOrDSBLgSeCXgJ8Ankryw+MsTJImxbimC582Xfow/hD4InAEuBx4J/AWYE+So+MsTtL4+KlZm9Vl4N67gSfqpQM2fiLJyQv54UneAfwc8CZgT1UNHTSR5Gngq8DXgHOjnhGW1N3amU71Ug6OfanzBkZVfW6Dwxc6WuxzwL8CPtqh7Vur6ssX+PMkqZM+pgufNhc0NUhVnb7Af38SYLWLRNK4+alZF2Ja5pIq4OEkBXy0qkb2nSQ5CBwEuOKKK3oqT5oOfmp+eRwcu2rsgZHkOHDpkEOHq+r+jqe5rqrOJnktsJjk81X1B8MaDsLkKKzOJfWyipakNbz6WjX2wKiq67fgHGcH37+U5D5gDzA0MCR146dmbdbEzyWV5FVJXv38a+AGVjvLJV0APzVrs5oGRpIfTHIG+F7gWJKHBvtfn+TBQbPXAX+Y5DPA/wSOVdVvj7s2n1GXpBdr2uldVfcB9w3Zfxa4cfD6NND75Ic+oy5JLzbxt6QkSZNhWh6r7YXPqEvSaC7ROoLPqEuaRS7RKkm6YAbGCD6jLkkvZmCMYJ+FJL2YgSFNKMcCadIYGNKEcpU3TRoDQy/wE62kjTgOQy9wdHt7jgXSJDMwZsCRxWX/2EwJ16vQJDMwZsBGVw5+opXUlYEx4/xEO7kcC6RJY2BsU145TD//O2nSGBjb1Mu5cvATraSN+FitXuAnWkkbab3i3oeSfD7JZ5Pcl+RbR7Tbn+TJJKeSLPRd57TzykHSVmh9hbEIXFNV1wLLwPvWN0hyEfAR4O3A1cBtSa7utcop55WDpK3QNDCq6uGqOjfYfBTYOaTZHuBUVZ2uqueAe4EDfdUoSVrV+gpjrXcBnxqy/3LgmTXbZwb7hkpyMMlSkqWVlZUtLlGSZtfYn5JKchy4dMihw1V1/6DNYeAccM+wUwzZN3KZwKo6ChyF1RX3Nl2wJGmosQdGVV2/0fEktwM3A3tr+HqxZ4Bda7Z3Ame3rkJJUhetn5LaD/wMcEtV/e2IZo8Bu5NcmeQVwK3AA33VKEla1boP4xeBVwOLSR5P8ssASV6f5EGAQaf4HcBDwEngE1X1RKuCtfWcVl2aDk1HelfVG0fsPwvcuGb7QeDBvupSv5xWXZoOra8wJElTwrmk1ISTI0rTJ8MfTNoe5ufna2lpqXUZOg+nVZcmR5ITVTU/7Ji3pCRJnRgYas7JEaXpYGCoOfsspOlgYEiSOjEwJEmdGBiSpE4MDElSJwaGJKkTA0OS1ImBIUnqxMCQJHViYEiSOjEwJEmdGBiSpE6aroeR5EPADwDPAX8K/EhV/e8h7Z4Gvgp8DTg3aupdSdL4tL7CWASuqaprgWXgfRu0fWtVvdmwkKQ2mgZGVT1cVecGm48CO1vWI0karfUVxlrvAj414lgBDyc5keTgRidJcjDJUpKllZWVLS9SkjZyZHG5dQljM/bASHI8yeeGfB1Y0+YwcA64Z8Rprquq7wLeDvx4kn8x6udV1dGqmq+q+R07dmzpe5Gk81m7Vv12M/ZO76q6fqPjSW4Hbgb21ogFxqvq7OD7l5LcB+wB/mCra5Ukjdb6Kan9wM8A31dVfzuizauAb6iqrw5e3wDc2WOZkrShI4vLL7qymFs4BqwuP7ydVpTMiA/1/fzw5BTwSuCvBrserap3J3k98LGqujHJG4D7BscvBn6jqj7Y5fzz8/O1tLS05XVL0ihzC8d4+q6bWpfxsiU5Mepp1KZXGFX1xhH7zwI3Dl6fBr6jz7okSS81SU9JSdLUO7R3d+sSxsbAkKQttJ36LNYzMCRJnRgYkqRODAxJUicGhiSpEwNDktSJgSFJ6sTAkCR1YmBIkjoxMCRJnRgYkqRODAxJUicGhiSpEwNDktSJgSFJ6qR5YCT5QJLPJnk8ycOD1faGtbs9yVODr9v7rlOSZl3zwAA+VFXXVtWbgU8CP7u+QZJvA94PfA+wB3h/ktf0W6YkzbbmgVFVf7Nm81XAsEXG3wYsVtWzVfUVYBHY30d9kqRVTdf0fl6SDwL/Hvhr4K1DmlwOPLNm+8xgnySpJ71cYSQ5nuRzQ74OAFTV4araBdwD3DHsFEP2DbsSIcnBJEtJllZWVrbuTUjSjOslMKrq+qq6ZsjX/eua/gbwQ0NOcQbYtWZ7J3B2xM86WlXzVTW/Y8eOrXkDmkpHFpdblyBtK837MJLsXrN5C/D5Ic0eAm5I8ppBZ/cNg33SSB9+5KnWJUjbyiT0YdyV5J8BXwe+ALwbIMk88O6q+rGqejbJB4DHBv/mzqp6tk25kjSbUjW0K2BbmJ+fr6WlpdZlqEdHFpeHXlkc2rub9+y7qkFF0nRJcqKq5oceMzC0Xc0tHOPpu25qXYY0VTYKjOZ9GJKk6WBgaNs6tHf3+RtJ6szA0LZln4W0tQwMSVInBoYkqRMDQ5LUiYGh3jllhzSdDAz1zik7pOlkYEiSOpmEuaQ0A9ZP2TG3cAxwyg5pmjg1iHrnlB3S5HJqEEnSBTMw1Dun7JCmk4Gh3tlnIU0nA0OS1ImBIUnqxMCQJHViYEiSOjEwJEmdbOuBe0lWgC+0rmMDlwBfbl1EQ77/2X7/4O9gEt//P6mqHcMObOvAmHRJlkaNqJwFvv/Zfv/g72Da3r+3pCRJnRgYkqRODIy2jrYuoDHfv2b9dzBV798+DElSJ15hSJI6MTAkSZ0YGBMiyXuTVJJLWtfSpyQfSvL5JJ9Ncl+Sb21dUx+S7E/yZJJTSRZa19OnJLuS/G6Sk0meSHKodU0tJLkoyf9K8snWtXRlYEyAJLuAfcCfta6lgUXgmqq6FlgG3te4nrFLchHwEeDtwNXAbUmubltVr84B/6Gq3gT8c+DHZ+z9P+8QcLJ1EZthYEyGI8BPAzP3BEJVPVxV5wabjwI7W9bTkz3Aqao6XVXPAfcCBxrX1Juq+vOq+vTg9VdZ/aN5eduq+pVkJ3AT8LHWtWyGgdFYkluAL1bVZ1rXMgHeBXyqdRE9uBx4Zs32GWbsD+bzkswB3wn8UdtKevefWf2Q+PXWhWzGxa0LmAVJjgOXDjl0GPiPwA39VtSvjd5/Vd0/aHOY1VsV9/RZWyMZsm/mri6TfAvw34Gfqqq/aV1PX5LcDHypqk4keUvrejbDwOhBVV0/bH+SbweuBD6TBFZvx3w6yZ6q+oseSxyrUe//eUluB24G9tZsDAw6A+xas70TONuoliaS/ANWw+Keqvqt1vX07DrgliQ3At8I/MMkv15V/65xXeflwL0JkuRpYL6qJm32yrFJsh/4T8D3VdVK63r6kORiVjv49wJfBB4D/m1VPdG0sJ5k9dPR3cCzVfVTretpaXCF8d6qurl1LV3Yh6HWfhF4NbCY5PEkv9y6oHEbdPLfATzEaofvJ2YlLAauA34Y+P7Bf/PHB5+2NeG8wpAkdeIVhiSpEwNDktSJgSFJ6sTAkCR1YmBIkjoxMCRJnRgYkqRODAypR4N1IPYNXv98kv/SuiapK+eSkvr1fuDOJK9ldZbWWxrXI3XmSG+pZ0l+H/gW4C2D9SCkqeAtKalHgxmKLwP+3rDQtDEwpJ4kuYzV9T4OAP8nydsalyRtioEh9SDJNwO/xepa1ieBDwA/17QoaZPsw5AkdeIVhiSpEwNDktSJgSFJ6sTAkCR1YmBIkjoxMCRJnRgYkqRO/h9nQU2Hc524lwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Generate some data\n", "sigma = 1.0 # noise standard deviation\n", "alpha = 1.0 # standard deviation of the parameter prior\n", "N = 20\n", "\n", "np.random.seed(42)\n", "\n", "X = (np.random.rand(N)*10.0 - 5.0).reshape(-1,1)\n", "y = g(X, sigma) # training targets\n", "\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.xlabel(\"$x$\")\n", "plt.ylabel(\"$y$\");" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# get the MAP estimate\n", "K = 8 # polynomial degree \n", "\n", "\n", "# feature matrix\n", "Phi = poly_features(X, K)\n", "\n", "theta_map = map_estimate_poly(Phi, y, sigma, alpha)\n", "\n", "# maximum likelihood estimate\n", "theta_ml = nonlinear_features_maximum_likelihood(Phi, y)\n", "\n", "Xtest = np.linspace(-5,5,100).reshape(-1,1)\n", "ytest = g(Xtest, sigma)\n", "\n", "Phi_test = poly_features(Xtest, K)\n", "y_pred_map = Phi_test @ theta_map\n", "\n", "y_pred_mle = Phi_test @ theta_ml\n", "\n", "plt.figure()\n", "plt.plot(X, y, '+')\n", "plt.plot(Xtest, y_pred_map)\n", "plt.plot(Xtest, g(Xtest, 0))\n", "plt.plot(Xtest, y_pred_mle)\n", "\n", "plt.legend([\"data\", \"map prediction\", \"ground truth function\", \"maximum likelihood\"]);" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-1.49712990e+00 -1.08154986e+00]\n", " [ 8.56868912e-01 6.09177023e-01]\n", " [-1.28335730e-01 -3.62071208e-01]\n", " [-7.75319509e-02 -3.70531732e-03]\n", " [ 3.56425467e-02 7.43090617e-02]\n", " [-4.11626749e-03 -1.03278646e-02]\n", " [-2.48817783e-03 -4.89363010e-03]\n", " [ 2.70146690e-04 4.24148554e-04]\n", " [ 5.35996050e-05 1.03384719e-04]]\n" ] } ], "source": [ "print(np.hstack([theta_ml, theta_map]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let us compute the RMSE for different polynomial degrees and see whether the MAP estimate addresses the overfitting issue we encountered with the maximum likelihood estimate." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Applications/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:13: LinAlgWarning: Ill-conditioned matrix (rcond=1.82839e-17): result may not be accurate.\n", " del sys.path[0]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "\n", "K_max = 12 # this is the maximum degree of polynomial we will consider\n", "assert(K_max < N) # this is the latest point when we'll run into numerical problems\n", "\n", "rmse_mle = np.zeros((K_max+1,))\n", "rmse_map = np.zeros((K_max+1,))\n", "\n", "for k in range(K_max+1):\n", " \n", " \n", " # feature matrix\n", " Phi = poly_features(X, k)\n", " \n", " # maximum likelihood estimate\n", " theta_ml = nonlinear_features_maximum_likelihood(Phi, y)\n", " \n", " # predict the function values at the test input locations (maximum likelihood)\n", " y_pred_test = 0*Xtest ## <--- EDIT THIS LINE\n", " \n", " ####################### SOLUTION\n", " # feature matrix for test inputs\n", " Phi_test = poly_features(Xtest, k)\n", " \n", " # prediction\n", " ypred_test_mle = Phi_test @ theta_ml\n", " #######################\n", " \n", " # RMSE on test set (maximum likelihood)\n", " rmse_mle[k] = RMSE(ytest, ypred_test_mle)\n", " \n", " # MAP estimate\n", " theta_map = map_estimate_poly(Phi, y, sigma, alpha)\n", "\n", " # Feature matrix\n", " Phi_test = poly_features(Xtest, k)\n", " \n", " # predict the function values at the test input locations (MAP)\n", " ypred_test_map = Phi_test @ theta_map\n", " \n", " # RMSE on test set (MAP)\n", " rmse_map[k] = RMSE(ytest, ypred_test_map)\n", " \n", "\n", "plt.figure()\n", "plt.semilogy(rmse_mle) # this plots the RMSE on a logarithmic scale\n", "plt.semilogy(rmse_map) # this plots the RMSE on a logarithmic scale\n", "plt.xlabel(\"degree of polynomial\")\n", "plt.ylabel(\"RMSE\")\n", "plt.legend([\"Maximum likelihood\", \"MAP\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Questions:\n", "1. What do you observe?\n", "2. What is the influence of the prior variance on the parameters ($\\alpha^2$)? Change the parameter and describe what happens." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Bayesian Linear Regression" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# Test inputs\n", "Ntest = 200\n", "Xtest = np.linspace(-5, 5, Ntest).reshape(-1,1) # test inputs\n", "\n", "prior_var = 2.0 # variance of the parameter prior (alpha^2). We assume this is known.\n", "noise_var = 1.0 # noise variance (sigma^2). We assume this is known.\n", "\n", "pol_deg = 3 # degree of the polynomial we consider at the moment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assume a parameter prior $p(\\boldsymbol\\theta) = \\mathcal N (\\boldsymbol 0, \\alpha^2\\boldsymbol I)$. For every test input $\\boldsymbol x_*$ we obtain the \n", "prior mean\n", "$$\n", "E[f(\\boldsymbol x_*)] = 0\n", "$$\n", "and the prior (marginal) variance (ignoring the noise contribution)\n", "$$\n", "V[f(\\boldsymbol x_*)] = \\alpha^2\\boldsymbol\\phi(\\boldsymbol x_*) \\boldsymbol\\phi(\\boldsymbol x_*)^\\top\n", "$$\n", "where $\\boldsymbol\\phi(\\cdot)$ is the feature map." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "\n", "# compute the feature matrix for the test inputs\n", "Phi_test = poly_features(Xtest, pol_deg) # N x (pol_deg+1) feature matrix SOLUTION\n", "\n", "# compute the (marginal) prior at the test input locations\n", "# prior mean\n", "prior_mean = np.zeros((Ntest,1)) # prior mean <-- SOLUTION\n", "\n", "# prior variance\n", "full_covariance = Phi_test @ Phi_test.T * prior_var # N x N covariance matrix of all function values\n", "prior_marginal_var = np.diag(full_covariance)\n", "\n", "# Let us visualize the prior over functions\n", "plt.figure()\n", "plt.plot(Xtest, prior_mean, color=\"k\")\n", "\n", "conf_bound1 = np.sqrt(prior_marginal_var).flatten()\n", "conf_bound2 = 2.0*np.sqrt(prior_marginal_var).flatten()\n", "conf_bound3 = 2.0*np.sqrt(prior_marginal_var + noise_var).flatten()\n", "plt.fill_between(Xtest.flatten(), prior_mean.flatten() + conf_bound1, \n", " prior_mean.flatten() - conf_bound1, alpha = 0.1, color=\"k\")\n", "plt.fill_between(Xtest.flatten(), prior_mean.flatten() + conf_bound2, \n", " prior_mean.flatten() - conf_bound2, alpha = 0.1, color=\"k\")\n", "plt.fill_between(Xtest.flatten(), prior_mean.flatten() + conf_bound3, \n", " prior_mean.flatten() - conf_bound3, alpha = 0.1, color=\"k\")\n", "\n", "plt.xlabel('$x$')\n", "plt.ylabel('$y$')\n", "plt.title(\"Prior over functions\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we will use this prior distribution and sample functions from it." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Every sampled function is a polynomial of degree 3\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## EDIT THIS CELL\n", "\n", "# samples from the prior\n", "num_samples = 10\n", "\n", "# We first need to generate random weights theta_i, which we sample from the parameter prior\n", "random_weights = np.random.normal(size=(pol_deg+1,num_samples), scale=np.sqrt(prior_var))\n", "\n", "# Now, we compute the induced random functions, evaluated at the test input locations\n", "# Every function sample is given as f_i = Phi * theta_i, \n", "# where theta_i is a sample from the parameter prior\n", "\n", "sample_function = Phi_test @ random_weights # <-- SOLUTION\n", "\n", "plt.figure()\n", "plt.plot(Xtest, sample_function, color=\"r\")\n", "plt.title(\"Plausible functions under the prior\")\n", "print(\"Every sampled function is a polynomial of degree \"+str(pol_deg));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are given some training inputs $\\boldsymbol x_1, \\dotsc, \\boldsymbol x_N$, which we collect in a matrix $\\boldsymbol X = [\\boldsymbol x_1, \\dotsc, \\boldsymbol x_N]^\\top\\in\\mathbb{R}^{N\\times D}$" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "N = 10\n", "X = np.random.uniform(high=5, low=-5, size=(N,1)) # training inputs, size Nx1\n", "y = g(X, np.sqrt(noise_var)) # training targets, size Nx1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let us compute the posterior " ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS FUNCTION\n", "\n", "def polyfit(X, y, K, prior_var, noise_var):\n", " # X: training inputs, size N x D\n", " # y: training targets, size N x 1\n", " # K: degree of polynomial we consider\n", " # prior_var: prior variance of the parameter distribution\n", " # sigma: noise variance\n", " \n", " jitter = 1e-08 # increases numerical stability\n", " \n", " Phi = poly_features(X, K) # N x (K+1) feature matrix \n", " \n", " # Compute maximum likelihood estimate\n", " Pt = Phi.T @ y # Phi*y, size (K+1,1)\n", " PP = Phi.T @ Phi + jitter*np.eye(K+1) # size (K+1, K+1)\n", " C = scipy.linalg.cho_factor(PP)\n", " # maximum likelihood estimate\n", " theta_ml = scipy.linalg.cho_solve(C, Pt) # inv(Phi^T*Phi)*Phi^T*y, size (K+1,1)\n", " \n", "# theta_ml = scipy.linalg.solve(PP, Pt) # inv(Phi^T*Phi)*Phi^T*y, size (K+1,1)\n", " \n", " # MAP estimate\n", " theta_map = scipy.linalg.solve(PP + noise_var/prior_var*np.eye(K+1), Pt)\n", " \n", " # parameter posterior\n", " iSN = (np.eye(K+1)/prior_var + PP/noise_var) # posterior precision\n", " SN = scipy.linalg.pinv(noise_var*np.eye(K+1)/prior_var + PP)*noise_var # posterior covariance\n", " mN = scipy.linalg.solve(iSN, Pt/noise_var) # posterior mean\n", " \n", " return (theta_ml, theta_map, mN, SN)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "theta_ml, theta_map, theta_mean, theta_var = polyfit(X, y, pol_deg, alpha, sigma)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's make predictions (ignoring the measurement noise). We obtain three predictors:\n", "\\begin{align}\n", "&\\text{Maximum likelihood: }E[f(\\boldsymbol X_{\\text{test}})] = \\boldsymbol \\phi(X_{\\text{test}})\\boldsymbol \\theta_{ml}\\\\\n", "&\\text{Maximum a posteriori: } E[f(\\boldsymbol X_{\\text{test}})] = \\boldsymbol \\phi(X_{\\text{test}})\\boldsymbol \\theta_{map}\\\\\n", "&\\text{Bayesian: } p(f(\\boldsymbol X_{\\text{test}})) = \\mathcal N(f(\\boldsymbol X_{\\text{test}}) \\,|\\, \\boldsymbol \\phi(X_{\\text{test}}) \\boldsymbol\\theta_{\\text{mean}},\\, \\boldsymbol\\phi(X_{\\text{test}}) \\boldsymbol\\theta_{\\text{var}} \\boldsymbol\\phi(X_{\\text{test}})^\\top)\n", "\\end{align}\n", "We already computed all quantities. Write some code that implements all three predictors." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "## EDIT THIS CELL\n", "\n", "# predictions (ignoring the measurement/observations noise)\n", "\n", "Phi_test = poly_features(Xtest, pol_deg) # N x (K+1)\n", "\n", "# maximum likelihood predictions (just the mean)\n", "m_mle_test = Phi_test @ theta_ml\n", "\n", "# MAP predictions (just the mean)\n", "m_map_test = Phi_test @ theta_map\n", "\n", "# predictive distribution (Bayesian linear regression)\n", "# mean prediction\n", "mean_blr = Phi_test @ theta_mean\n", "# variance prediction\n", "cov_blr = Phi_test @ theta_var @ Phi_test.T" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot the posterior\n", "plt.figure()\n", "plt.plot(X, y, \"+\")\n", "plt.plot(Xtest, m_mle_test)\n", "plt.plot(Xtest, m_map_test)\n", "var_blr = np.diag(cov_blr)\n", "conf_bound1 = np.sqrt(var_blr).flatten()\n", "conf_bound2 = 2.0*np.sqrt(var_blr).flatten()\n", "conf_bound3 = 2.0*np.sqrt(var_blr + sigma).flatten()\n", "\n", "plt.fill_between(Xtest.flatten(), mean_blr.flatten() + conf_bound1, \n", " mean_blr.flatten() - conf_bound1, alpha = 0.1, color=\"k\")\n", "plt.fill_between(Xtest.flatten(), mean_blr.flatten() + conf_bound2, \n", " mean_blr.flatten() - conf_bound2, alpha = 0.1, color=\"k\")\n", "plt.fill_between(Xtest.flatten(), mean_blr.flatten() + conf_bound3, \n", " mean_blr.flatten() - conf_bound3, alpha = 0.1, color=\"k\")\n", "plt.legend([\"Training data\", \"MLE\", \"MAP\", \"BLR\"])\n", "plt.xlabel('$x$');\n", "plt.ylabel('$y$');" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 2 }