{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3-a: Bayesian Linear Regression\n", "\n", "During this session, we will work with Bayesian Linear Regression models with varying basis functions (linear, polynomial and Gaussian). Datasets used are 1D toy regression samples ranging from linear datasets to more complex non-linear datasets such as increasing sinusoidal curves.\n", "\n", "**Goal**: Take hand on simple Bayesian models, understand how it works, gain finer insights on predictive distribution.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Useful function: plot results\n", "def plot_results(X_train, y_train, X_test, y_test, y_pred, std_pred,\n", " xmin=-2, xmax=2, ymin=-2, ymax=1, stdmin=0.30, stdmax=0.45):\n", " \"\"\"Given a dataset and predictions on test set, this function draw 2 subplots:\n", " - left plot compares train set, ground-truth (test set) and predictions\n", " - right plot represents the predictive variance over input range\n", " \n", " Args:\n", " X_train: (array) train inputs, sized [N,]\n", " y_train: (array) train labels, sized [N, ]\n", " X_test: (array) test inputs, sized [N,]\n", " y_test: (array) test labels, sized [N, ]\n", " y_pred: (array) mean prediction, sized [N, ]\n", " std_pred: (array) std prediction, sized [N, ]\n", " xmin: (float) min value for x-axis on left and right plot\n", " xmax: (float) max value for x-axis on left and right plot\n", " ymin: (float) min value for y-axis on left plot\n", " ymax: (float) max value for y-axis on left plot\n", " stdmin: (float) min value for y-axis on right plot\n", " stdmax: (float) max value for y-axis on right plot\n", " \n", " Returns:\n", " None\n", " \"\"\"\n", " plt.figure(figsize=(15,5))\n", " plt.subplot(121)\n", " plt.xlim(xmin = xmin, xmax = xmax)\n", " plt.ylim(ymin = ymin, ymax = ymax)\n", " plt.plot(X_test, y_test, color='green', linewidth=2,\n", " label=\"Ground Truth\")\n", " plt.plot(X_train, y_train, 'o', color='blue', label='Training points')\n", " plt.plot(X_test, y_pred, color='red', label=\"BLR Poly\")\n", " plt.fill_between(X_test, y_pred-std_pred, y_pred+std_pred, color='indianred', label='1 std. int.')\n", " plt.fill_between(X_test, y_pred-std_pred*2, y_pred-std_pred, color='lightcoral')\n", " plt.fill_between(X_test, y_pred+std_pred*1, y_pred+std_pred*2, color='lightcoral', label='2 std. int.')\n", " plt.fill_between(X_test, y_pred-std_pred*3, y_pred-std_pred*2, color='mistyrose')\n", " plt.fill_between(X_test, y_pred+std_pred*2, y_pred+std_pred*3, color='mistyrose', label='3 std. int.')\n", " plt.legend()\n", "\n", " plt.subplot(122)\n", " plt.title(\"Predictive variance along x-axis\")\n", " plt.xlim(xmin = xmin, xmax = xmax)\n", " plt.ylim(ymin = stdmin, ymax = stdmax)\n", " plt.plot(X_test, std_pred**2, color='red', label=\"\\u03C3² {}\".format(\"Pred\"))\n", "\n", " # Get training domain\n", " training_domain = []\n", " current_min = sorted(X_train)[0]\n", " for i, elem in enumerate(sorted(X_train)):\n", " if elem-sorted(X_train)[i-1]>1:\n", " training_domain.append([current_min,sorted(X_train)[i-1]])\n", " current_min = elem\n", " training_domain.append([current_min, sorted(X_train)[-1]])\n", " \n", " # Plot domain\n", " for j, (min_domain, max_domain) in enumerate(training_domain):\n", " plt.axvspan(min_domain, max_domain, alpha=0.5, color='gray', label=\"Training area\" if j==0 else '')\n", " plt.axvline(X_train.mean(), linestyle='--', label=\"Training barycentre\") \n", " \n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part I: Linear Basis function model\n", "\n", "We start with a linear dataset where we will analyze the behavior of linear basis functions in the framework of Bayesian Linear Regression." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Generate linear toy dataset\n", "def f_linear(x, noise_amount, sigma):\n", " y = -0.3 + 0.5*x\n", " noise = np.random.normal(0, sigma, len(x))\n", " return y + noise_amount*noise\n", "\n", "# Create training and test points\n", "sigma = 0.2\n", "nbpts=25\n", "dataset_linear = {}\n", "dataset_linear['X_train'] = np.random.uniform(0, 2, nbpts)\n", "dataset_linear['y_train'] = f_linear(dataset_linear['X_train'], noise_amount=1, sigma=sigma)\n", "dataset_linear['X_test'] = np.linspace(-10,10, 10*nbpts)\n", "dataset_linear['y_test'] = f_linear(dataset_linear['X_test'], noise_amount=0, sigma=sigma)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbwAAAEzCAYAAABKVrbSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA3mklEQVR4nO3dd3wUdf7H8dc3BULoTUEgBKQKUkNvof5sJ6eeFQs2REBOL/TQIRSTWAE59KwXldM7OT0bvYpCQKqAAtKR3hNI+/7+2IABE0jZZHaz7+fjwWN3ZmdnPpMN+87MfOf7NdZaRERECjs/pwsQEREpCAo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCW4JPGPM28aYw8aYTVm8Hm6MOWWMWZf+b7Q7tisiIpJdAW5az7vANOD9qyyzzFp7h5u2JyIikiNuOcKz1i4FjrtjXSIiIvmhIK/htTHGrDfGfG2MaVCA2xUREXHbKc1rWQtUt9aeNcbcBswBal+5kDGmD9AHoHjx4s3r1atXQOWJiIg3WLNmzVFrbcXcvNe4qy9NY0wo8D9rbcNsLLsLCLPWHs1qmbCwMBsfH++W2kREpHAwxqyx1obl5r0FckrTGFPJGGPSn7dM3+6xgti2iIgIuOmUpjHmIyAcqGCM2QeMAQIBrLUzgb8AzxpjUoBE4AGrYRpERKQAuSXwrLUPXuP1abhuWxAREXFEQTVacYvk5GT27dvH+fPnnS5FsikoKIiqVasSGBjodCki4uO8KvD27dtHyZIlCQ0NJf2SoHgway3Hjh1j37591KhRw+lyRMTHeVVfmufPn6d8+fIKOy9hjKF8+fI6IhcRj+BVgQco7LyMPi8R8RReF3hOO3ToEA899BA1a9akefPmtGnThs8++6xAa9i1axcNG15+u+PGjRtp0qQJTZo0oVy5ctSoUYMmTZrQrVu3bK/zww8/vDT97rvvMmDAALfWLSLiJAVeDlhr+fOf/0zHjh3ZuXMna9as4eOPP2bfvn1/WDYlJaVAa7v55ptZt24d69at48477yQ6Opp169Yxf/78bNV0ZeCJiBQ2CrwcWLhwIUWKFKFv376X5lWvXp3nnnsOcB0V3XnnnXTp0oWuXbty/Phx/vznP9OoUSNat27Nhg0bABg7diwxMTGX1tGwYUN27drFrl27qF+/Pk8//TQNGjSgR48eJCYmArBmzRoaN25M48aNmT59erZrDg8P5/nnnycsLIxXX32V3r178+mnn156vUSJEgAMGzaMZcuW0aRJE15++WUADhw4wC233ELt2rUZMmRILn9qIiKeQYGXA5s3b6ZZs2ZXXWbt2rV8+umnLFmyhDFjxtC0aVM2bNjApEmTePTRR6+5jV9++YX+/fuzefNmypQpw7///W8AHn/8cV5//XXWr1+f47qTkpKIj48nIiIiy2WmTJlChw4dWLduHS+88AIA69atY/bs2WzcuJHZs2ezd+/eHG9bRMRTeNVtCRmZcfnTGMKOyX4HMP3792f58uUUKVKE1atXA9C9e3fKlSsHwPLlyy8FVpcuXTh27BinT5++6jovXnsDaN68Obt27eLkyZOcPHmSjh07AvDII4/w9ddfZ7vO+++/P9vLZtS1a1dKly4NwE033cTu3bupVq1artYlIuI0HeHlQIMGDVi7du2l6enTp7NgwQKOHDlyaV7x4sWvuZ6AgADS0tIuTWdstl+0aNFLz/39/d1yLTBjTRm3nZaWRlJSUpbvy49aRESc4rVHeDk5EnOXLl26MGLECN544w2effZZABISErJcvkOHDsTFxTFq1CgWL15MhQoVKFWqFKGhofzvf/8DXKdAf/3116tut0yZMpQpU4bly5fTvn174uLicr0PoaGhrFmzhvvuu4/PP/+c5ORkAEqWLMmZM2dyvV4REU+nI7wcMMYwZ84clixZQo0aNWjZsiWPPfYYU6dOzXT5sWPHsmbNGho1asSwYcN47733ALjnnns4fvw4DRo0YNq0adSpU+ea237nnXfo378/TZo0IS/9bj/99NMsWbKExo0bs3LlyktHf40aNcLf35/GjRtfarQiIlKYuG08PHfLbDy8LVu2UL9+fYcqktzS5yYi7uLx4+GJiIg4TYEnIiI+QYEnIiI+QYEnIiI+QYEnIiI+QYEnIiI+QYGXA8eOHbs0BE+lSpWoUqXKpemr9VgCEB8fz8CBA6+5jbZt27qr3BzLzrZfeeWVq95sLyLiqXQfXi6NHTuWEiVKMGjQoEvzUlJSCAjw2s5rsiU0NJT4+HgqVKiQ7fd40ucmIt5N9+FlIS4OQkPBz8/1mIceubLUu3dv+vbtS6tWrRgyZAirVq2iTZs2NG3alLZt27Jt2zYAFi9ezB133AG4wvKJJ54gPDycmjVr8tprr11a38XhehYvXkx4eDh/+ctfqFevHr169brUw8pXX31FvXr1aN68OQMHDry03ozeffddevbsSXh4OLVr12bcuHGXXnvppZdo2LAhDRs25JVXXsn2tl977TUOHDhA586d6dy5M6mpqfTu3ZuGDRty8803q4cWEfFohfZwJC4O+vSBi2ffdu92TQP06uXebe3bt4/vvvsOf39/Tp8+zbJlywgICGD+/PmMGDHi0ogJGW3dupVFixZx5swZ6taty7PPPktgYOBly/z4449s3ryZG264gXbt2rFixQrCwsJ45plnWLp0KTVq1ODBBx/Msq5Vq1axadMmgoODadGiBbfffjvGGN555x1++OEHrLW0atWKTp060bRp02tue+DAgbz00kssWrSIChUqsGbNGvbv38+mTZsAOHnyZN5/mCIi+aTQHuFFRv4edhclJLjmu9u9996Lv78/AKdOneLee++lYcOGvPDCC2zevDnT99x+++0ULVqUChUqcN1113Ho0KE/LNOyZUuqVq2Kn58fTZo0YdeuXWzdupWaNWtSo0YNgKsGXvfu3SlfvjzFihXj7rvvZvny5Sxfvpy77rqL4sWLU6JECe6++26WLVuWrW1fqWbNmuzcuZPnnnuOb775hlKlSmXnxyUi4ohCG3h79uRsfl5kHH5n1KhRdO7cmU2bNvHFF19cNvRPRtkZeievw/MYY646fTXZ2XbZsmVZv3494eHhzJw5k6eeeipH9YmIFKRCG3ghITmb7y6nTp2iSpUqgOs6mrvVrVuXnTt3Xjrimj17dpbLzps3j+PHj5OYmMicOXNo164dHTp0YM6cOSQkJHDu3Dk+++wzOnTokO3tZxxG6OjRo6SlpXHPPfcwceLEy8YKFBHxNIU28KKiIDj48nnBwa75+WnIkCEMHz6cpk2b5suAqcWKFWPGjBnccsstNG/enJIlS14alfxKLVu25J577qFRo0bcc889hIWF0axZM3r37k3Lli1p1aoVTz311B+u311Nnz59uOWWW+jcuTP79+8nPDycJk2a8PDDDzN58mR37aaIyB98t/e7PL2/UN+WEBfnuma3Z4/ryC4qyv0NVpxw9uxZSpQogbWW/v37U7t2bV544YXLlnn33XeJj49n2rRpDlX5O92WICJ5seP4DoYtGManP30KY8n1bQmFtpUmuMKtMATcld58803ee+89kpKSaNq0Kc8884zTJYmIuN3xxONMXDqRaaumkZyWTLGAYiSSmOv1FeojPPEM+txEJCcupFxgxuoZTFg6gRPnT2AwPNbkMSZ0nkC10tV0hCciIt7NWsu/t/ybofOHsvPETgC61OhCTPcYmlbOfluDrHhd4Flrc9S8XpzlqWcQRMSzfL/veyLmRlxqmFK/Qn2iu0dzW+3b3Pad71WBFxQUxLFjxyhfvrxCzwtYazl27BhBQUFOlyIiHurXE78ybMEw/rX5XwBUDK7I+M7jearZUwT4uTeivCrwqlatyr59+zhy5IjTpUg2BQUFUbVqVafLEBEPcyLxBFHLonh91eskpSYRFBDE31r/jaHth1KqaP702uRVgRcYGHipSy0REfE+SalJvLH6DcYvHc/xxOMAPNLoEaK6RFGtdLV83bZXBZ6IiHgnay1zts5hyPwhbD++HYDw0HBie8TSrHKzAqlBgSciIvlq1f5VDJo7iGV7XB3V1y1fl+ju0dxR544CbY+hwBMRkXyx6+QuRiwYwUebPgKgQnAFxnYaS5/mfQj0D7zGu91PgSciIm518vxJJi+bzKs/vMqF1AsU9S/K862fZ3j74ZQOyrzv34KgwBMREbdITk3m72v+ztjFYzmWeAyAXjf3IqpLFNXLVHe4OgWeiIjkkbWWz7d9zpD5Q/j52M8AdAjpQGyPWFpUaeFwdb9T4ImISK7FH4hn0NxBLNm9BIDa5WrzYvcX6Vm3p8d1EKLAExGRHNtzag8jFowgbmMcAOWLlWdMpzH0DevrSIOU7FDgiYhItp2+cJopy6fw8vcvcz7lPEX8i/DXVn9lRIcRlAkq43R5V6XAExGRa0pJS+HNNW8yZvEYjiS4und8oOEDTO46mdAyoc4Wl00KPBERyZK1li9/+ZLB8waz9ehWANpVa0dsj1haVW3lcHU54+eOlRhj3jbGHDbGbMridWOMec0Ys90Ys8EYUzD9yIiIXCEuDkJDwc/P9RgX53RFnuvHgz/S9f2u/OmjP7H16FZuLHsjn977KcseX+Z1YQfuO8J7F5gGvJ/F67cCtdP/tQLeSH8UESkwcXHQpw8kJLimd+92TQP06uVcXZ5m76m9jFw0kg/Wf4DFUjaoLKM7jaZfi34U8S/idHm55pbAs9YuNcaEXmWRnsD71jUa6PfGmDLGmMrW2oPu2L6ISHZERv4edhclJLjmK/DgzIUzTF0xldiVsZxPOU+gXyDPtXyOkR1HUrZYWafLy7OCuoZXBdibYXpf+rzLAs8Y0wfoAxASElJApYmIr9izJ2fzfUVKWgr/WPsPRi8ezeFzhwG4r8F9TO46mZplazpcnft4VKMVa+0sYBZAWFiYdbgcESlkQkJcpzEzm++LrLV8vf1rBs8bzE9HfgKgTdU2xPaIpU21Ng5X535uabSSDfuBjCP7VU2fJyJSYKKiIDj48nnBwa75vmb9b+vp8c8e3P7h7fx05CdqlKnBv/7yL1Y8saJQhh0U3BHe58AAY8zHuBqrnNL1OxEpaBev00VGuk5jhoS4ws6Xrt/tP72fUYtG8e66d7FYygSVYVTHUfRv0Z+iAUWdLi9fuSXwjDEfAeFABWPMPmAMEAhgrZ0JfAXcBmwHEoDH3bFdEZGc6tXLtwLuorNJZ4leEU3MyhgSkhMI9Aukf4v+jOw4kvLB5Z0ur0C4q5Xmg9d43QL93bEtERHJvtS0VN5Z9w6jFo3it7O/AXBP/XuY0m0KtcrVcri6guVRjVZERMR9vt3+LYPmDWLTYVefIC2rtCS2RyztQ9o7XJkzFHgiIoXMxkMbGTRvEHN3zAWgeunqTOk2hfsb3O9xQ/YUJAWeiEghcfDMQUYtGsU7694hzaZRumhpIjtE8lyr5wgKCHK6PMcp8EREvNy5pHPEfBdD9HfRnEs+R4BfAP1b9Gd0p9FUCK7gdHkeQ4EnIuKlUtNSeW/9e4xcOJKDZ113et1V7y6mdJtCnfJ1HK7O8yjwRES80Lwd8xg0bxAbDm0AIOyGMGJ7xNKxekeHK/NcCjwRES+y+fBmBs8bzNfbvwYgpHQIk7tO5oGGD+BnCqrzLO+kwBMR8QK/nf2NMYvG8NaPb5Fm0yhVtBQj2o9gYKuBFAss5nR5XkGBJyLiwRKSE3hp5UtMXTGVs0ln8Tf+9G/RnzGdxlCxeEWny/MqCjwREQ+UZtP4YP0HRC6MZP8ZV1/7d9a9k6ndplKvQj2Hq/NOOuErIj4pLg5CQ8HPz/UYF+d0Rb9b+OtCms9qTu//9mb/mf00q9yMhY8u5L8P/Fdhlwc6whMRnxMXB336/D76+e7drmlwtmPpLUe2MGT+EP738/8AqFqqKpO6TKJXo15qkOIGxtWvs+cJCwuz8fHxTpchIoVQaGjmA8FWrw67dhV0NXD43GHGLh7LrDWzSLWplChSguHth/NC6xfUIOUKxpg11tqw3LxXR3gi4nP27MnZ/PySmJzIK9+/wuTlkzmTdAY/40ff5n0ZGz6W60tcX7DF+AAFnoj4nJCQzI/wQkIKZvtpNo24DXFELoxk7+m9ANxe+3Ze7P4iN1W8qWCK8EEKPBHxOVFRl1/DAwgOds3Pb4t3LSZibgRrD64FoPH1jYntEUvXml3zf+M+ToEnIj7nYsOUyEjXacyQEFfY5WeDla1HtzJ0/lA+3/Y5AFVKViGqSxQPN3oYfz///NuwXKLAExGf1KtXwbTIPHLuCOOWjGNm/ExSbSrFA4szrP0w/tbmbwQHBud/AXKJAk9EJB+cTznPq9+/yqTlkzh94TR+xo8+zfowrvM4KpWo5HR5PkmBJyLiRmk2jY83fczwBcPZc8rV7PPWWrfyYvcXaXhdQ4er820KPBERN1m2exkRcyNYfWA1AI2ub0RM9xi639jd4coEFHgiInn287GfGTZ/GJ9t/QyAyiUqM7HLRB5r/JgapHgQBZ6ISC4dTTjKhCUTmBE/g5S0FIIDgxnSdggRbSMoUaSE0+XJFRR4IiI5dD7lPNNWTWPi0omcunAKg+HJpk8yvvN4bih5g9PlSRYUeCIi2WStZfbm2QxfMJxdJ3cB0L1md2J6xNDo+kbOFifXpMATEcmGFXtWEDE3gh/2/wBAg4oNiOkRwy21bnG4MskuBZ6IyFXsOL6DofOH8u8t/wagUolKTOg8gd5NehPgp69Qb6JPS0QkE8cTjzNhyQSmr55OcloyxQKKMbjtYAa3G6wGKV5KgSciksGFlAtMXz2dCUsncPL8SQyGx5s8zoTOE6hSqorT5UkeKPBERHA1SPn0p08ZtmAYO0/sBKBrja7E9IihSaUmzhYnbqHAExGft3LvSiLmRrBy30oA6leoT0yPGG6tdSvGGIerE3dR4ImIz9p5YifDFwznX5v/BcB1xa9jfPh4nmz2pBqkFEL6REXE55xIPEHUsiheX/U6SalJBAUEEdEmgiHthlCqaCmny5N8osATEZ+RlJrEG6vfYPzS8RxPPA7Ao40fZWLniVQrXc3h6iS/KfBEpNCz1vLZ1s8YOn8o249vByA8NJzYHrE0q9zM4eqkoCjwRKRQW7V/FRFzI1i+ZzkAdcvXJbp7NHfUuUMNUnyMAk9ECqVdJ3cxfMFwPt70MQAVgiswLnwcTzd7mkD/QIerEyco8ESkUDl5/iSTlk3i1R9eJSk1iaL+RXmh9QsMaz+M0kGlnS5PHKTAE5FCITk1mZnxMxm3ZBzHEo/BhgcJXvoqiccq8FGIoWEU9OrldJXiJAWeiHg1ay3/3fZfhswbwi/HfwGg7oFx7P46koRE12jju3dDnz6u5RV6vsvP6QJERHIr/kA84e+Fc9fsu/jl+C/UKV+HOffPIfGbUZxPD7uLEhIgMtKZOsUz6AhPRLzO7pO7iVwYSdzGOADKFyvP2PCxPNP8GQL9A7lrb+bv27OnAIsUj6PAExGvcer8KaYsn8LL37/MhdQLFPEvwvOtnmd4h+GUCSpzabmQENdpzCuFhBRcreJ5dEpTRDxecmoyM1bPoPbrtZmyYgoXUi/wYMMH2TZgG1O7T70s7ACioiA4+PJ1BAe75l8pLg5CQ8HPz/UYF5dfeyFOU+CJyCWe9uVvreWLbV/QaGYj+n/VnyMJR2hXrR3fP/k9H97zIaFlQjN9X69eMGsWVK8OxrgeZ836Y4OVuDhXY5bdu8Ha3xu3OL3fkj+MtTbvKzHmFuBVwB94y1o75YrXewPRwP70WdOstW9dbZ1hYWE2Pj4+z7WJSPZc/PJPSPh9XnBw5kFRENYeXMuguYNYtGsRALXK1WJqt6ncVe8ut/WQEhqa+anP6tVh1y63bELczBizxloblpv35vkIzxjjD0wHbgVuAh40xtyUyaKzrbVN0v9dNexEJH9ldiQXGXl52IEzLRv3ntrLo589SvNZzVm0axHlipXjlf97hc39NnN3/bvd2h1YVo1Y1LilcHJHo5WWwHZr7U4AY8zHQE/gJzesW0Tc7MojuYun8a4Mu4sK6sv/zIUzTFk+hZe+f4nzKecp4l+E51o+R2SHSMoWK5sv21TjFt/ijmt4VYCMjYD3pc+70j3GmA3GmE+NMRqHQ8QhWR3J+ftnvnx+f/mnpKUwM34mtV6vxaTlkzifcp77GtzHlv5biOkRk29hBzlr3CLer6BuS/gC+Mhae8EY8wzwHtDlyoWMMX2APgAh+hNLJF9kdcSWmur6sr/yGl5+fflba/nql68YPG8wW45uAaBN1TbE9oilTbU2+bPRK1y8NhkZ6fq5hIS49le9sRRO7jjC2w9kPGKryu+NUwCw1h6z1l5In3wLaJ7Ziqy1s6y1YdbasIoVK7qhNBG5UlZ/S15syXitlo3usO63dXT/oDt3fHQHW45uoWbZmnxy7yeseGJFgYXdRb16uRqopKW5HhV2hZc7jvBWA7WNMTVwBd0DwEMZFzDGVLbWHkyfvBPY4obtikguREVl3hrz4pFNfn7h7z+9n5GLRvLeuvewWMoGlWVUx1H0a9GPogFF82/DIrgh8Ky1KcaYAcC3uG5LeNtau9kYMx6It9Z+Dgw0xtwJpADHgd553a6I5I4Tp/HOJp3lxRUvEvNdDIkpiQT6BdK/RX9GdRpFuWLl8m/DIhm45T68/KD78ES8X2paKm//+DajFo3i0LlDANxT/x6mdJtCrXK1HK5OvFFe7sNTX5oiki++2f4Ng+cNZtPhTQC0qtKK2B6xtAtp53Bl4qsUeCLiVhsObWDwvMHM3TEXgNAyoUzpOoX7Gtzn1pvGRXJKgScibnHgzAFGLRzFO+vewWIpXbQ0IzuOZEDLAQQFBDldnogCT0Ty5lzSOaK/iyb6u2gSkhMI8AugX1g/RnUaRYXgCk6XJ3KJRksQKUCeNhpBXlxskFL79dqMWzKOhOQE7qp3Fz/1+4lXb31VYSceR0d4IgUkqz4swftudp63Yx6D5g1iw6ENALS4oQWxPWLpUL2Dw5WJZE23JYgUkMIwFM2mw5sYPG8w32z/BoCQ0iFM7jqZBxo+gJ/RCSPJf7otQcQLePNQNL+d/Y3Ri0bzjx//QZpNo1TRUoxoP4K/tv6rGqSI19CfZCIFJKs+LLPbT7oT1/8SkhOYsGQCtV6rxZtr38Rg6N+iP9uf287Q9kMVduJVFHgiBSSzoWiKFIGzZ68dYhev/+3eDdb+fv0vv0IvNS2Vd9e9S+3XazN68WjOJZ/jzrp3srnfZqbdNo2KxdW5u3gfXcMTKUAXRxbfswfKlYPTpyE5+ffXg4MzH6GgIK//Ldi5gEHzBrHut3UANKvcjNgesYSHhrt3QyK5kJdreAo8EYfkJMT8/FxHdlcyxjWsjTv8dOQnhswbwpe/fAlAtVLVmNR1Eg/d/JAapIjHyEvg6bdYxCE5acSS1+t/V3Po7CGe/d+zNHqjEV/+8iUli5RkUpdJbBuwjYcbPezxYVeY7m2U/KVWmiIOCQnJ/AgvsxC72hh2uZWYnMjL37/MlOVTOJN0Bn/jz7NhzzI2fCzXFb8u9ysuQIXp3kbJf579p5tIIZZZI5asQqxXL/eNRp5m0/hg/QfUmVaHyIWRnEk6wx117mDjsxuZcfsMrwk7cF0PzfhHALimIyOdqUc8m47wRByS04FY3TEa+eJdi4mYG8Hag2thw4MELo4h5URlNoYY1qZBfS87KvLmexul4OkIT8RBvXq5Gqikpbkeswq0vF6n2np0K3d+dCed3+vM2oNrKfNzf4p89T7Jx2/AWpPvtznkl/y8timFjwJPxMPl5R68I+eO0P/L/jSc0ZAvfv6C4oHFmdB5AiWXv0bS+ctP8HjjqcCcnBYW0W0JIh4uN/fgJSYn8uoPrzJp2STOJJ3Bz/jxVNOnGNd5HJVKVCqQ2xwKSsZ7G691Wli8n/rSFCnEcnKdKs2m8dHGjxixcAR7TrkWuLXWrUR3j6bBdQ0uLZeTFqKezh3XNsU36JSmiIfL7nWqpbuX0vqt1jz82cPsObWHRtc3Yu7Dc+nFV9zessFl1/90KlB8kQJPxMNdK5x+PvYzd82+i07vdmL1gdVULlGZt+98m7V91nL4++6ZXv8D993mIOItdEpTxMNldfvC/911lIFfj+eN+DdISUshODCYoe2GEtEmgjmfFOfGuzI/bXmxccrVWoWKFEYKPBEvkPE61fmU87z+w+vUei2KUxdOYTA82fRJJnSeQOWSlf/Q+0hmdJ+a+CIFnoiXsNYye/Nshi8Yzq6TuwDocWMPortH0+j6RpeWy6z3kSt5Y+MUkbxS4Il4geV7lhMxN4JV+1cB0PC6hsR0j+H/av3fH5a91tGbGqeIr1LgiXiw7ce3M3T+UP6z5T8AVCpRiQmdJ/B4k8fx9/PP9D1Z3XJwkRqniK9S4Il4oOOJxxm/ZDwzVs8gOS2ZYgHFGNx2MIPbDaZEkRJXfW9UFDzySOY3llevrrAT36XAE/EgF1IuMH31dCYsncDJ8ycxGB5v8jgTOk+gSqkq2VpHr16wYgXMnHl56OlUpvg6BZ6IB7DW8ulPnzJswTB2ntgJQLea3YjpHkPjSo1zvL4ZM6BdO3W5JZKRAk/EYSv3riRibgQr960E4KaKNxHdPZpba92KMSbX61WXWyKXU+CJOGTniZ0Mmz+MT376BIDril/H+PDxPNnsSQL89F9TxN30v0qkgJ1IPMHEpRN5fdXrJKclExQQRESbCIa2G0rJoiWdLk+k0FLgiRSQpNQkZqyewfgl4zlx/gQGw6ONHyWqSxRVS1V1ujyRQk+BJ5LPrLX8Z8t/GDp/KDtO7ACgc2hnYnrE0KxyM4erE/EdCjyRfPTDvh+ImBvBir0rAKhXoR7R3aO5vfbteWqQIiI5p8ATyQe/nviVEQtH8PGmjwGoGFyRceHjeLr502qQIuIQ/c8TcaOT508yadkkXv3hVZJSkwgKCOKF1i8wrP0wShUt5XR5Ij5NgSfiBsmpycyMn8m4JeM4lngMgIcbPUxUlyhCSmtoAhFPoBHPRfLAWsucrXNoMKMBA78ZyLHEY3Ss3pHVT6/mg7s+KPCwi4uD0FDw83M9xsUV6OZFPJqO8ERyafX+1QyaN4ilu5cCUKd8HV7s9iJ31r3TkQYpVw78unu3axrU44oIgLGZdanuAcLCwmx8fLzTZYj8we6TuxmxcAQfbvwQgPLFyjM2fCzPNH+GQP9Ax+oKDc18WKDq1WHXroKuRiR/GGPWWGvDcvNeHeGJZNOp86eYvHwyr3z/ChdSL1DUvyh/bfVXRnQYQemg0k6Xl+XAr9caEFbEV+gansg1JKcmM33VdGq9XoupK6ZyIfUCDzZ8kK0DtjK1+1SPCDtwjYiQk/lO0/VGKWgKPJEsWGv5fNvn3PzGzQz4egBHE47SPqQ9Pzz1Ax/e8yGhZULdur28BkBUlGvMu4w8dQy8i9cbd+92jdl38XqjQk/yk1sCzxhzizFmmzFmuzFmWCavFzXGzE5//QdjTKg7tiuSX9YcWEOX97vQ8+OebDu2jVrlavGf+/7D0t5LaVmlpdu3544A6NULZs1yXbMzxvU4a5ZnNliJjPy9cc1FCQmu+SL5Jc+NVowx/sDPQHdgH7AaeNBa+1OGZfoBjay1fY0xDwB3WWvvv9p61Wil8IqL89yBSfee2kvkwkg+2PABAOWKlWNMpzH0DetLEf8i+bZdX2tw4ud3+WjsFxkDaWkFX494D6cbrbQEtltrd6YX8zHQE/gpwzI9gbHpzz8FphljjPXUJqKSbzy16fzpC6eZunwqL33/EudTzlPEvwgDWw5kRIcRlC1WNt+372sNTkJCMg94T73eKIWDO05pVgH2Zpjelz4v02WstSnAKaC8G7YtXsbTTmWlpKUwM34mtV+vzaTlkzifcp77G9zPlv5biO4RXSBhB97X4CSvvOl6oxQeHtVoxRjTxxgTb4yJP3LkiNPlSD7wlCMZay1f/vwljd5oxLNfPsvhc4dpW60tK59cycd/+ZiaZWsWaD2+FgDedL1RCg93BN5+oFqG6arp8zJdxhgTAJQGjl25ImvtLGttmLU2rGLFim4oTfJbTlsWesKRzLrf1tHtg27c8dEdbDm6hZpla/LJvZ+w/PHltK7auuAKycAXA6BXL9f1ybQ012Nh3lfxDO64hrcaqG2MqYEr2B4AHrpimc+Bx4CVwF+Ahbp+5/369YOZM39vfJCd63FRUZdfw4OCO5LZf3o/kQsjeX/9+1gsZYPKMqrjKPq16EfRgKL5X8A19OqlL32R/JTnwLPWphhjBgDfAv7A29bazcaY8UC8tfZz4B/AB8aY7cBxXKEoXiwu7vKwu+ji9bisvrgvzi/IVppnLpwh+rtoYr6LITElkUC/QAa0HMDIjiMpV6xc/m1YRDyK+tKUXMmqGT14TtPylLQU3vnxHUYtGsWhc4cA+MtNf2FK1yncWO5Gh6sTkdxw+rYE8UFXa2TidMtCay3fbP+GwfMGs/nIZgBaVWlFbI9Y2oW0c7Y4EXGMAk9yJav7qIxxtmXh+t/WM3jeYObtnAdAaJlQpnSdwn0N7nNkyB4R8RwedVuCeI/MmtEbA337OtPw4sCZAzz53ydp+vemzNs5j9JFSxPTPYat/bdyf8P7FXYioiM8yR0nGp9k5mzSWWK+iyH6u2gSkhMI8Augf4v+jOo4ivLB6ttARH6nwJNcc7IZfWpaKu+ue5dRi0Zx8OxBAO6ufzdTuk6hdvnazhQlIh5NgSdeZ+6OuQyaO4iNhzcC0OKGFsT2iKVD9Q4OVyYinkyBJ15j0+FNDJ43mG+2fwNA9dLVmdx1Mvc3vB8/o8vRInJ1CjzxeL+d/Y3Ri0bzjx//QZpNo1TRUkR2iGRgq4EEBQQ5XZ6IeAkFnnisc0nneGnlS0xdMZVzyefwN/4MaDGA0Z1GU7G4+loVkZxR4InHSU1L5YMNHxC5MJIDZw4A0LNuT6Z2m0rdCnUdrk5EvJUCTzzK/J3zGTR3EOsPrQegeeXmxPSIITw03NnCRMTrKfDEI2w+vJkh84fw1S9fAVCtVDUmdZ3EQzc/pAYpIuIWCjxx1KGzhxizeAxvrn2TNJtGySIlGd5+OM+3fp5igcWcLk9EChH96SyOSEhOIGppFLVer8Xf1/wdg6FfWD+2D9zO8A7DPT7scjrwrYg4T0d4UqDSbBr/3PBPIhdGsu/0PgD+VOdPTO02lfoV6ztcXfbExV0+iG12Br4VEedpPDwpMIt+XUTE3Ah+/O1HAJpWakpMjxi61OjicGU5k9VYgNWrw65dBV2NiG/ReHji0bYe3cqQeUP44ucvAKhSsgqTuk7i4UYPe2WDlKzGArzaGIEi4jwFnuSbw+cOM27xOP6+5u+k2lRKFCnBsHbDeKHNCwQHBl97BR4qq7EAnR74VkSuToEnbpeYnMirP7zKpGWTOJN0Bj/jxzPNn2Fs+FgqlajkdHl5FhV1+TU8cI0N6OTAtyJybd53PkkcdbXWiWk2jbgNcdSbXo/hC4ZzJukMt9W+jQ19NzDzjpmFIuzA1TBl1izXNTtjXI+zZqnBioin0xGeZNvVWidWa7+UiLkRxB9wNTRqdH0jYnvE0q1mN4eqzV9OjgUoIrmjwJNsi4y8/DQeuKaffv4wiQM6AXBDyRuY2HkijzZ+FH8/fweqFBHJnAJPsi2rVoiJRytQPLA4Q9oNIaJNBMWLFC/YwkREskGBJ9mWVevEEhWP8/Nzv1C5ZOWCL0pEJJvUaEWyxVrL7X2/wwRefk4zqFgaM1+uoLATEY+nwJNrWr5nOa3/0ZoZF9ph//QUgeUOYIylenV4600/Nd4QEa+gU5qSpV+O/cKwBcP4z5b/AFCpRCUmRnal9yfX4+9nHK5ORCRnFHjyB8cSjjFh6QSmr55OSloKwYHBDG47mEFtB1GiSAmnyxMRyRUFnlxyIeUC01ZNY+KyiZw8fxKD4YkmTzC+83iqlKridHkiInmiwBOstXzy0ycMmz+MX0/+CkC3mt2I6R5D40qNHa5ORMQ91GjFx3239zvavt2W+z+9n19P/spNFW/iq4e+Yu7Dcz0q7DTgqojklY7wfNSO4zsYvmA4n/z0CQDXFb+OCZ0n8ETTJwjw86xfCw24KiLuoAFgfczxxONMXDqRaaumkZyWTLGAYkS0iWBIuyGULFrS6fIypQFXReQiDQAr15SUmsT0VdOZsHQCJ86fwGB4rPFjTOwykaqlqjpd3lVpwFURcQcFXiFnreXfW/7NsPnD2HFiBwBdanQhpnsMTSs3dbi67NGAqyLiDgq8Quz7fd8TMTeC7/Z+B0D9CvWJ7h7NbbVvwxjvuXFcA66KiDso8AqhX0/8yvAFw5m9eTYAFYMrMr7zeJ5q9pTHNUjJjosNUyIjXacxQ0JcYacGKyKSE9737SdZOnn+JFFLo3ht1WskpSYRFBDE31r/jaHth1KqaCmny8sTDbgqInmlwCsEklKTmBk/k3FLxnE88TgAjzR6hIldJhJSWhe6RERAgefVrLXM2TqHofOH8svxXwDoVL0TsT1iaX5Dc4erExHxLAo8L7V6/2oi5kawbM8yAOqUr0N092j+VOdPXtUgRUSkoCjwvMzuk7sZsXAEH278EIAKwRUY22ksfZr3IdA/0OHqREQ8l/rS9BKnzp9i6Lyh1J1Wlw83fkhR/6IMbTeU7c9tp3/L/gUedurbUkS8jY7wPFxyajJ/X/N3xi0Zx9GEowA8dPNDRHWJIrRMqCM1qW9LEfFG6kvTQ1lr+eLnLxgybwjbjm0DoENIB2J7xNKiSgtHa1PfliLiFPWlWcisObCGQfMGsXjXYgBql6vNi91fpGfdnh7RIEV9W4qIN8pT4BljygGzgVBgF3CftfZEJsulAhvTJ/dYa+/My3YLq72n9jJi4Qj+ueGfAJQrVo4xncbQN6wvRfyLOFzd79S3pYh4o7w2WhkGLLDW1gYWpE9nJtFa2yT9n8LuCqcvnGbEghHUmVaHf274J0X8izCozSB2DNzBwFYDPSrswNWtV3Dw5fPUt6WIeLq8ntLsCYSnP38PWAwMzeM6fUZKWgpvrnmTMYvHcCThCAD3N7ifyV0nU6NsDYery5r6thQRb5SnRivGmJPW2jLpzw1w4uL0FculAOuAFGCKtXZOFuvrA/QBCAkJab47s/NmhYC1li9/+ZLB8waz9ehWANpWa0tsj1haV23tcHUiIp4rXxutGGPmA5UyeSky44S11hpjskrP6tba/caYmsBCY8xGa+2OKxey1s4CZoGrleY1q/dCPx78kUHzBrHw14UA3Fj2RqZ2m8rd9e/2iAYpIiKF1TUDz1rbLavXjDGHjDGVrbUHjTGVgcNZrGN/+uNOY8xioCnwh8ArzPad3sfIhSN5f/37WCxlg8oyutNo+rXo53HX6ERECqO8XsP7HHgMmJL++N8rFzDGlAUSrLUXjDEVgHbAi3ncrtc4c+EML654kdiVsSSmJBLoF8hzLZ8jsmMk5YqVc7o8ERGfkdfAmwL8yxjzJLAbuA/AGBMG9LXWPgXUB/5ujEnD1Sp0irX2pzxu1+OlpKXw9o9vM3rRaA6dOwTAvTfdy+Suk7mx3I0OVyci4nvyFHjW2mNA10zmxwNPpT//Drg5L9vxJtZavtn+DYPnDWbzkc0AtK7amtgesbSt1tbh6kREfJd6WnGj9b+tZ9C8QczfOR+AGmVqMKXbFO696V41SBERcZgCzw0OnDnAyIUjeXfdu1gsZYLKMLLDSAa0HEDRgKJOlyciIijw8uRs0lmiV0QTszKGhOQEAvwC6N+iP6M6jqJ8cHmnyxMRkQwUeLmQmpbKO+veYdSiUfx29jcA7q5/N1O6TqF2+doOVyciIplR4OXQt9u/ZdC8QWw6vAmAllVaEtsjlvYh7R2uTERErkaBl00bD21k8LzBfLvjWwCql67OlG5TuK/BffgZDRwvIuLpFHjXcPDMQUYvGs3b694mzaZRumhpIjtE8lyr5wgKCHK6PBERySYFXhbOJZ0jdmUsL654kXPJ5wjwC6BfWD/GhI+hQnAFp8sTEZEcUuBdITUtlffXv8/IRSM5cOYAAD3r9mRqt6nUrVDX4epERCS3FHgZzN85n0FzB7H+0HoAmlduTmyPWDqFdnK4MhERySsFHrD58GYGzxvM19u/BqBaqWpM7jqZB29+UA1SREQKCZ8OvENnDzF60Wje+vEt0mwaJYuUZESHEfy11V8pFljM6fJERMSNfDLwEpITeGnlS0xdMZWzSWfxN/6XGqRcV/w6p8sTEZF84FOBl2bT+GD9B0QujGT/mf0A/KnOn3ix+4vUq1DP4epERCQ/+UzgLfp1ERFzI/jxtx8BaFa5GTHdY+hco7PDlYmISEEo9IG35cgWhswfwv9+/h8AVUpWYVLXSTzc6GE1SBER8SGFNvAOnzvM2MVjmbVmFqk2lRJFSjCs3TBeaPMCwYHBTpcnIiIFrNAFXmJyIq98/wqTl0/mTNIZ/IwfzzR/hnHh47i+xPVOlyciIg4pNIGXZtP4cOOHjFgwgr2n9wJwW+3biO4ezU0Vb3K4OhERcVqhCLwlu5YQMTeCNQfXAND4+sbE9IihW81uDlcmIiKewqsDb9vRbQyZP4TPt30OwA0lbyCqSxSPNHoEfz9/h6sTERFP4pWBd+TcEcYvGc/MNTNJSUuheGBxhrYbyt/a/I3iRYo7XZ6IiHggrwq88ynnee2H14haFsXpC6fxM3483expxnceT6USlZwuT0REPJhXBF6aTWP2ptkMXzCc3ad2A3BLrVuI7h5Nw+saOlydiIh4A48PvGW7lxExN4LVB1YDcPN1NxPTI4YeN/ZwuDIREfEmHht4F1IucPfsu/ls62cAVCpRiYmdJ9K7SW81SBERkRzz2MDbdGQTm7ZuIjgwmMFtBzOo7SBKFCnhdFkiIuKlPDbwsPBEkyeY0GUCN5S8welqRETEyxlrrdM1ZKpB4wZ28/rNTpchIiIexBizxloblpv3euxwARpxXERE3MljA09ERMSdFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuITFHgiIuIT8hR4xph7jTGbjTFpxpgsB+QzxtxijNlmjNlujBmWl22KiIjkRl6P8DYBdwNLs1rAGOMPTAduBW4CHjTG3JTH7YqIiORIQF7ebK3dAmCMudpiLYHt1tqd6ct+DPQEfsrLtkVERHKiIK7hVQH2Zpjelz5PRESkwFzzCM8YMx+olMlLkdba/7qzGGNMH6BP+uQFY8wmd67fARWAo04XkUfaB8+gfXCet9cPhWMf6ub2jdcMPGttt9yuPN1+oFqG6arp8zLb1ixgFoAxJt5am2VDGG+gffAM2gfP4O374O31Q+HZh9y+tyBOaa4GahtjahhjigAPAJ8XwHZFREQuyettCXcZY/YBbYAvjTHfps+/wRjzFYC1NgUYAHwLbAH+Za3dnLeyRUREciavrTQ/Az7LZP4B4LYM018BX+Vw9bPyUpuH0D54Bu2DZ/D2ffD2+sHH98FYa91ZiIiIiEdS12IiIuITPCbwCkM3ZcaYcsaYecaYX9Ify2axXKoxZl36P49owHOtn6sxpqgxZnb66z8YY0IdKPOqsrEPvY0xRzL87J9yos6sGGPeNsYczup2HOPyWvr+bTDGNCvoGq8lG/sQbow5leEzGF3QNV6NMaaaMWaRMean9O+jv2ayjEd/DtncB0//HIKMMauMMevT92FcJsvk/DvJWusR/4D6uO6vWAyEZbGMP7ADqAkUAdYDNzlde4b6XgSGpT8fBkzNYrmzTtea058r0A+Ymf78AWC203XnYh96A9OcrvUq+9ARaAZsyuL124CvAQO0Bn5wuuZc7EM48D+n67xK/ZWBZunPSwI/Z/J75NGfQzb3wdM/BwOUSH8eCPwAtL5imRx/J3nMEZ61dou1dts1FrvUTZm1Ngm42E2Zp+gJvJf+/D3gz86VkiPZ+blm3LdPga7mGn3KFTBP/924JmvtUuD4VRbpCbxvXb4HyhhjKhdMddmTjX3waNbag9batenPz+BqWX5lz1Ae/Tlkcx88WvrP9mz6ZGD6vysbnOT4O8ljAi+bPL2bsuuttQfTn/8GXJ/FckHGmHhjzPfGmD8XTGlXlZ2f66VlrOtWk1NA+QKpLnuy+7txT/ppqE+NMdUyed2Tefrvf3a1ST9V9bUxpoHTxWQl/RRZU1xHFxl5zedwlX0AD/8cjDH+xph1wGFgnrU2y88hu99JebotIadMAXZTll+utg8ZJ6y11hiTVRPY6tba/caYmsBCY8xGa+0Od9cqf/AF8JG19oIx5hlcfx12cbgmX7MW1+//WWPMbcAcoLazJf2RMaYE8G/geWvtaafryY1r7IPHfw7W2lSgiTGmDPCZMaahtTZP3U0WaODZAuymLL9cbR+MMYeMMZWttQfTT3EczmId+9MfdxpjFuP6C8zJwMvOz/XiMvuMMQFAaeBYwZSXLdfcB2ttxnrfwnXN1Zs4/vufVxm/eK21XxljZhhjKlhrPaZ/R2NMIK6giLPW/ieTRTz+c7jWPnjD53CRtfakMWYRcAuuIekuyvF3kred0vT0bso+Bx5Lf/4Y8IejVmNMWWNM0fTnFYB2OD9UUnZ+rhn37S/AQpt+tdhDXHMfrrjOcieuaxve5HPg0fRWgq2BUxlOoXsFY0yli9dZjDEtcX0HecwfTum1/QPYYq19KYvFPPpzyM4+eMHnUDH9yA5jTDGgO7D1isVy/p3kdGucDC1u7sJ1LvwCcAj4Nn3+DcBXGZa7DVerox24ToU6XnuG2soDC4BfgPlAufT5YcBb6c/bAhtxtSLcCDzpdN1Z/VyB8cCd6c+DgE+A7cAqoKbTNediHyYDm9N/9ouAek7XfEX9HwEHgeT0/wtPAn2BvumvG1yDKe9I/93JtDWzh+/DgAyfwfdAW6drvqL+9rgaR2wA1qX/u82bPods7oOnfw6NgB/T92ETMDp9fp6+k9TTioiI+ARvO6UpIiKSKwo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCQo8ERHxCf8PkaUw7m3DMMEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot dataset\n", "plt.figure(figsize=(7,5))\n", "plt.xlim(xmax = 3, xmin =-1)\n", "plt.ylim(ymax = 1.5, ymin = -1)\n", "plt.plot(dataset_linear['X_test'], dataset_linear['y_test'], color='green', linewidth=2, label=\"Ground Truth\")\n", "plt.plot(dataset_linear['X_train'], dataset_linear['y_train'], 'o', color='blue', label='Training points')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "dataset_linear['ALPHA'] = 2.0\n", "dataset_linear['BETA'] = 1/(2.0*sigma**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use the linear basis function:\n", " $\\phi:x \\rightarrow (1,x)$\n", "\n", "Design matrix $\\Phi$ defined on training set is:\n", "$$ \\Phi=\n", " \\begin{bmatrix}\n", " 1 & x_1 \\\\\n", " ... & ...\\\\\n", " 1 & x_n\n", " \\end{bmatrix}\n", "$$\n", "\n", "**Question 1.1: Code linear basis function**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "#TO DO: Define basis function\n", "\n", "def phi_linear(x):\n", " \"\"\" Linear Basis Functions \n", " \n", " Args:\n", " x: (float) 1D input\n", " \n", " Returns:\n", " (array) linear features of x\n", " \"\"\"\n", " # TO DO\n", " return 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 1.2: Recall closed form of the posterior distribution in linear case. Then, code and visualize posterior sampling. What can you observe?**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TO DO: Code and visualize posterior sampling by completing code below\n", "\n", "plt.figure(figsize=(15,10))\n", "for count,n in enumerate([0,1,2,10,len(dataset_linear['X_train'])]):\n", " cur_data = dataset_linear['X_train'][:n]\n", " cur_lbl = dataset_linear['y_train'][:n]\n", " meshgrid = np.arange(-1, 1.01, 0.01)\n", " w = np.zeros((2,1))\n", " posterior = np.zeros((meshgrid.shape[0],meshgrid.shape[0]))\n", " \n", " # TO DO: code mu_n and sigma_N\n", "\n", " # Compute values on meshgrid\n", " for i in range(meshgrid.shape[0]):\n", " for j in range(meshgrid.shape[0]):\n", " w[0,0] = meshgrid[i]\n", " w[1,0] = meshgrid[j]\n", " posterior[i,j] = np.exp(-0.5* np.dot(np.dot((w-mu_N.reshape(2,1)).T, np.linalg.inv(sigma_N)) , (w-mu_N.reshape(2,1)) ) ) \n", " Z = 1.0 / ( np.sqrt(2*np.pi* np.linalg.det(sigma_N) ) )\n", " posterior[:,:] /= Z\n", " \n", " # Plot posterior with n points\n", " plt.subplot(231+count)\n", " plt.imshow(posterior, extent=[-1,1,-1,1])\n", " plt.plot(0.5,0.3, '+', markeredgecolor='white', markeredgewidth=3, markersize=12)\n", " plt.title('Posterior with N={} points'.format(n))\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 1.3: Recall and code closed form of the predictive distribution in linear case.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TO DO: Code closed form solution according to the following requirements defined below\n", "\n", "def closed_form(func, X_train, y_train, alpha, beta):\n", " \"\"\"Define analytical solution to Bayesian Linear Regression, with respect to the basis function chosen, the\n", " training set (X_train, y_train) and the noise precision parameter beta and prior precision parameter alpha chosen.\n", " It should return a function outputing both mean and std of the predictive distribution at a point x*.\n", "\n", " Args:\n", " func: (function) the basis function used\n", " X_train: (array) train inputs, sized [N,]\n", " y_train: (array) train labels, sized [N, ]\n", " alpha: (float) prior precision parameter\n", " beta: (float) noise precision parameter\n", " \n", " Returns:\n", " (function) prediction function, returning itself both mean and std\n", " \"\"\"\n", " \n", " #TO DO\n", " def f_model(x) :\n", " return 0\n", "\n", " return f_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f_pred = closed_form(phi_linear, dataset_linear['X_train'], dataset_linear['y_train'], \n", " dataset_linear['ALPHA'], dataset_linear['BETA'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 1.4: Based on previously defined ``f_pred()``, predict on the test dataset. Then visualize results using ``plot_results()`` defined at the beginning of the notebook.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TO DO : predict on test dataset and visualize results\n", "\n", "# You should use the following parameters for plot_results\n", "# xmin=-10, xmax=10, ymin=-6, ymax=6, stdmin=0.05, stdmax=1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 1.5: Analyse these results. Why predictive variance increases far from training distribution? Prove it analytically in the case where $\\alpha=0$ and $\\beta=1$.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Bonus Question: What happens when applying Bayesian Linear Regression on the following dataset?**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create training and test points\n", "sigma = 0.2\n", "dataset_hole = {}\n", "dataset_hole['X_train'] = np.concatenate(([np.random.uniform(-3, -1, 10), np.random.uniform(1, 3, 10)]), axis=0)\n", "dataset_hole['y_train'] = f_linear(dataset_hole['X_train'], noise_amount=1,sigma=sigma)\n", "dataset_hole['X_test'] = np.linspace(-12,12, 100)\n", "dataset_hole['y_test'] = f_linear(dataset_hole['X_test'], noise_amount=0,sigma=sigma)\n", "dataset_hole['ALPHA'] = 2.0\n", "dataset_hole['BETA'] = 1/(2.0*sigma**2)\n", "\n", "# Plot dataset\n", "plt.figure(figsize=(7,5))\n", "plt.xlim(xmin =-12, xmax = 12)\n", "plt.ylim(ymin = -7, ymax = 6)\n", "plt.plot(dataset_hole['X_test'], dataset_hole['y_test'], color='green', linewidth=2, label=\"Ground Truth\")\n", "plt.plot(dataset_hole['X_train'], dataset_hole['y_train'], 'o', color='blue', label='Training points')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TO DO: Define f_pred, predict on test points and plot results\n", "\n", "# You should use the following parameters for plot_results\n", "# xmin=-12, xmax=12, ymin=-7, ymax=6, stdmin=0.0, stdmax=0.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part II: Non Linear models\n", "\n", "We now introduce a more complex toy dataset, which is an increasing sinusoidal curve. The goal of this part is to get insight on the importance of the chosen basis function on the predictive variance behavior." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate sinusoidal toy dataset\n", "def f_sinus(x, noise_amount,sigma=0.2):\n", " y = np.sin(2*np.pi*x) + x\n", " noise = np.random.normal(0, sigma, len(x))\n", " return y + noise_amount * noise\n", "\n", "# Create training and test points\n", "sigma=0.2\n", "nbpts=50\n", "dataset_sinus = {}\n", "dataset_sinus['X_train'] = np.random.uniform(0, 1, nbpts)\n", "dataset_sinus['y_train'] = f_sinus(dataset_sinus['X_train'], noise_amount=1,sigma=sigma)\n", "dataset_sinus['X_test'] = np.linspace(-1,2, 10*nbpts)\n", "dataset_sinus['y_test'] = f_sinus(dataset_sinus['X_test'], noise_amount=0,sigma=sigma)\n", "\n", "dataset_sinus['ALPHA'] = 0.05\n", "dataset_sinus['BETA'] = 1/(2.0*sigma**2)\n", "\n", "# Plot dataset\n", "plt.figure(figsize=(7,5))\n", "plt.xlim(xmin =-1, xmax = 2)\n", "plt.ylim(ymin = -2, ymax = 3)\n", "plt.plot(dataset_sinus['X_test'], dataset_sinus['y_test'], color='green', linewidth=2,\n", " label=\"Ground Truth\")\n", "plt.plot(dataset_sinus['X_train'], dataset_sinus['y_train'], 'o', color='blue', label='Training points')\n", "plt.legend()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### II.1 Polynomial basis functions\n", "\n", "We will first use polynomial basis functions:\n", "$$\\phi:x \\rightarrow (\\phi_0,\\phi_1,...,\\phi_{D-1})$$\n", "where $\\phi_j = x^j$ for $j \\geq 0$ and $D \\geq 0$\n", "\n", "\n", "Design matrix $\\Phi$ defined on training set is:\n", "$$ \\Phi=\n", " \\begin{bmatrix}\n", " 1 & x_1 & x_1^2 &... &x_1^{D-1} \\\\\n", " ... & ... & ... & ...\\\\\n", " 1 & x_n & x_n^2 &... &x_n^{D-1}\n", " \\end{bmatrix}\n", "$$\n", "\n", "**Question 2.1: Code polynomial basis function**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define basis function\n", "def phi_polynomial(x):\n", " \"\"\" Polynomial Basis Functions\n", " \n", " Args:\n", " x: (float) 1D input\n", " \n", " Returns:\n", " (array) polynomial features of x\n", " \"\"\"\n", " D = 10\n", " # TO DO\n", " return 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 2.2 : Code and visualize results on sinusoidal dataset using polynomial basis functions. What can you say about the predictive variance?**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TO DO: Define f_pred, predict on test points and plot results\n", "\n", "# You should use the following parameters for plot_results\n", "# xmin=-1, xmax=2, ymin=-3, ymax=5, stdmin=0, stdmax=10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### II.2 Gaussian basis functions\n", "\n", "Now, let's consider gaussian basis functions:\n", "$$\\phi:x \\rightarrow (\\phi_0,\\phi_1,...,\\phi_M)$$\n", "where $\\phi_j = \\exp \\Big ( -\\frac{(x-\\mu_j)^2}{2s^2} \\Big )$ for $j \\geq 0$\n", "\n", "\n", "Design matrix $\\Phi$ defined on training set is:\n", "$$ \\Phi=\n", " \\begin{bmatrix}\n", " \\phi_0(x_1) & \\phi_1(x_1) &... &\\phi_M(x_1) \\\\\n", " ... & ... & ... & ...\\\\\n", " \\phi_0(x_n) & \\phi_1(x_n) &... &\\phi_M(x_n)\n", " \\end{bmatrix}\n", "$$\n", "\n", "**Question 2.3: Code gaussian basis function**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TO DO: Define Gaussian basis function\n", "MU_MIN = 0\n", "MU_MAX = 1\n", "M = 9\n", "\n", "def phi_gaussian(x) :\n", " \"\"\" Gaussian Basis Functions\n", " \n", " Args:\n", " x: (float) 1D input\n", " \n", " Returns:\n", " (array) gaussian features of x\n", " \"\"\"\n", " s = (MU_MAX-MU_MIN)/M\n", " return np.exp(-(x - np.arange(MU_MIN, MU_MAX, s)) ** 2 / (2 * s * s))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 2.4 : Code and visualize results on sinusoidal dataset using Gaussian basis functions. What can you say this time about the predictive variance?**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TO DO: Define f_pred, predict on test points and plot results\n", "\n", "# You should use the following parameters for plot_results\n", "# xmin=-1, xmax=2, ymin=-2, ymax=3, stdmin=0.05, stdmax=0.2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question 2.5: Explain why in regions far from training distribution, the predictive variance converges to this value when using localized basis functions such as Gaussians.**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }