{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "# Probabilistic curve fitting, express\n", "# uncertainty over target t using a gaussian\n", "# (aka normal) distribution\n", "%pylab inline\n", "def gaussian(x, mu, sigma):\n", " # vectorised for multiple x, sigma\n", " # returning x varying in the 1st dimension (down), sigma in the 2nd (right)\n", " numerator = np.exp(-0.5 * array(mat(x - mu).T * mat(1./sigma))**2)\n", " return 1./(np.sqrt(2*np.pi)*sigma)*numerator\n", "\n", "def cond_pdf_t(t, x, w, sigma):\n", " \"\"\"Probability of t given x, y, sigma\"\"\"\n", " return gaussian(t, polyval(w,x), sigma)\n", "\n", "\n", "from IPython.display import display, Math, Latex" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'x')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Do polynomial fitting like before, for now just\n", "# trusting the built-in polyfit\n", "# We sample from a sin curve with added gaussian noise\n", "# of std dev 0.3\n", "\n", "N = 15 # training points\n", "M = 4 # polynomial degree\n", "\n", "x = arange(0, 2*pi, (2*pi)/N)\n", "noise = 0.3 * randn(N)\n", "t = sin(x) + noise\n", "plot(x,t,'ro')\n", "\n", "w = polyfit(x, t, M)\n", "xs = arange(0,2*pi,.01)\n", "plot(xs, polyval(w, xs), 'g')\n", "ylabel(\"t\")\n", "xlabel(\"x\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, '$p(t|x,$w$,\\\\sigma)$')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEdCAYAAAAcmJzBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXycdbn//9c1k7VpkzRNmqRZmu5buodSVinQ0palgghUQVEEkYN6vvjVH3r8IXIOHlzOQT2iiIqIHDahQildECiLhS7pnnRN96TZuqVpsyfX94+ZQAhJM0kmuWe5no/HPGa577nv687ynns+9+f+3KKqGGOMCX4upwswxhjjHxboxhgTIizQjTEmRFigG2NMiLBAN8aYEGGBbowxIcIC3RhjQoQFujHGhAgLdBMQRORpEakQkbhuvu9BEVERyembyoKHiMz0/izucLoW4wwLdOM4EckDbgUeUdWzHUy/zxtUX+j/6oKHqm4EXgH+Q0QGOl2P6X8W6CYQ/AQ4Dfyuk+kzvPcb+6ecoPafQBrwLacLMf3PAt04SkTGAlcCL6pqbSezzQTOAHv7rbAgparrgV3A10XE7XQ9pn9ZoBu/EpELvc0jvxaRxSLyvohUiUidiKwXkTnt3vJVQIAXOljWT0VEgfHAQKDZu2wVkVt9qOUN77w3tHtdROQp77RHurFtA0WkQUTWtHs91rt9KiK3tZt2j/f1r/q6ng7W+YCIbBaR6jbb3/42tM3bngey8XxQmjBigW78rbV5ZC7wV+AE8HvgfeA8YLmIZLeZ/0qgGVjbwbI2An/xPv4A+HGb27s+1PJdoAVPm3LbvdVfAF8G/qCq9/uwHABU9QywHpglIoPaTLoIiPY+vqLd2y733r/l63paeUN6A57tdQGPA/8DlHlnaQT2AetUtaLNW1s/cOZ2d50myKmq3ezmtxvwJKBAFXBxu2m/9k77hfd5HNAEbD/H8u7yvueuTqY/6J2e08n0p7zTb/c+/4H3+QuAqwfb95D3/Ve3ee0/vdvxNnCkzesu4Biwr4c/yze96/opIG1ezwLq8AR6cgfvS/C+b73Tfw9269+b7aEbf2vdQ79XVf/ZbtofvfcTvfcZgBso9WF5m3pYzw/xhN+DInIv8DCwCrhNVVt6sLzWPe22e+JX4Pk28TKQ6T0uADANGELP9s7nepe7Bvi+epMaQFWP4PnGE+FdxyeoahWebc5uP82ENgt04zciEo0nrI8A/9vBLMe995He+yHe+5PnWOwMPHui23tSk6oWA78EhuNprvgAuEFVG3qyPOBDoBZvoItIgrfGt/DsocPHYd/a3PI23dd6jODRTj54qrz3nf0PnwCSe7BeE8Qs0I0/TcYT1qs6CaEc7/1h731rr5aYjhYmIhHeZe5Q1fpe1FXZ5vEdqlrT0wV5Pwj+CUz2tnFfhudbxluquhM4yseBfgWepo+eBPoleNr/V3YyPdN7X9TJ9Fg+/vmaMGGBbvyptXnkYCfTr/bev+G9bz2QN6SDecGztx9Dz5tbEJHFeA6Cth5I/HZPl9XG23h65lyOJ7Tr+fhA5GpgjvfbyiVAoX7ygKUvNbvwfKOo0I5PtErFc4D5gKru7+T9iXz88zVhwgLd+FNroCe2nyAiSXgOcBYDr3pfLsWz9zyuk+W1tg9v7kkxIrIQTy+ZQmAKnv7ZXxOR8T1ZXhtt29EvB9aoal2baUnAN/Ac9O12+zmevXqAQd5wbu97eP53f9/J+8fh+cDZ0oN1myBmgW78aab3/sa2Y7J4T0N/Fs+e+Ldbw897oO89IFlERnewvNY999PdLURELgZewvMBMk9VK4H/H8+BxE/1PW/TL/12Hxa/ETgFLAIm8cnQbn38fe/9p5pbulqX9+eyFc8HwuJ2770R+Fc8H06/6qS+2d771V1tiAktEU4XYEKDt707F08QJQDbRGQpnv7Zi4BheHprLGn31peBzwFX8en24NZT/R8WkVzgLJ4mjL91UctUYBmeA4dzVbUUQFVfEpF8YJGIXKKq77d5W+vOTVNX26qqLSLyrne7oE2gq+phEdkHjMLTv76j/vK+rOshYAnwZxGZj+dA83l4+u3vBRa2+VbQ3jzvul/tZLoJVU73m7RbaNyAqXiaCn4HjAFW4NmzrsbTn3peJ++LwtO+va6T6fcCu/F0w1Pg4XbTH6RNP3RgtHd5J4EpHSzvSu/8a9u9vtlb72Aft/ebfNzf3t1u2u+90zrbJp/WhecD4wM8H2Q1eJpQfgAMPMd7EvAcDH3F6b8Ju/X/Tbx/BMb0ioh8Bc9JRV9X1Se6+d7v4xmga4aqdqu9XEQeBH4EjFDVg915b5tlJOLpUvlfqvq9niwjUNYlIt/EcwLXpfrJbyAmDFgbuvGX1gOiPTmA+SierowP+a+cbrkET1/3/w7mdYlILJ62+5ctzMOTtaEbf5mBp9222ycAqWqdd1CrOSISpx101etLqvoanfSFD7J15QBP4BnuwIQhC3TTa96udVOBXdr5gbpzUtX38PR4MT2knhObHnS6DuMcC3TTa+o5K9SpK+S8470/5dD6jQkYjh0UTU5O1pycHEfWbYwxwWrjxo3HVDWlo2mO7aHn5OSQn5/v1OqNMSYoicihzqZZLxdjjAkRFujGGBMiLNCNMSZEWKAbY0yIsEA3xpgQYYFujDEhwgLdGGNChJ0paozX6bpG1u8/wYFjZ6lrbCZlUDQzhg9mzNCBiIjT5RnTJQt0E/YOH6/hV2/t5bWtR2lo/vS1rcemDuTrl47i+ukZuFwW7CZwWaCbsKWq/OH9/fzijT24Rbj5vCyunpLOhPR4YiPdlFbV8s+iY/zv2sN8529beXb9YX558zSykgY4XboxHXJsLJe8vDy1U/+NU+oam/nX57ewsrCMqyal8uPrcklL6HhUW1Xl5U0l/Pi1QgR4/LaZXDgquX8LNsZLRDaqal5H0+ygqAk7NQ1N3P7n9azaUcYPr57A47fO7DTMAUSEG2dmsuybF5MaH8Ptf97A27vK+7FiY3xjgW7CSkNTC3c/s4n1B07wy5un8bVLRvp8wHP4kDhe/PoFjE8bxDee2cS6/cf7uFpjuscC3YSVHy0t5L09lfzk+sksmpbR7fcPjoviqa/MImNwLHc+nc+h4/16cSVjzskC3YSNFzYc5rn1h/nGZaO4ZVZ2j5eTFBfFX74yC5dLuOvpjdQ2NPuxSmN6zgLdhIWDx87y4NIdXDImmf87b1yvl5eVNIBf3zKd3eXV/GzVLj9UaEzvWaCbkNfSonz3pa1EuIWf3zgVt5/6kl86NoXbL8zhz2sOstba000AsEA3Ie/JNQfYcPAkD1476Zy9WXrie/PHMXzIAL770lbO1jf5ddnGdJcFuglpB4+d5eerdnPlhKHcMKP7B0G7MiAqgl98firFJ2v52UprejHOskA3Ie0ny3cS4RIevn5yn43Hcl5OErfNHs5f1x5iT3l1n6zDGF9YoJuQ9cG+Y7yxo5x75owmNd6/TS3t/Z8rxzIwOoL/eH1nn67HmHOxQDchqblF+Y9lO8lIjOWOi0f0+foGx0XxrSvG8N6eSt7ZXdHn6zOmIxboJiQt2VTMjtLTfG/+OGIi3f2yzi9dkEPOkAE8/PpOmjoYtdGYvmaBbkJOfVMz//2PPUzNSuS6qcP6bb1RES7uXzCBvRVnWLKppN/Wa0yrLgNdRJ4UkQoRKehk+hdFZJv39oGITPV/mcb47qWNxZRW1fGduWP7/cIUV01KJTcjnsfeKbK9dNPvfNlDfwqYf47pB4DPqOoU4N+BJ/xQlzE90tjcwm9X72NaViKXjOn/IW5FhG9ePoZDx2t4bdvRfl+/CW9dBrqqvgecOMf0D1T1pPfpWiDTT7UZ021/31RCyalavnXFaMcuGzd3Qirj0wbxP28X0dzizPUGTHjydxv6HcAKPy/TGJ80Nbfw2DtF5GbEM2fcUMfqcLk8e+n7K8+yfHupY3WY8OO3QBeROXgC/f87xzx3iUi+iORXVlb6a9XGAPD69lIOHa/h3jljHL+o84LcNMYMHchjq4tw6qpgJvz4JdBFZArwR2CRqnY6SpGqPqGqeaqal5KS4o9VGwN4LhP3p38eYGRyHPMmpjpdDi6XcOelI9lVVs2H+2zgLtM/eh3oIpINLAFuU9U9vS/JmO7bdPgk24qr+MpFObj8NJpib103dRhD4qJ4cs0Bp0sxYcKXbovPAR8C40SkWETuEJG7ReRu7ywPAEOA34rIFhGxKz+bfvfkPw8SHxPBDTMC55h8TKSbL84ezlu7KjhwzK5sZPqeL71cFqtquqpGqmqmqv5JVR9X1ce907+mqoNVdZr31uHVqI3pK8Una1hRUMri87OJi45wupxPuHV2NhEu4SnbSzf9wM4UNUHv6Q8PISJ86YIcp0v5lKGDYrh26jD+trGYqtpGp8sxIc4C3QS1usZmnl9/mPmT0shIjHW6nA599aIR1DQ089LGYqdLMSHOAt0Etde3lXK6rolbZw93upRO5WYkMC0rkefWH7YujKZPWaCboPbs+sOMTI5j9sgkp0s5py+cn01RxRk2HDzZ9czG9JAFuglau8uq2XjoJItnZTt+IlFXrpmSzqDoCJ5bf9jpUkwIs0A3Qeu59YeJcrv43MzA6arYmQFREVw/I4PXt5dy8myD0+WYEGWBboJSbUMzSzYVMz83jaS4KKfL8ckXzs+moamFJZttrHTTNyzQTVBavt1zMHTxrGynS/HZ+LR4pmfbwVHTdyzQTVB6eVMxw4cMCPiDoe3dnJdFUcUZthVXOV2KCUEW6CbolJyq5cP9x7lhembAHwxtb+GUdKIjXCzZZH3Sjf9ZoJug88rmElTh+ukZTpfSbfExkcyblMbSrUdpaLJL1Bn/skA3QUVVeXlTMbNyksgeMsDpcnrkhhkZnKxpZPXuCqdLMSHGAt0Ela3FVeyvPMsNM4Jv77zVJaOTSR4Ybc0uxu8s0E1QWbKpmKgIFwunpDtdSo9FuF18dtow3t5VYX3SjV9ZoJug0dDUwtKtR5k3MZX4mEiny+mVG2Zk0tisLNt21OlSTAixQDdBY/XuCk7VNAbFmaFdmTgsnvFpg3h5k51kZPzHAt0EjSWbikkeGM0lo5OdLsUvPjcjky1HTrGv8ozTpZgQYYFugsLJsw28vauCz04bRoQ7NP5sF00bhkuwg6PGb0LjP8OEvOUFpTQ2K9cHce+W9obGx3DxmBRe3XLUhgIwfmGBboLCsq2ljEyJY2J6vNOl+NW1U9IpPlnLVhsKwPiBBboJeBWn61h74DjXThkWdKf6d2XepDSi3C5e22q9XUzvdRnoIvKkiFSISEEn00VEfi0iRSKyTURm+L9ME86Wby9FFa6dGrx9zzuTEBvJpWNTeH1bKS0t1uxieseXPfSngPnnmL4AGOO93QX8rvdlGfOx17aVMj5tEKOHDnK6lD5x7dR0yk7XkX/ILk9neqfLQFfV94AT55hlEfC0eqwFEkUk9HaljCNKTtWy8dBJrp06zOlS+syVE1KJiXTZSUam1/zRhp4BHGnzvNj72qeIyF0iki8i+ZWVlX5YtQl1r3tD7pogPtW/K3HREVw+fijLt5fS1GwjMJqe80egd3SUqsPGQFV9QlXzVDUvJSXFD6s2oW7ZtlKmZCYwfEic06X0qWunDOPYmQbWHTjXl2Fjzs0fgV4MZLV5ngnYd0fTa4eOn2VbcVVI7523mjN+KHFRbmt2Mb3ij0BfCnzJ29tlNlClqqV+WK4Jc8u2ef6Mrp4Suu3nrWIi3cydmMqKgjIardnF9JAv3RafAz4ExolIsYjcISJ3i8jd3lmWA/uBIuAPwD19Vq0JK69tPcrM4YPJSIx1upR+cc2UYZyqaeSfRcecLsUEqYiuZlDVxV1MV+Bf/FaRMcC+yjPsKqvmR9dOdLqUfnPJ2GQGxUSwbGspc8YNdbocE4TsTFETkFYWlAGwIDf0289bRUe4mTshlTd3lluzi+kRC3QTkFYUlDI9O5G0hBinS+lX83PTqKpt5MN9x50uxQQhC3QTcI6cqKGg5DQLctOcLqXfXTo2hbgoNyu831CM6Q4LdBNwwrG5pVVMpJs544fyRmEZzTa2i+kmC3QTcFYUlDJpWDxZSQOcLsURC3LTOX62gfV2kpHpJgt0E1DKqurYdPgU8yeFX3NLq8vGpRAd4WJlgZ3OYbrHAt0ElFWF3uaWyeEb6HHREXxmbAorC8tsSF3TLRboJqCsKChl9NCBITtUrq8WTk6n/HQ9m4+ccroUE0Qs0E3AOH6mnvUHToRl75b2Lp8wlEi3sGK7NbsY31mgm4Dxxo5yWtTTFzvcxcdEcvHoZFYUlNkFpI3PLNBNwFhRUEZ20oCQuxB0Ty3ITafkVC0FJaedLsUECQt0ExCqahr5oOgYC3LTQu5C0D01d2Iqbpewwnq7GB9ZoJuA8ObOcppa1Jpb2hgcF8XskUnW7GJ8ZoFuAsKKgjLSE2KYmpnodCkBZUFuOgeOnWV3ebXTpZggYIFuHHemvon39lZy1aQ0XC5rbmlr3qRURGDFdhvbxXTNAt04bvWuChqaWqy7YgeGDorhvOFJH41vY8y5WKAbx60sKCN5YBR5OUlOlxKQrspNY3d5NfsrzzhdiglwFujGUXWNzazeXcHciWm4rbmlQ60HilcVljtciQl0FujGUe/tqaSmodmaW84hIzGWKZkJNliX6ZIFunHUysIyEmIjuWDUEKdLCWjzc9PYWlxFyalap0sxAcwC3TimoamFN3eUc+WEVCLd9qd4Lq3DCb9RaAdHTed8+i8SkfkisltEikTk/g6mZ4vIahHZLCLbRGSh/0s1oWbt/uOcrmuyk4l8MDJlIGNTB9ql6cw5dRnoIuIGHgMWABOBxSIysd1sPwReVNXpwC3Ab/1dqAk9KwvLGBDl5pIxyU6XEhTm56az4eAJKqvrnS7FBChf9tBnAUWqul9VG4DngUXt5lGgdUSlBOCo/0o0oai5RXmjsJw544YSE+l2upygMH9SGqqeYRKM6YgvgZ4BHGnzvNj7WlsPAreKSDGwHPhmRwsSkbtEJF9E8isrK3tQrgkVmw6f5NiZeq6y5hafTUgfRHbSAGt2MZ3yJdA76hzcfqSgxcBTqpoJLAT+KiKfWraqPqGqeaqal5KS0v1qTchYWVBGlNvFnHH2d+ArEWFBbhofFB2jqrbR6XJMAPIl0IuBrDbPM/l0k8odwIsAqvohEANYw6jpkKqysqCMi8ckMygm0ulygspVuWk0tShv77JmF/NpvgT6BmCMiIwQkSg8Bz2XtpvnMHAFgIhMwBPo1qZiOlR49DQlp2qtd0sPTMtMJC0+xgbrMh3qMtBVtQm4F1gF7MTTm6VQRB4Skeu8s30HuFNEtgLPAberDeBsOrGioBS3S7hyQqrTpQQdl0u4alIq7+6ppKahyelyTICJ8GUmVV2O52Bn29ceaPN4B3CRf0szoWplQRnnj0giKS7K6VKC0lW5afzlw0O8u7uSBZPTnS7HBBA7Pc/0q6KKavZVnrXmll6YleP5MLTeLqY9C3TTr1rH9Z430QK9pyLcLuZOSOXtXRXUNzU7XY4JIBbopl+tLCxjenYiaQkxTpcS1ObnpnGmvokPio47XYoJIBbopt8cOVFDQcnpjwaaMj134eghDIqOYIUNqWvasEA3/WaVd6RAaz/vvegIN5dPGMo/dpTT1NzidDkmQFigm36zqrCMCenxDB8S53QpIWFBbhonaxpZf/CE06WYAGGBbvpFRXUd+YdOWnOLH106NoWYSJddQNp8xALd9Is3CstRteYWfxoQFcFnxqawqrCMlhY7j89YoJt+sqqwjBHJcYxNHeh0KSFlQW465afr2VJ8yulSTACwQDd97lRNAx/uO85Vk9IQ6WjwTtNTc8YPJdIt1uxiAAt00w/e2llBU4tac0sfSIiN5MJRyawsKMOGTzIW6KbPrSwsIz0hhqmZCU6XEpIW5KZx+EQNO0urnS7FOMwC3fSps/VNvLen0ppb+tDciam4BFbaSUZhzwLd9Km3dlVQ39TCAmtu6TNDBkYza0QSKwutHT3cWaCbPvX6tqMMHRRNXk6S06WEtPmT0thTfoZ9lWecLsU4yALd9Jkz9U28s7uShZPTcbusuaUvtV5s23q7hDcLdNNn3tpZTn1TC1dPsYsw9LX0hFimZSVaoIc5C3TTZ17fVkpqfDQzswc7XUpYWDg5je0lVRw6ftbpUoxDLNBNn6iua+SdPZ7mFpc1t/SLq6cMA2DZNuvtEq4s0E2feGtnBQ1NLVxjzS39JiMxlpnDB/Pa1qNOl2IcYoFu+sSybaWkJ8QwPcuaW/rTNVPS2VVWTVGF9XYJRz4FuojMF5HdIlIkIvd3Ms9NIrJDRApF5Fn/lmmCyem6Rt7bU8mCXGtu6W8LJ6cjAsu22V56OOoy0EXEDTwGLAAmAotFZGK7ecYA3wcuUtVJwL/2Qa0mSLy5o5yGZuvd4oTU+Bhm5STx2tajNrZLGPJlD30WUKSq+1W1AXgeWNRunjuBx1T1JICqVvi3TBNMXt9WyrCEGKZnJTpdSli6Zuow9lWeZVeZje0SbnwJ9AzgSJvnxd7X2hoLjBWRNSKyVkTmd7QgEblLRPJFJL+ysrJnFZuAVlXbyPt7j1nvFgctyE3D7RJrdglDvgR6R/+V7b/LRQBjgMuAxcAfReRTu2eq+oSq5qlqXkpKSndrNUHAmluclzwwmgtHDWHZtlJrdgkzvgR6MZDV5nkm0P6jvxh4VVUbVfUAsBtPwJsw8/r2UjISPWctGudcMyWdQ8drKCg57XQpph/5EugbgDEiMkJEooBbgKXt5nkFmAMgIsl4mmD2+7NQE/iqahp5f28lV09Jt6FyHXbVpDQirNkl7HQZ6KraBNwLrAJ2Ai+qaqGIPCQi13lnWwUcF5EdwGrgu6p6vK+KNoFpRUEpjc3K1ZOtucVpiQOiuGRMsjW7hBmf+qGr6nJVHauqo1T1Ye9rD6jqUu9jVdX7VHWiqk5W1ef7smgTmF7ZUsLI5Dim2JWJAsK1U4dRcqqWTYftAtLhws4UNX5x9FQt6w6cYNG0DGtuCRBzJ6YSFeGyZpcwYoFu/GLp1qOowqJpw5wuxXgNionksrEpvL6tlOYWa3YJBxboxi9e2VzCtKxEcpLjnC7FtLFoWgYV1fV8sO+Y06WYfmCBbnptd1k1u8qquX56+/PNjNOumDCUQTER/H1TidOlmH5ggW567ZUtJbhdYicTBaCYSDfXTElnZWEZNQ1NTpdj+pgFuumVlhbl1c0lXDImmeSB0U6XYzpw/fRMahqaWVVol6cLdRboplc2HDzB0ao6PjvNmlsCVd7wwWQOjmWJNbuEPAt00yuvbDnKgCg38yalOl2K6YTLJVw/PYM1RccoP13ndDmmD1mgmx6rb2pm+fZS5k1MZUBUhNPlmHO4fnoGLQqvbrG99FBmgW567B87yqmqbeSGGZlOl2K6MDJlINOyEq3ZJcRZoJse+1t+McMSYrhodLLTpRgf3DAjg11l1ewstREYQ5UFuumR0qpa3ttbyY0zM3HbhSyCwjVThhHhEv6+2fbSQ5UFuumRJZtKUIUbZ2Z1PbMJCElxUcwZP5Qlm0pobG5xuhzTByzQTbepKi/mH2H2yCSyhwxwuhzTDbecl8WxM/W8vcsu+xuKLNBNt60/cIJDx2u4Kc/2zoPNZ8amkBofzQsbjnQ9swk6Fuim2/62sZiB0REsyLVT/YNNhNvF52dm8c7uCkqrap0ux/iZBbrpljP1Tby+rZRrp6YTG+V2uhzTAzflZdGinl5KJrRYoJtuWb6tlNrGZj5vzS1BK3vIAC4aPYQXNhyhxcZJDykW6KZbXsg/wqiUOKZnJTpdiumFW87LpuRULWtsnPSQYoFufLbj6Gk2HjrJ4lnZdpm5IDdvUiqJAyJ53g6OhhSfAl1E5ovIbhEpEpH7zzHfjSKiIpLnvxJNoHhm3SFiIj0H1Uxwi45wc8P0TN4oLOP4mXqnyzF+0mWgi4gbeAxYAEwEFovIxA7mGwR8C1jn7yKN807XNfLK5hKumzqMhAGRTpdj/ODm87JobFYb3yWE+LKHPgsoUtX9qtoAPA8s6mC+fwd+Btj4nCHo75tKqGlo5rbZOU6XYvxkXNogZg4fzDPrDtnB0RDhS6BnAG0b2oq9r31ERKYDWaq6zI+1mQChqvx17SGmZiUyOTPB6XKMH91+YQ6Hjtfwzh47czQU+BLoHR39+ujjXERcwKPAd7pckMhdIpIvIvmVlZW+V2kctXb/CYoqznDb7OFOl2L8bH5uGqnx0fx5zUGnSzF+4EugFwNtj4JlAkfbPB8E5ALviMhBYDawtKMDo6r6hKrmqWpeSkpKz6s2/eqZtYdIHBDJNXYR6JAT6XZx6/nDeX/vMYoqzjhdjuklXwJ9AzBGREaISBRwC7C0daKqVqlqsqrmqGoOsBa4TlXz+6Ri06/KT9exqrCMm/KyiIm0M0ND0eLzs4lyu3j6w4NOl2J6qctAV9Um4F5gFbATeFFVC0XkIRG5rq8LNM56bv1hmlX54vnZTpdi+kjywGiumZrOyxuLOV3X6HQ5phd86oeuqstVdayqjlLVh72vPaCqSzuY9zLbOw8NdY3NPLP2EJeNTWH4kDinyzF96PYLczjb0MxLNr5LULMzRU2nXt1SwrEzDdx5yUinSzF9bEpmIjOyE3n6w4PWhTGIWaCbDqkqf3z/ABPT47lg1BCnyzH94MsX5nDQujAGNQt006F39lSyt+IMd146wsZtCRMLctNJT4jh9+/ud7oU00MW6KZDf3x/P2nxMVw9eZjTpZh+EhXh4o6LR7DuwAk2Hz7pdDmmByzQzacUlFSxpug4X74wh6gI+xMJJ7fMyiY+JoLH393ndCmmB+y/1XzKY6uLGBQTwRdnW1fFcDMwOoIvXZDDGzvK2VdpJxoFGwt08wl7y6tZWVjGly/IIT7GRlUMR7dflEOU28UT1pYedCzQzSf89p19xES4+erFI5wuxTgkeWA0N+VlsWRzMSWn7ELSwcQC3Xzk0PGzLN16lC+en01SXJTT5RgH3X3ZKAB+u7rI4UpMd1igm4/87vqSCaoAABCHSURBVJ19uEW481I7kSjcZSTGclNeFi/mH7G99CBigW4Az975SxuLufm8LFLjY5wuxwSAe+aMBjwHyU1wsEA3APzyzb1EuIVvXj7a6VJMgMhIjOXm87L4W/4Rik/WOF2O8YEFumFPeTWvbCnhyxfkMNT2zk0b91w2GkH49Vt7nS7F+MAC3fBfb+wmLiqCuz8zyulSTIAZlhjLrbOH89LGYvaWVztdjumCBXqY23rkFKsKy/naJSMYbD1bTAfuvXw0cVER/HTlLqdLMV2wQA9jqsq/L9tB8sAo7rB+56YTSXFR3H3ZKN7cWcH6AyecLsecgwV6GFu+vYz8Qye5b+44BtlZoeYcvnrRCFLjo/nPFTtRtfHSA5UFepiqa2zmkZU7GZ82iJvPy+r6DSasxUa5+c7ccWw+fIpXtxzt+g3GERboYeqpDw5y5EQtP7x6Im6XjXduunbjzEymZCbwk+U7OVPf5HQ5pgMW6GGotKqW/3lrL1eMH8rFY5KdLscECZdL+PF1k6iorud/rBtjQLJAD0MPvbaDphblwesmOV2KCTLTswfz+ZmZPLnmgA2vG4B8CnQRmS8iu0WkSETu72D6fSKyQ0S2ichbIjLc/6Uaf3h7VzkrCsr41hVjyEoa4HQ5Jgh9b/54YiLd/PDvBXaANMB0Gegi4gYeAxYAE4HFIjKx3WybgTxVnQK8BPzM34Wa3qttaOaBVwsZPXQgd15iA3CZnkkZFM33F0zgw/3HeWHDEafLMW34soc+CyhS1f2q2gA8DyxqO4OqrlbV1sEe1gKZ/i3T+MMv3thN8claHv5srl1azvTKLedlMXtkEg8v30n56TqnyzFevvxXZwBtP4aLva915g5gRUcTROQuEckXkfzKykrfqzS9tm7/cZ5cc4DbZg/n/JFDnC7HBDmXS3jkhik0NLXww1es6SVQ+BLoHfVp6/C3JyK3AnnAzzuarqpPqGqequalpKT4XqXplbP1Tfzfl7aSnTSA+xeMd7ocEyJykuP4zryx/GNHOUs2lThdjsG3QC8G2p55kgl86swCEbkS+DfgOlWt9095xh8eXr6T4pO1/OLzU4mLjnC6HBNC7rh4JOePSOKBVws4dPys0+WEPV8CfQMwRkRGiEgUcAuwtO0MIjId+D2eMK/wf5mmp1ZsL+XZdYe565KRnJeT5HQ5JsS4XcKjN0/D7RK+/fwWGptbnC4prHUZ6KraBNwLrAJ2Ai+qaqGIPCQi13ln+zkwEPibiGwRkaWdLM70o0PHz/K9l7YxLSuR78wb53Q5JkQNS4zlJzdMZsuRU/zyzT1OlxPWfPr+rarLgeXtXnugzeMr/VyX6aX6pmb+5dlNiMBvvjDderWYPnXNlGG8v+cYj63ex9TMROZNSnO6pLBk/+UhSFX50auFFJSc5r9umkbmYDuByPS9Hy+axJTMBO57caudReoQC/QQ9OSagzy/4Qj3zhnN3ImpTpdjwkRMpJvf3TqTqAgXX//rRqrrGp0uKexYoIeY1bsrePj1HVw1KZX75o51uhwTZjISY/nN4ukcOHaWbzyziYYmO0janyzQQ0hBSRXffHYz49PiefTmabhsWFzjgAtHJ/OfN0zmn0XHuH/JNjvpqB9Zp+QQUVRxhi89uZ6E2Ej+dHseA6LsV2ucc1NeFqWn6nj0zT2kxcfwvfl2Qlt/sP/6EFB8sobb/rQOlwjPfO180hNinS7JGL51xWjKTtfx23f2ER3h5ttXjnG6pJBngR7kDh0/yxf+sI6z9U08f9cFjEiOc7okYwAQER7+bC4NTS08+uYeItzCv8wZ7XRZIc0CPYjtLa/mi39cR2NzC8/eOZuJw+KdLsmYT3C5hJ/dOIUWVX6+aje1Dc18Z95YROz4Tl+wQA9SGw+d5M6n83G7hOfvuoBxaYOcLsmYDrldwi8+P5Uot4vfrC6iorqOn1w/mQi39cnwNwv0IPTqlhK++9I20hNieOors6yZxQQ8t0t45HOTSY2P5tdvF1FRXc+vbplOQmyk06WFFPuIDCJNzS38fNUuvv38FqZlJfLKPRdZmJugISLcN28cD1+fyz/3HuOzj61hT3m102WFFAv0IFFWVccX/riOx1bv4+a8LJ6543wGx0U5XZYx3fbF84fz7J2zqa5r4rOPrWHJpmLrq+4nFugBTlVZtu0oC3/9PgUlVfz3TVP56Y1TbLAtE9RmjUhi2TcvZtKweO57cSv3PruZUzUNTpcV9KwNPYBVnK7jh68U8MaOciZnJPDozdMYPXSg02UZ4xdpCTE8f9cF/P69fTz6jz2sP3iCH149geumDrNeMD0kTn3VycvL0/z8fEfWHejqGpt5cs0BHnu7iKYW5b65Y7nj4hHWK8CErIKSKn7w9+1sK67iwlFD+NG1k6znVidEZKOq5nU4zQI9cDQ2t/DK5hJ+9dZeik/WMndiKv+2cAI5duDThIHmFuXZ9Yf5+cpdVNc3cf30DP7PlWPJSrLhn9uyQA9wNQ1N/H1zCY+/u48jJ2qZNCyeHyycwEWjk50uzZh+d/JsA4+/u4+nPjhIc4ty3bRh3HXpSMan2YlzYIEesPaWV/Ni/hFe2HCE03VNTMlM4FuXj+GKCUOtDdGEvdKqWp54bz8vbDhCTUMzF4wcwufzMlmQm05slNvp8hxjgR5ASk7V8trWo7y65Sg7S0/jdgnzJ6Xx5QtzOC9nsAW5Me2cqmngf9cd5oUNRzh8ooZB0RFcMzWdhZPTOX/EkLDr8WWB7qD6pmbyD57kvT2VvLunkl1lnhMppmUlsmjaMK6eks7QQTEOV2lM4GtpUdYfPMHf8otZvr2U2sZmBkZHcOnYZC4fn8rskUlkJMaG/E6RBXo/aWpu4cjJWraXVLHl8Cm2HDlJwdHTNDS1EOkW8oYncenYFBZOTmP4EDvQaUxP1TY0s6boGG/tKuetnRVUVNcDkJ4Qw8zhg5k5fDAT0+MZnxZPwoDQGl6g14EuIvOBXwFu4I+q+ki76dHA08BM4Dhws6oePNcygzXQ6xqbKauqo7SqjtKqWg4eO0tR5Rn2VZzlwLGzNDR7LrkVE+lickYCUzMTuWDUEGaPHEJctHX7N8bfWlqUXWXV5B86wYaDJ8k/eILSqrqPpg9LiGFM6iCykwaQnTSArKQBZCXFkpk4gPjYiKDbo+9VoIuIG9gDzAWKgQ3AYlXd0Waee4Apqnq3iNwCXK+qN59ruf0V6KpKi3q6RDW1tFDX2EJtYzO1Dc3UNTZ/9Li20fP8bH0zVbWNnKptoKqmkVM1nsenahopP13HyZpPXvjWJTB8SByjUgYyamgco1MGMiE9nnFpg4i0fuPGOKL8dB07S0+zs7SaXWWnKao4w+ETNVTXNX1ivii3i+SBUSQPiiZ5YDRD4qIYFBPJoJiINjfP89hIN9ERbqIjXURHuDyPI1ze527c/XTJx3MFui+7jLOAIlXd713Y88AiYEebeRYBD3ofvwT8RkRE+6A9553dFfz7sh20KDS1tNDcrDSregPbc9/6uMV73xNRbheJAyI9t9gospIGkJczmPSEWNLiY0hPiCE1IYbMwbFER4TvEXdjAlFqfAyp8TFcNm7oJ16vqmnkyMkaDp+o4eipWirP1HOsuoFjZ+opq6pjx9HTnKlv4kx9UydL7pzbJbhdQoRLcIvg8j7+6F4+nr54VjZ3XjrSX5v7EV8CPQM40uZ5MXB+Z/OoapOIVAFDgGNtZxKRu4C7ALKzs3tU8KCYSManxX/0g2n9Ybnb3D75ugu3CBFuz7TYSLfnkzbS5Xkc5Xke4308IMpNYmwUMZGuoPsqZow5t4QBkSQMSCA3I+Gc8zW3KGfqm6iua6S6ronquibqm5qpb2yhvqnF87iphfpG7733tbY7ki0tH+9sNn/iNRgaH90n2+dLoHeUau13e32ZB1V9AngCPE0uPqz7U1oPeBhjTF9xu4SE2MigG6/dl0beYiCrzfNM4Ghn84hIBJAAnPBHgcYYY3zjS6BvAMaIyAgRiQJuAZa2m2cp8GXv4xuBt/ui/dwYY0znumxy8baJ3wuswtNt8UlVLRSRh4B8VV0K/An4q4gU4dkzv6UvizbGGPNpPnWMVtXlwPJ2rz3Q5nEd8Hn/lmaMMaY7rKO0McaECAt0Y4wJERboxhgTIizQjTEmRDg22qKIVAKHgGTanVEaZmz7w3f7w3nbwba/p9s/XFVTOprgWKB/VIBIfmcDzYQD2/7w3f5w3naw7e+L7bcmF2OMCREW6MYYEyICIdCfcLoAh9n2h69w3naw7ff79jvehm6MMcY/AmEP3RhjjB9YoBtjTIjo90AXkSQR+YeI7PXed3i1ChFpFpEt3lv74XqDjojMF5HdIlIkIvd3MD1aRF7wTl8nIjn9X2Xf8GHbbxeRyja/7685UWdfEZEnRaRCRAo6mS4i8mvvz2ebiMzo7xr7ig/bfpmIVLX53T/Q0XzBSkSyRGS1iOwUkUIR+XYH8/jv96+q/XoDfgbc7318P/DTTuY709+19eE2u4F9wEggCtgKTGw3zz3A497HtwAvOF13P2777cBvnK61D38GlwIzgIJOpi8EVuC58tdsYJ3TNffjtl8GLHO6zj7c/nRghvfxIGBPB3//fvv9O9Hksgj4i/fxX4DPOlBDf/voQtuq2gC0Xmi7rbY/l5eAKyQ0Lmrqy7aHNFV9j3NfwWsR8LR6rAUSRSS9f6rrWz5se0hT1VJV3eR9XA3sxHMN5rb89vt3ItBTVbUUPBsLDO1kvhgRyReRtSIS7KHf0YW22/9SP3GhbaD1QtvBzpdtB/ic9+vmSyKS1cH0UObrzyhUXSAiW0VkhYhMcrqYvuJtRp0OrGs3yW+/f58ucNFdIvImkNbBpH/rxmKyVfWoiIwE3haR7aq6zz8V9ju/XWg7CPmyXa8Bz6lqvYjcjeebyuV9XlngCNXfvS824Rmb5IyILAReAcY4XJPfichA4GXgX1X1dPvJHbylR7//Pgl0Vb2ys2kiUi4i6apa6v1aUdHJMo567/eLyDt4PtmCNdC7c6Ht4hC70HaX266qx9s8/QPw036oK5D48vcRktqGm6ouF5HfikiyqobMoF0iEoknzP9XVZd0MIvffv9ONLm0vaD0l4FX288gIoNFJNr7OBm4CNjRbxX6XzhfaLvLbW/XXngdnnbGcLIU+JK3t8NsoKq1WTLUiUha67EiEZmFJ5OOn/tdwcO7bX8Cdqrqf3cym99+/32yh96FR4AXReQO4DDea5GKSB5wt6p+DZgA/F5EWvD8gh9R1aANdA3jC237uO3fEpHrgCY82367YwX3ARF5Dk9vjmQRKQZ+BEQCqOrjeK7XuxAoAmqArzhTqf/5sO03At8QkSagFrglRHZkWl0E3AZsF5Et3td+AGSD/3//duq/McaECDtT1BhjQoQFujHGhAgLdGOMCREW6MYYEyIs0I0xJkRYoBvThogkisg9TtdhTE9YoBvzSYl4Rr40JuhYoBvzSY8Ao7xjc//c6WKM6Q47sciYNrwj4i1T1VyHSzGm22wP3RhjQoQFujHGhAgLdGM+qRrPpcKMCToW6Ma04R2bfY2IFNhBURNs7KCoMcaECNtDN8aYEGGBbowxIcIC3RhjQoQFujHGhAgLdGOMCREW6MYYEyIs0I0xJkT8P3wSrzo1XfjxAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Consider the probability distibution for\n", "# the predicted value of t at a certain x\n", "# position, eg one of our training points\n", "\n", "xvalue = x[5]\n", "\n", "# From above, we are assuming it is distributed\n", "# normally, with mean given as the value of our\n", "# polynomial, and for now lets pick 0.3 as a standard\n", "# deviation (cheating, because we know that is the \n", "# real value from the noise)\n", "\n", "tmean = polyval(w, xvalue)\n", "sigma = 0.3\n", "\n", "# We plot over the domain +-4 standard deviations,\n", "# using the conditional probability function defined earlier\n", "ts = arange(tmean-4*sigma, tmean+4*sigma, 0.01)\n", "pdf_t = cond_pdf_t(ts, xvalue, w, sigma)\n", "\n", "plot(ts, pdf_t)\n", "\n", "xlabel(\"t\")\n", "title(\"$p(t|x,$w$,\\sigma)$\", fontsize=20)\n", "\n", "# As expected, it is just a normal distribution" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text([3.4242027], 0.8019421131641933, ' $p(t|x,$w$,\\\\sigma)$')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x,t,'ro')\n", "plot(xs, polyval(w, xs), 'g')\n", "\n", "# Now, to illustrate, we plot this gaussian sideways\n", "# over our graph, showing how its width (standard deviation)\n", "# expresses our uncertainty in the value of t\n", "\n", "def plot_sideways(x, y, zeropos, *args, **kwargs):\n", " # Plot sideways by switching x and y, adding zeropos\n", " # to our y to change the location\n", " plot(y + zeropos, x, *args, **kwargs)\n", " # Draw a vertical line for the axis\n", " vlines(zeropos, min(x), max(x), *args, **kwargs)\n", " \n", "plot_sideways(ts, pdf_t, xvalue, 'b')\n", "\n", "ylabel(\"t\")\n", "xlabel(\"x\")\n", "text(xvalue + max(pdf_t), polyval(w,xvalue), \" $p(t|x,$w$,\\sigma)$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot probability as color in the 2d plane." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X,Y = meshgrid(arange(0,7,0.1),arange(-1,2.1,0.1))\n", "Z = [cond_pdf_t(Yi,Xi,w,sigma) for Xi,Yi in zip(X,Y)]\n", "Z = array(Z).reshape(X.shape)\n", "jet()\n", "contourf(arange(0,7,0.1), arange(-1,2.1,0.1), Z, levels=40)\n", "colorbar()\n", "xlim([0,6.9])\n", "ylim([-1,2])\n", "scatter(x,t,c='w')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1.32553585]\n", " [0.11718557]\n", " [1.22835099]\n", " [0.61759085]\n", " [0.53579397]\n", " [1.23106725]\n", " [0.9629209 ]\n", " [0.05338001]\n", " [1.08959831]\n", " [0.0064887 ]\n", " [1.32973981]\n", " [1.32810898]\n", " [0.79024069]\n", " [1.23107637]\n", " [1.31759643]]\n" ] } ], "source": [ "# So, now the question is how do we learn the weights w\n", "# and the standard deviation sigma from our training data\n", "\n", "# Well, we can see how probable each training point is given some\n", "# weights and standard deviation using our conditional probability\n", "# density function (now working on a vector of training data rather than a scalar)\n", "# (note these values are not probabilities, hence may be >1)\n", "\n", "print(cond_pdf_t(t, x, w=[0.01, -0.02, -0.5, 1.7, -0.4], sigma=0.3))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle p(\\mathbf{t}|\\mathbf{x},\\mathbf{w},\\sigma) = \\prod_{n=1}^{N} Norm(t_n|y(x_n,\\mathbf{w}),\\sigma) = 0.000064$" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# But we want to know how probable all the data is -- for this\n", "# we use the product rule, assuming the data was drawn independently, \n", "# simply multiplying all the previous values\n", "\n", "p = product(cond_pdf_t(t, x, w=[0.01, -0.02, -0.5, 1.7, -0.4], sigma=0.3), 0)\n", "\n", "\n", "Math(\"$p(\\mathbf{t}|\\mathbf{x},\\mathbf{w},\\sigma) = \\prod_{n=1}^{N} Norm(t_n|y(x_n,\\mathbf{w}),\\sigma) = %f$\" % p)\n", "\n", "# This is the value we want to maximise - try tweaking the parameters\n", "# to get a high value, it probably isn't easy!" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Likelihood')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# When considered as a function of w and sigma, this is our\n", "# likelihood function, and is *not* a probability distribution\n", "# over these variables. Note that syntactically it is the same\n", "# expression as a conditional probability dist on t, it is just\n", "# a matter of which variables we \"fix\" and which we \"slide around\". \n", "\n", "def likelihood_w_sigma(w, sigma):\n", " return product(cond_pdf_t(t, x, w, sigma), 0)\n", "\n", "# We can plot likelihood against different standard deviation\n", "# to see where the maximum is - how does this compare to the sd\n", "# we actually set for the noise?\n", "\n", "sigs = arange(0.02,1,0.01)\n", "plot(sigs, likelihood_w_sigma(w, sigs))\n", "xlabel(\"$\\sigma$\", fontsize=22)\n", "ylabel(\"Likelihood\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'NLL')" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# In practice we consider the log of the likelihood, which lets us\n", "# turn the product into a sum and avoids numerical underflow\n", "# from multiplying small values. We want to maximise the log likelihood,\n", "# or equivalently, minimise the negative log likelihood\n", "\n", "def nlog_likelihood_w_sigma(w, sigma):\n", " # log(product(cond_pdf_t(t, x, w, sigma),0))\n", " # = sum(log(cond_pdf_t(t, x, w, sigma)), 0)\n", " # using negative:\n", " return -sum(log(cond_pdf_t(t, x, w, sigma)), 0)\n", "\n", "plot(sigs, nlog_likelihood_w_sigma(w, sigs))\n", "xlabel('$\\sigma$')\n", "ylabel('NLL')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Maximum Likelihood Solution\n", "===========================\n", "\n", "The Maximum Likelihood solution can be found in closed form.\n", "\n", "First we can use the properties of the (natural) log to simplify the formula\n", "\n", "$\\log(p(\\mathbf{t}|\\mathbf{x}, \\mathbf{w},\\sigma)) = \\log(\\prod_{n=1}^{N} Norm(t_n|y(x_n,\\mathbf{w}),\\sigma))$\n", "\n", "$ = \\sum_{n=1}^{N} \\log[(2\\pi\\sigma^2)^{-1/2} \\exp(-\\frac{(t_n-y(x_n,\\mathbf{w}))^2}{2\\sigma^2})]$\n", " \n", "$ = N \\log((2\\pi\\sigma^2)^{-1/2}) -\\frac{1}{2\\sigma^2} \\sum_{n=1}^{N} (t_n-y(x_n,\\mathbf{w}))^2 $\n", " \n", "$ = -\\frac{1}{2\\sigma^2} \\sum_{n=1}^{N} (t_n-y(x_n,\\mathbf{w}))^2 - \\frac{N}{2}\\log(\\sigma^2) - \\frac{N}{2}\\log(2\\pi) $\n", "\n", "Now we can find the gradient wrt $\\mathbf{w}$ and $\\sigma^2$\n", "\n", "(TODO)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In polynomial regression, the hypothesis, $y$, is of the form:\n", "\n", "$y(x_n, \\mathbf{w}) = \\sum_{i=0}^{M} w_i x^{i}$, where $M-1$ is the degree of the polynomial (we started at 0).\n", "\n", "Using matrix notation, let $\\mathbf{w} = w_1 ... w_M$, a vector of weights. Also, let $\\mathbf{x_n} = x_n^{0}...x_n^{M}$, a vector of the first $M$ powers of $x_n$.\n", "\n", "Then, $y(x_n, \\mathbf{w}) = \\mathbf{w}^{T}\\mathbf{x_n}$\n", "\n", "The derivative of $y(x_n, \\mathbf{w})$ wrt $\\mathbf{w}$ will come in handy soon, and we can see that: $\\frac{d y}{d \\mathbf{w}} = \\mathbf{x_n}$ by the linearity of $\\mathbf{w}^{T}\\mathbf{x_n}$.\n", "\n", "Another point of notation, we'll introduce the design matrix, $X$, in which each training example is a row and each power of $x_n$ is a column.\n", "\n", "Taking the derivative of $-\\frac{1}{2\\sigma^2} \\sum_{n=1}^{N} (t_n-y(x_n,\\mathbf{w}))^2 - \\frac{N}{2}\\log(\\sigma^2) - \\frac{N}{2}\\log(2\\pi) $ with respect to $\\mathbf{w}$, we obtain:\n", "\n", "$ -\\sum_{n=1}^{N}(t_n - y(x_n, \\mathbf{w})) \\frac{d y}{d \\mathbf{w}}$, using the fact that constants are irrelevant in minimization and we don't care about scalar multiples of the gradient.\n", "\n", "Replacing $\\frac{d y}{d \\mathbf{w}}$ with what we computed earlier, \n", "\n", "$ = -\\sum_{n=1}^{N}(t_n - y(x_n, \\mathbf{w}))\\mathbf{x_n}$\n", "\n", "$ = -\\sum_{n=1}^{N}(t_n - \\mathbf{w}^{T}\\mathbf{x_n})\\mathbf{x_n}$\n", "\n", "Now using the design matrix to simplify things, and letting $\\mathbf{t} = t_1 ... t_N$, the vector of targets, we can compute the $M$ dimensional gradient of the weights with respect to the error:\n", "\n", "$ = -\\sum_{n=1}^{N}(t_n \\mathbf{x_n}) + \\sum_{n=1}^{N}(\\mathbf{w}^{T}\\mathbf{x_n})\\mathbf{x_n}$\n", "\n", "$ = -X^{T}\\mathbf{t} + X^{T}(X\\mathbf{w}) $\n", "\n", "And since we are at an optima when this gradient is zero,\n", "\n", "$ 0 = -X^{T}\\mathbf{t} + X^{T}(X\\mathbf{w}) $\n", "\n", "$ X^{T}\\mathbf{t} = (X^{T}X)\\mathbf{w} $\n", "\n", "Almost there! Using the fact that $X^{T}X$ has a pseudoinverse, \n", "\n", "$ \\mathbf{w} = (X^{T}X)^{-1}X^{T}\\mathbf{t} $\n", "\n", "And we have a closed form solution for the weights. This solution is refered to as the \"normal equation\". Note for big-data problems, the performance of this equation is pretty poor, and a much fa" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 1 }