{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Stochastic Gradient Descent" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Run once per session\n", "!pip install fastai -q --upgrade" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Grab the related library we will need" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.basics import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Stochastic Gradient Descent (SGD):\n", "\n", "* Optimization technique (**optimizer**)\n", "* Commonly used in neural networks\n", "* Example with linear regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression\n", "\n", "* Fit a line on 100 points" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate our data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.ones(n,2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100, tensor([[1., 1.],\n", " [1., 1.],\n", " [1., 1.],\n", " [1., 1.],\n", " [1., 1.]]))" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "len(x), x[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Randomize it in a uniform distribution from -1 to 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[-0.0318, 1.0000],\n", " [ 0.2159, 1.0000],\n", " [-0.9402, 1.0000],\n", " [ 0.1420, 1.0000],\n", " [ 0.5565, 1.0000]]), torch.Size([100, 2]))" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "x[:,0].uniform_(-1., 1)\n", "x[:5], x.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Any linear model is `y=mx+b`\n", "* `m`, `x`, and `b` are matrices\n", "* We have `x`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([3., 2.]), torch.Size([2]))" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "m = tensor(3.,2); m, m.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* `b` is a random bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.7173, 0.7303, 0.1615, 0.2098, 0.8227]), torch.Size([100]))" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b = torch.rand(n); b[:5], b.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can make our `y`\n", "* Matrix multiplication is denoted with `@`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = x@m + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll know if we got a size wrong if:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: size mismatch, m1: [1 x 2], m2: [100 x 2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:197" ] } ], "source": [ "m@x + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot our results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcH0lEQVR4nO3df4wcZ33H8c/Xl0tyhpJzGosmFzs2\nKnUKpLXhFGhdATElDr8SEyhJJFqgVO4PWrVpSWsEUhBqFbdRC62KSi1+F2QCCXFdAhioQxERMZxl\nQ0KCwYTS5Egb0+RQIca52N/+sbPOeG9m55ndZ2dndt8vyfLd7OwzT+Yu3338ne/zPObuAgA017Jh\ndwAA0B8COQA0HIEcABqOQA4ADUcgB4CGO20YFz3nnHN8zZo1w7g0ADTW/v37f+juKzuPDyWQr1mz\nRnNzc8O4NAA0lpl9P+t4lNSKmU2b2c1m9i0zu9fMfiVGuwCAYrFG5H8v6bPu/mozO13S8kjtAgAK\n9B3IzewsSc+X9HpJcvfHJD3Wb7sAgDAxUitrJR2R9AEzO2Bm7zWzJ3WeZGZbzWzOzOaOHDkS4bIA\nAClOID9N0rMl/ZO7b5D0E0nbOk9y9x3uPuvusytXLnnoCgDoUYxA/oCkB9x9X/L9zWoFdgBABfrO\nkbv7f5vZ/Wa2zt0PSXqRpHv67xoANN+uA/O6cc8h/WDhqM6bntJ1m9dpy4aZqNeIVbXyR5I+mlSs\n3CfpDZHaBYDG2nVgXm/55F06unhckjS/cFRv+eRdkhQ1mEepI3f3g0n++5fcfYu7PxKjXQBoshv3\nHDoZxNuOLh7XjXsORb0Oa60AwID8YOFoqeO9IpADwICcNz1V6nivCOQAMCDXbV6nqcmJU45NTU7o\nus3rol5nKItmAcA4aD/QbErVCgAgw5YNM9EDdydSKwDQcARyAGg4AjkANByBHAAajkAOAA1HIAeA\nhiOQA0DDEcgBoOEI5ADQcARyAGg4AjkANBxrrQBorCq2Uavz9dsI5AAaqapt1Op6/TRSKwAaqapt\n1Op6/TRG5AAaI53K8JxzYm+jlqeqbdxCEMgBNEJnKiNPrG3UivLf501PaT4jaMfexi1ElNSKmf2n\nmd1lZgfNbC5GmwCQlpXK6BRrG7X2h8Z8MvJv5793HZg/eU5V27iFiDkiv8TdfxixPQCQ1AqsWaPf\nNlNrJHzJhSt1455Duvamg8FVJFkj727573Z7VW3jFoLUCoBaa4+O88xMT+mObZt6qiLJe0/eyL8z\n/13FNm4hYlWtuKTPmdl+M9uadYKZbTWzOTObO3LkSKTLAhh13VIq6VRGL1Ukee+ZMMs8fxj57xCx\nAvmvufuzJb1E0pvM7PmdJ7j7DnefdffZlStXRrosgFHXrQrkhisvOjki7qWKJO+14+61yX+HiBLI\n3X0++fshSbdKujhGuwCQNwqemZ7Slg0z2nVgXhu3780tR+w2ip5ePpl5fMXySd1w5UWamZ6SJde6\n4cqLJEkbt+/V2m23aeP2vac8/BymvnPkZvYkScvc/f+Sry+V9I6+ewYAalWHdOat26PjopLEolG0\n50R/96X57zrN5OwUY0T+VElfNrOvS/qqpNvc/bMR2gUAbdkwkzk63rJhpmv+PH1enh8dXQw+XqeZ\nnJ36HpG7+32SfjlCXwAgU151SF6O2yTdsW1TYbtlJvXUaSZnJ9ZaAdBYeTnuvOOdykzqycu116GS\nhUAOoLG65bhDdEvbdKrTTM5OTAgC0Fhlctx5Qif11GkmZycCOYDoqtpwoeqFq+oyk7MTgRxAVKFl\nennBvsyHQLfSxHFCIAcQVciCU3nBfu77D+uW/fPBtdp1TndUiUAOIKqQMr28YL9z3/063vGksvND\noFNd0x1VomoFQFQhZXrd1jjJUoda7TojkAOI6pILsxfFSx/PC/ZNW3WwLgjkAEppL1KVt3DU7d/K\nXqY6fTyvJvua566qba12nZEjBxAspCIlJEfe7SHl7AVnj/3Dy7II5ACChVSkhNZ25z2k5OFleaRW\nAATptm9merRd56nso4oROYBCRftmpkfbdajtrmpmaV0QyAEUCt03s61seqQo8JYJzHXeAGJQSK0A\nKBS6b2Yv2oF3fuGoXE8E3nY1TNHrneq8AcSgEMgBFCraN7MfRYG3bGCu8wYQg0IgB1Co2wPMorry\nInkPUNvHywbmOm8AMSgEcgCFwThvAwZJpdIeWfJmc7aPlw3M41g1w8NOYMyFPhzMeoC5cfvewrry\nInnrq7SPl12qtg5VM1UjkANjLmSST55e8tHpCpTp5ZMySVmhfCYZcfcSmMdtUhGBHBhhIWV7/Twc\nLLtDT+fo/5FHs7dk6xxxj1tgLitaIDezCUlzkubd/eWx2gXQm9CUSZlg3Dma/mlGbXm3tEe3evS2\nCbO+SxrHTcwR+R9LulfSUyK2CaBHoSmTvBz0JReu1Mbte0+O5i+5cOUpu/dkjaZXLJ/U9a94Zm4Q\nDhnln3AniJcUpWrFzM6X9DJJ743RHoD+haZMsipSXvWcGd2yf/6UapSP3vlfhaPp5aef1jUIh5QA\njnKZ4KDEGpG/S9KfS/qZvBPMbKukrZK0evXqSJcFkKdMyqQzB51VjZJdW3KqohF31ug/bdTLBAel\n7xG5mb1c0kPuvr/bee6+w91n3X125crsHUQAxNNPPXWvsyCLRtOdo/8Vyyc1PTV5Sm06aZXyYozI\nN0q63MxeKulMSU8xs4+4+2sjtA2gR/3UU+eN5vNKBaXwDwkqUOIzzynG76kxsxdKenNR1crs7KzP\nzc1Fuy6AuDorXqRWoH7Vc2Z0+7eOnKxacZd+dHRxLCbd1IGZ7Xf32c7j1JEDWGIcZ0c2WdQReShG\n5ABQXt6InEWzAKDhSK0AY2pQ26GN2zZrdUAgB8bQoLZDG8dt1uqA1Aowhga1Hdo4brNWBwRyYAwN\naju0cdxmrQ5IrQBjqNcVD4ty3mWXtUUcjMiBMRQ6fb/sDvbjuM1aHRDIgTGUtwdn50i7bM47tF3E\nRWoFGFMha570kvNmLZXqEciBGqhr7TU572YgtQIMWdk8dJXIeTcDgRwYsjrXXpPzbgZSK8CQ1b32\nmpx3/TEiB4YsL99MHhqhCOTAkGXloU2tXPnG7XsHlivfdWBeG7fv1dpttw30Ohg8UivAkKU3cZhf\nOHrKdmqDWnSKxa1GCyNyYEjSI+Ib9xzSdZvXaWZ6asmemIN48FnnB6wojxE5MAR5I+LO4NoW+uAz\ntB697g9YUQ6BHKhIOsguM9Pxjm0Wjy4e10TGcUmaXj4Z1H5ouoSJPqOF1ApQgc5JP1nBWsnxyQlb\ncvzHP3288GFkmXQJE31GC4EcqEBWkM0yMz2lJ52+9B/Kiye8MH9dJl3CRJ/R0ndqxczOlPQlSWck\n7d3s7tf32y4wbDHXPwnJPbdHxNfedLCnNs6amtTC0cXM41mY6DM6YuTIj0na5O4/NrNJSV82s8+4\n+50R2gaGInZ5Xl5OesJMJ9xP+aBolyFmtZHVz/aHjS3NyEhS7nGMjr5TK97y4+TbyeRPdgIQaIjY\n5Xl5Oem/fc0v63vbX6Y7tm06+QHR66YPJ3L+r1t4dOkoHaMlStWKmU1I2i/p5yW92933ZZyzVdJW\nSVq9enWMywIDE7s8Lz3ppyhVE3puaN6dSpTRFyWQu/txSevNbFrSrWb2LHe/u+OcHZJ2SNLs7Cwj\ndtTaIMrzyuSk+9n0IY1KlPEQtWrF3Rck3S7pspjtAlVrQnle3ofKhBmVKGMmRtXKSkmL7r5gZlOS\nXizpr/vuGTBEZVIhw3Ld5nVLZoNOTU4QvMdQjNTKuZI+lOTJl0n6uLt/KkK7wFDVvTyvCR82qEbf\ngdzdvyFpQ4S+ACip7h82qAYzOwGg4QjkANBwrH6IsfO2XXdp5777ddxdE2a65rmr9JdbLop+nZhT\n/IFuCOQYK2/bdZc+cud/nfz+uPvJ72MGc3bgQZVIrWCs7Nx3f6njvWIHHlSJQI6x0m0d8JjYgQdV\nIpBjrEzkLAWYd7xXebMuWfcEg0Agx0hJb2i8cfveJbvqXPPcVZnvyzveqyZM8cfo4GEnRkbIA8b2\nA81BV60w6xJVMo+cGwwxOzvrc3NzlV8Xo23j9r2ZKxbOTE/pjm2bhtAjIC4z2+/us53HSa1gZPCA\nEeOK1ApGRtk1xJmwg1HBiBwjo8wDxs5t0tr59M6Ho0ATMCJH7fQ6Ui7zgLHbhB1G5WgaAjlqpd+p\n7aHLupbNp5OGQZ2RWkGtVDW1vcyEHdIwqDsCOWolRuVJ0aQgqVw+nXVTUHekVlAr/e5eH5qaKZNP\np6wRdUcgR63kbSjcbWp7On+9zGzJAlh5DzFD8+n9frgAg0ZqBbWyZcOMbrjyIs1MT8nUmpXZbVf4\nzvx13iqG/YyeWTcFdceIHLVTZkPhrPx1ln5Gz6ybgrrrO5Cb2SpJH5b0VEkuaYe7/32/7QIhQkba\nMUbP7FaPOouRWnlc0p+5+zMkPU/Sm8zsGRHaBQrljbQnzIJSM8Ao6HtE7u4PSnow+fr/zOxeSTOS\n7um3baBI3sNRgjfGSdQcuZmtkbRB0r6M17ZK2ipJq1evjnlZjDHy10DE9cjN7MmS/kPSX7n7J7ud\ny3rkkJj2DpSVtx55lBG5mU1KukXSR4uCOEZDXhAODc79rqkC4AkxqlZM0vsk3evuf9d/l1B3eUF4\n7vsP65b980HBmdUHgXhiVK1slPSbkjaZ2cHkz0sjtIuaygvCO/fdH7wmyaCmvYesswKMmhhVK1+W\nZBH6gobIC7ZlZlUOYto76RqMK6boo7Rutduh5w9i2jurFGJcEchRWl4Qvua5q4KDc9k1VUKwSiHG\nFWutoLSi2u2d++7XcXdNmOlVz8mf2h572jurFGJcEcjRk6wgvOvAvG7ZP38yV37cXbfsn9fsBWdX\nkqPuZQlcYBSQWkE0w85RDyJdAzQBI3JEE5qjHuSMTlYpxDgikI+x2AE1JEddtkSQafxAMVIrI6po\nYkwvO8MXtRlSUlgm/cLu9UAYAvkICgmAZfPZIW2G5KjLlAgOO+cONAWplREUso5J2Zrr0LVRinLU\nZUoEqQsHwjAiH0EhATCvtvqsqcme2wxRZkZnXh+pCwdORSAfQSEB8LrN6zS5bOmU+p889nhmDjqv\nzWVmpRaoKlMiyO71QJhoG0uUwcYSg9VZGSJlb3+24R2f0yOPLi55/8z0lO7YtqmwzU6D2GKNqhXg\nCQPdWAL1Err92UJGEJey0yWdbS4zW7La4SDWE6cuHChGIB9RIQGw7Nok6TbXbrst8xweRALVI0c+\npnYdmNdPjj2+5HhoDpoHkUB9EMjHUDvfvXD01NTKiuWTwTluHkQC9UFqZQxl1YS3heajQ/PwAAaP\nQD6G8vLYjzy6qF0H5ksFcwI3MHwE8gbqtyQv7yGnpNJVJ5QHAsNHjrxhstY8ufamg1pTYlJOtzx2\nmaoTFrUC6iFKIDez95vZQ2Z2d4z2kC8rv92u5g4NpFs2zGg6Zyp+maoTFrUC6iHWiPyDki6L1Ba6\nKBoxhwbSt1/+zL6rTljUCqiHKIHc3b8k6eEYbaG7kBFzSCCNsS0ateRAPVT2sNPMtkraKkmrV6+u\n6rIjJ2uD4U6hgbTfqhM2OwbqobJA7u47JO2QWotmVXXduum3yiNdvz2/cFSmJ3LkUrWBlFpyoB6i\nrX5oZmskfcrdn1V07riufhi6KmHZNgmkwHjIW/2QQF6hjdv35tZvzxCEARQY6DK2ZrZT0gslnWNm\nD0i63t3fF6PtUdLtIWTRbvIxMYoHRkusqpVr3P1cd5909/MJ4tmKHkJWUYPNJB5g9DBFf0CyRr0h\nFSf91GCHjLRDN1EG0BxM0R+AvFGvpJO123l6rcEOHWkziQcYPQTyPuw6MK+N2/cu2Xy4aNR7x7ZN\netdV66Ou5x06XZ5JPMDoIZD3qNsIOGTUG2NmZV7b3Y6zIQQwesiR96jbCDh0L8yY63mXuabEJB5g\nlBDIe9RtBPzOq9ZXPnW9zHR5NoQARguplR51yzXHTpuk5eXlO685PTWpMyeX6dqbDgavUw6gmaLN\n7CxjFGZ2DmK6faxrDqNvAAYvb2YnI/IeDXLUnSe0MoUNH4DxQo68D1XnmkMrU6gVB8YLI/IGCa0B\np1YcGC8E8gYJrQGnVhwYL6RWIqhqNcHQGnBqxYHxQtVKnwZZIcJyswDSBroe+TgrWlel12Dc+QFR\n5XrlAJqFHHmfulWIZK3H8ic3HdSGd3yucIIOJYQAQhHI+9StQiQrGEvSI48uFm7mQAkhgFAE8j51\nqxDpFnSLRteUEAIINZaBPG+9kl50m+FZFHS7BXpKCAGEGpuHnbsOzOvtu7+phaOLpxzvfIgYs1Kk\naGu3boGeEkIAocYikO86MK/rPvF1LZ7ILrVMpznKVoqEVJdkfYCEjK5ZbhZAiCipFTO7zMwOmdlh\nM9sWo80iZdIjN+45lBvE236wcLSnSpGi92zZMKOD11+qd121vtIFtgCMj75H5GY2Iendkl4s6QFJ\nXzOz3e5+T79t5ylbYx1S6XHe9FRPlSJ5r3Xu1sPoGsCgxBiRXyzpsLvf5+6PSfqYpCsitJur7Mi5\n6KFjO83RS6VI3msmsZkDgErECOQzku5Pff9AcmxgikbOnWmXSy5cqclllvme6anJk2mOXipFrtu8\nTlktu8TkHQCVqKz80My2mtmcmc0dOXKkr7a6jZyzZlPesn9eV128StNTkyfPXbF8Uu+6ar0OXn/p\nyZRHL5tFbNkwo7zsO5N3AFQhRtXKvKRVqe/PT46dwt13SNohtRbN6ueC3TYazku73P6tIzp4/aWn\nHG+P3DvL+8rmsmcCd7AHgEGIMSL/mqSnm9laMztd0tWSdkdoN1e3kXPoA8uskXvRtPk8TN4BMEx9\nj8jd/XEz+0NJeyRNSHq/u3+z754VyBs5nxc4Oi5atbBsX9ptMnkHQNWiTAhy909L+nSMtkLlzcDs\nlnZJi70oFeWFAIalkWutdEuLdKZdpqcmdebkMl1708FTJg6xKBWAUdHIQB4ym/KObZv0zqvW69jj\nJ/TIo4tLAj55bQCjopGBPDQtUhTwz5x84j8/XU8OAE3SyEWzQh9odps+35lHP/b4icxz2TcTQN01\nckQekhbZdWBeyyx7NueEWdAU/5gligAwKI0M5EUzMNsB+LgvnXc0NTmReVwqn5oBgDpoZGpF6l7u\nl7dX5oSZbrjyIt2451BfqRmm3gOok0aOyLvZdWA+M0hL0gn3UotjUaIIoAlGKpC3Uyp52gE4dHEs\nShQBNEHjUivdqkjyUirS0gAcMhOTqfcAmqBRgbxoZ6Buuetea8SZeg+g7hoVyIuqSJaZZVakTOSU\nIQLAKGhUjrxogk9eWeFxd+q/AYysRgXys1I7/KSZlJsbb6P+G8CoalQgz8uQhG43RP03gFHUqEC+\n8OhiX++n/hvAKGpUIM8LxCuWTy6p9+5E/TeAUdWoQJ43Qef6VzxzyQSf1z5vdeGEHwAYBY0qPyya\noEOgBjCOGhXIJSboAECnRqVWAABLEcgBoOH6CuRm9htm9k0zO2Fms7E6BQAI1++I/G5JV0r6UoS+\nAAB60NfDTne/V5KMRakAYGgqy5Gb2VYzmzOzuSNHjlR1WQAYeYUjcjP7gqSfy3jpre7+r6EXcvcd\nknYkbR4xs+8H9/IJ50j6YQ/vGzT6VQ79Kq+ufaNf5fTbrwuyDhYGcnf/9T4umtfmyl7eZ2Zz7l67\nh6r0qxz6VV5d+0a/yhlUvyg/BICG67f88JVm9oCkX5F0m5ntidMtAECofqtWbpV0a6S+hNhR4bXK\noF/l0K/y6to3+lXOQPplnrM9GgCgGciRA0DDEcgBoOFqF8hD128xs8vM7JCZHTazbanja81sX3L8\nJjM7PVK/zjazz5vZd5K/V2Scc4mZHUz9+amZbUle+6CZfS/12vqq+pWcdzx17d2p48O8X+vN7CvJ\nz/sbZnZV6rWo9yvv9yX1+hnJf//h5H6sSb32luT4ITPb3E8/eujXn5rZPcn9+XczuyD1WubPtKJ+\nvT6ZD9K+/u+kXntd8nP/jpm9ruJ+vTPVp2+b2ULqtUHer/eb2UNmdnfO62Zm/5D0+xtm9uzUa/3f\nL3ev1R9JvyhpnaQvSprNOWdC0nclPU3S6ZK+LukZyWsfl3R18vV7JP1+pH79jaRtydfbJP11wfln\nS3pY0vLk+w9KevUA7ldQvyT9OOf40O6XpF+Q9PTk6/MkPShpOvb96vb7kjrnDyS9J/n6akk3JV8/\nIzn/DElrk3YmKuzXJanfod9v96vbz7Sifr1e0j9mvPdsSfclf69Ivl5RVb86zv8jSe8f9P1K2n6+\npGdLujvn9ZdK+owkk/Q8Sfti3q/ajcjd/V53P1Rw2sWSDrv7fe7+mKSPSbrCzEzSJkk3J+d9SNKW\nSF27ImkvtN1XS/qMuz8a6fp5yvbrpGHfL3f/trt/J/n6B5IektTTZLECmb8vXfp7s6QXJffnCkkf\nc/dj7v49SYeT9irpl7vfnvodulPS+ZGu3Ve/utgs6fPu/rC7PyLp85IuG1K/rpG0M9K1u3L3L6k1\ncMtzhaQPe8udkqbN7FxFul+1C+SBZiTdn/r+geTYz0pacPfHO47H8FR3fzD5+r8lPbXg/Ku19Jfo\nr5J/Vr3TzM6ouF9nWmutmzvb6R7V6H6Z2cVqjbK+mzoc637l/b5knpPcjx+pdX9C3jvIfqW9Ua1R\nXVvWz7TKfr0q+fncbGarSr53kP1SkoJaK2lv6vCg7leIvL5HuV9D2erNIq3fElu3fqW/cXc3s9y6\nzeST9iJJ6QlSb1EroJ2uVi3pX0h6R4X9usDd583saZL2mtldagWrnkW+X/8i6XXufiI53PP9GkVm\n9lpJs5JekDq85Gfq7t/NbiG6f5O0092PmdnvqvWvmU0VXTvE1ZJudvfjqWPDvF8DNZRA7v2v3zIv\naVXq+/OTY/+r1j9ZTktGVe3jfffLzP7HzM519weTwPNQl6ZeI+lWd19Mtd0enR4zsw9IenOV/XL3\n+eTv+8zsi5I2SLpFQ75fZvYUSbep9SF+Z6rtnu9Xhrzfl6xzHjCz0ySdpdbvU8h7B9kvmdmvq/Xh\n+AJ3P9Y+nvMzjRGYCvvl7v+b+va9aj0Tab/3hR3v/WKEPgX1K+VqSW9KHxjg/QqR1/co96upqZWv\nSXq6tSouTlfrh7bbW08PblcrPy1Jr5MUa4S/O2kvpN0lubkkmLXz0lvU2pSjkn6Z2Yp2asLMzpG0\nUdI9w75fyc/uVrVyhzd3vBbzfmX+vnTp76sl7U3uz25JV1urqmWtpKdL+moffSnVLzPbIOmfJV3u\n7g+ljmf+TCvs17mpby+XdG/y9R5Jlyb9WyHpUp36L9OB9ivp24VqPTj8SurYIO9XiN2SfiupXnme\npB8lg5U492tQT3F7/SPplWrliY5J+h9Je5Lj50n6dOq8l0r6tlqfqG9NHX+aWv+jHZb0CUlnROrX\nz0r6d0nfkfQFSWcnx2clvTd13hq1PmWXdbx/r6S71ApIH5H05Kr6JelXk2t/Pfn7jXW4X5JeK2lR\n0sHUn/WDuF9Zvy9qpWouT74+M/nvP5zcj6el3vvW5H2HJL0k8u97Ub++kPx/0L4/u4t+phX16wZJ\n30yuf7ukC1Pv/e3kPh6W9IYq+5V8/3ZJ2zveN+j7tVOtqqtFteLXGyX9nqTfS143Se9O+n2XUhV5\nMe4XU/QBoOGamloBACQI5ADQcARyAGg4AjkANByBHAAajkAOAA1HIAeAhvt/eS9BTnw6q+QAAAAA\nSUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0], y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our **weights** from last lesson should minimize the distance between points and our line.\n", "\n", "* **mean squared error**: Take distance from `pred` and `y`, square, then average" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat-y)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we run our model, we are trying to predict `m`\n", "\n", "For example, say `a = (0.5, 0.75)`.\n", "\n", "* Make a prediction\n", "* Calculate the error" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = tensor(.5, .75)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred = x@a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calculate error" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(4.8796)" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "mse(y_pred, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What does that **mean**? Let's plot it" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3df5BdZZ3n8fe3O53QwZl0kKwrTSKx\nloVVYYh0iWMslTgD6qJEdBFq3RV/VMbRcVfdYYw/SihrZsGhShlrnJ1NKaKrg7D8aKLoZNTguFIm\n2tlEIkIEcTRpdYhCsgME0ul89497bvfp0+fXvfe5P869n1dVKt3nnvuch9PN9z75nu/zPObuiIhI\ndQ11uwMiItIaBXIRkYpTIBcRqTgFchGRilMgFxGpuCXduOjJJ5/sp512WjcuLSJSWbt27fqNu69K\nHu9KID/ttNOYmprqxqVFRCrLzH6edjxIasXMxszsVjN7wMzuN7PfD9GuiIgUCzUi/yvg7939jWa2\nFFgeqF0RESnQciA3sxXAy4ArANz9KHC01XZFRKScEKmVtcBB4HNmttvMPmNmJyZPMrNNZjZlZlMH\nDx4McFkREYEwgXwJ8ELgf7j7OuAJYHPyJHff4u4T7j6xatWih64iItKkEIH8AHDA3XdG399KLbCL\niEgHtJwjd/dfm9l+MzvD3fcBrwR+3HrXRESqbXL3NNdt28cvDx3hlLFRrrzwDDauGw9+nVBVK+8B\nvhRVrDwMvDVQuyIilTS5e5oP3r6XIzOzAEwfOsIHb98LEDyYB6kjd/c9Uf77bHff6O6PhWhXRKSq\nrtu2by6I1x2ZmeW6bfuCX0trrYiItMEvDx1p6HgrFMhFRNrglLHRho63QoFcRKQNrrzwDEZHhhcc\nGx0Z5soLzwh+ra4smiUi0u/qDzSrVLUiIiIJG9eNtyVwJym1IiJScQrkIiIVp0AuIlJxCuQiIhWn\nQC4iUnEK5CIiFadALiJScQrkIiIVp0AuIlJxCuQiIhWnQC4iUnFaa0VEKqlT26j12rXTKJCLSOV0\nchu1Xrp2FqVWRKRyOrmNWi9dO4tG5CJSCfF0hmec045t1MpeoxPXzqJALiI9L5nOyBJqG7W8HPgp\nY6NMpwTtdmzhVlaQ1IqZ/ZOZ7TWzPWY2FaJNEZG6tHRGUqht1OofGtPRyL+eA5/cPQ10dgu3skKO\nyM93998EbE9EhMnd06kj4DqjNho+/8xVXLdtH++7eU/pSpK0kXdeDjy+44+qVkRESqiPjrOMj41y\nz+YNTVWSZL0na+Qfz4F3agu3skJVrTjwD2a2y8w2pZ1gZpvMbMrMpg4ePBjosiLSz/JSKvF0RjOV\nJFnvGTZLPb+bOfAioQL5S939hcCrgXeb2cuSJ7j7FnefcPeJVatWBbqsiPSzvEqQay45a25U3Ewl\nSdZrs+49lwMvEiSQu/t09PcjwB3Ai0K0KyKDLWsUPD42ysZ140zunmb9tdszyxHzRtFjy0dSj69c\nPsI1l5zF+NgoFl3rDeeOc922fazdfBfrr90+9+CzV7ScIzezE4Ehd/+X6OsLgI+13DMRGXhXXnjG\norx1fXRcVJJYNIr2jOjvvjAH3oszOZNCjMifBXzXzH4IfB+4y93/PkC7IjLgNq4bXzQ6rqdU8vLn\n8fOyHD4yU+p4L87kTGp5RO7uDwO/F6AvIiKLZFWIZJUkGnDP5g2F7Zad2NOLMzmTtNaKiFTO5O5p\n0mtLyleXlJ3Yk9VeL1WxKJCLSOVct21f6gNOg9LVJXlpm7henMmZpAlBIlI5WWkNp7EHkGUm9vTi\nTM4kBXIRCaoTmy5k5bfH25Tu6LWZnEkK5CISTNlSvaxgX/ZDIK8scRApkItIMEULTkF2sJ/6+aPc\ntmu6VL12FdIdnaRALiLBlCnVywr2N+3cz2xilk7yQyCu19MdnaSqFREJpkypXt4aJ2l6qV67VymQ\ni0gw55+ZviBe/HhWsK/iqoO9QoFcREqrL1KVtXjU3Q+kL1EdP55Vl335eat7vl67VylHLiKllKlI\nKZMjz3tQOfGck/QAswkK5CJSSpmKlLLrl2Q9qNQDzOYotSIihfL2zYyPtqswnb0faUQuIrmK9s2M\nj7a7Xd/diVmlvUiBXERyld03s67R9EhR8C0bnKuwAUS7KLUiIrnK7pvZjHrwnT50BGc++NarYYpe\nj6vCBhDtokAuIrmK9s1sRVHwbSQ4V2EDiHZRIBeRXHkPMIvqyotkPUCtH28kOFdhA4h2USAXGXBF\nwThrAwagdNojS9ZszvrxRoLzIFfM6GGnyAAr+4Aw7QHm+mu3F9aVF8laX6V+vJHlartdMdNNCuQi\nA6zMJJ8szeSk4xUoY8tHMEjdsq2+QUSjwXlQJxQpkIv0qTJle608ICw7izPen/jo+rEnZ1LPS464\nBzU4NyJYIDezYWAKmHb3i0K1KyKNK5syaSQYJ0fTT6XUluflpPPq0euGzVouaRxEIUfk/xW4H/jd\ngG2KSBPKpkzyctDxwL1idIQnjh5jZraWCEkbTa9cPsJVr31+ZhAuM8o/7q4g3oQgVStmdirw74HP\nhGhPRFpTNmVStiLl0JGZuSCeZfnSJblBuEwZ4CCUCrZDqBH59cCfAb+TdYKZbQI2AaxZsybQZUUk\nTSMpk7IVKUWKRtxpo/+4QSkVbIeWR+RmdhHwiLvvyjvP3be4+4S7T6xalb6LiIiE0WpNdTOzIYtG\n08nR/8rlI4yNjiz4l4DSKs0JMSJfD7zOzF4DnAD8rpl90d3fHKBtEWlCqzXVWSP6LGU/JFSB0h7m\nGQX5TTVm9grgT4uqViYmJnxqairYdUUkrGTVC8DIkPGME5Zw6MkZxpaP4A6Hj8wM1MSbbjOzXe4+\nkTyuOnIRWWSQZ0lWUdAReVkakYuINC5rRK5Fs0REKk6pFZEB1M4t0QZ1u7VuUiAXGTDt3BJtkLdb\n6yalVkQGTDu3RBvk7da6SYFcZMC0c0u0Qd5urZuUWhEZMM2ueFgm393o0rYShkbkIgOm7PT9Rnaw\nb7RtCUuBXGTAZK14mBxpN5PvLtu2hKXUisgAKrPmSbP5bq2n0nkK5CI9oBdrr5Xvrg6lVkS6rJlc\ndCco310dCuQiXdartdfKd1eHUisiXdbLtdfKd1eDRuQiXZaVc1YuWspSIBfpsrRcNMCTR4+1NU8+\nuXua9dduZ+3mu1h/7fau5+SleUqtiHRZPXVx9db7OHRkZu74Y0/OtG3BKS1u1V80IhfpguRoGODE\nZYvHVe166NmrD1ilORqRi3RY1mg4GVjrGnnoWbYevZcfsErjFMhFOqQeZNMm2RyZmWXYjNmUrReH\nzJjcPV2Y8mgkXaLJPv1FqRWRDohP+sky65760HPWvdQEoUbSJZrs018UyEU6IC3IJtUn3AybLXqt\nTP66kXSJJvv0l5ZTK2Z2AvAdYFnU3q3uflWr7Yp0W8j1T4pyz/XR8MZ147zv5j1NtbFidGRB1Uv8\neBpN9ukfIXLkTwMb3P1xMxsBvmtmX3f3HQHaFumK0OV5WTlpqI2G4x8SZfPXyQ+amdnjqe2nDPCl\nz7ScWvGax6NvR6I/i5/YiFRI6PK8rJz09W86h3s2b1jw4VAmf5220NYTR9NTN4eeXDxKl/4SpGrF\nzIaBXcC/AT7t7jtTztkEbAJYs2ZNiMuKtE3o8rx6oC6Tqilzbpmce50qUfpfkEDu7rPAOWY2Btxh\nZi9w9x8lztkCbAGYmJjQiF16WjvK8xrJSRedW/YDRZUogyFo1Yq7HwLuBl4Vsl2RTuv18rysD5Sx\n0RFVogygEFUrq4AZdz9kZqPAHwIfb7lnIl3USCqkG6688IxFs0FHR4a5+nXP75k+SueESK08G/h8\nlCcfAm5x968GaFekq3q5PK/XP2iks1oO5O5+L7AuQF9EpAG9/EEjnaWZnSIiFadALiJScVr9UAbO\nRyb3ctPO/cy6M2zG5eet5s83nhX8OiGn+IvkUSCXgfKRyb18cccv5r6fdZ/7PmQw1w480klKrchA\nuWnn/oaON0s78EgnKZDLQEnbuCHveLO0A490kgK5DJS0tb7zjjcra+al1j2RdlAgl76S3NQ4uavO\n5eetTn1f1vFm9foUf+kvetgpfaPMA8b6A812V61o5qV0knng3GAZExMTPjU11fHrSn9bf+321BUL\nx8dGuWfzhi70SCQsM9vl7hPJ40qtSN/QA0YZVEqtSN9odA1xTdiRfqERufSNRh4wpm2V9sHb9y56\nOCpSBRqRS89pdqTcyAPGvAk7GpVL1SiQS09pdWp72aVdG82nKw0jvUypFekpnZra3siEHaVhpNcp\nkEtPCVF5UjQpCBrLp2vdFOl1Sq1IT2l19/qyqZlG8ukqa5Rep0AuPSVrU+Giqe31HHbah0DWQ8yy\n+fRWP1xE2k2pFekpG9eNc80lZzE+NopRm5V5zSVn5QbceA47SyujZ62bIr1OI3LpOY1uKpyWw05q\nZfSsdVOk17UcyM1sNfAF4FmAA1vc/a9abVekrKLRdojRs3asl14WIrVyDPhv7v484MXAu83seQHa\nFSklb7RdJjUjUnUtB3J3/5W7/9/o638B7gf0f410TFYO+/o3ncM9mzcoiEvfC5ojN7PTgHXAzpTX\nNgGbANasWRPysjLglMOWQRdsPXIzewbwj8BfuPvteedqPXKp09R3kfKy1iMPMiI3sxHgNuBLRUFc\n+kNeAC4bnFtdV0VEakJUrRjwWeB+d/9E612SXpcXgIHSwVkrEIqEEWJEvh74T8BeM9sTHfuQu38t\nQNvSg4rWHikbnNsx9V2pGhlELQdyd/8uYAH6IhXRTABOey301HelamRQaYq+NCxvCdhGlocNPfVd\nqxTKoFIgl4blBeBGgnMz66rk0SqFMqi01oo0rKhue+rnj3LTzv3MujNsxhvOzZ7eHnLqu1YplEGl\nQC5NyQrAk7unuW3XNLPR/IRZd27bNc3Ec05qe5662SVwRapOqRUJqpt56tCpGpGq0IhcgiqTp25n\niaBWKZRBpEA+4EIH1aI8daMlgqoLFymm1EofK9qEuJnd4YvaLKpaaST1ot3rRcpRIO9TZYJgo/ns\nMm0W5akbKRFUXbhIOUqt9Kky65g0Wndddm2UvDx1IyWCqgsXKUcj8j5VJgg2MguzbJtFGpkw1Gj/\nRAaVAnmfKhMEzz9zVeo5Wcez2hwyy8yZJzVSIqjd60XKUWqlT5WZHHP3AwdT35t1PK1NYG7yT9lF\nqsqWCGrnH5FyFMj7VJkg2GiqJNnmkNlcEK8LvZ646sJFiimQ97GiINjM2iTxNtduviv1HD2MFInc\newt862Nw+ACsOBVe+VE4+9Lgl1EgH2Dnn7mKL+34BfExdSM5aC1SJQMjGZBPvwAe/If8AH3vLfCV\n/wIz0f8jh/fXvofgwVyBfEDVF7eKB3GD3JUKk7RIlVTeV98Pu24EnwUbhnOvgIs+sTBwj66Eo4/D\n7NHaew7vh6nPzreRFaC/9bH5IF43c6R2XIFcQkirCXeyH3Sm0cNI6TmNpDK++v6FAdlna9//9iE4\n8P35IHzk0eLrpgXowwfSz8063gIF8gGVlhLJO55FDyOlY5JB+qTnwj/9H/Dj0QnDQGxwUpTK2HVj\n+nV+9o/N9S8ZoFecWutD0opTm2s/hwJ5RbW6mNRwSsVJ/Xgn+yGyIL2BwdLlcPTJ+RH1L3bArs/F\nAja1ALkoSM6ySF4qw1POb0UyQL/yowtz5AAjo7XjgSmQV1DaCoLvu3kP7715D+Mlg2laEM87XrYf\n2uxYctMbc6/tBxtaGJwBcDj6RO3Lw/vhjne2HnCzUhk2HC6YpwXo+n9zVapWzOwG4CLgEXd/QYg2\nJVtWfhvKB9PxjIqT8QYqTsquvSJ9ZEEgHp5/SOizsGJ1rZpj6nNAFKAP74fb/2j+/fER6qIgniJE\noM1KZZx7xcIced3aly/MkQMMjcCy34Ejj5WvWoHasTYE7qRQI/Ibgb8GvhCoPclRVKddJpiGqDjR\nolZ97t5b4OsfyH7YVw+y9b+T1RxzjsNX3gvLT1pcxdF2lp3KuOgTtb+LqlbaOJIOJUggd/fvmNlp\nIdqSYln123FFwTRExYnqyHtcUYojHqTraY4Vq+cD353vni+5a9XME3D4yTBtlWYw8bb8AHzRJ+YD\nelyHRtKhdCxHbmabgE0Aa9as6dRl+1LWmidxZYJpqxUnqiPvsrRA/YsdsQeHMfEKDlgcpP34wvOW\njIYL4nVZVRx5svLYIyfCqRMLq1aGlsKyZ8ynP3p8FB1SxwK5u28BtgBMTEyUf6LWh1qt9IiPpqcP\nHcGg6dmZrVAdeRuk5aDro+R4UEqbNVj0YLBewQH5QXrmSPgUiA2lV3GkWXpiStXKjYvTHzLHvIEq\nhdyGaqmVr5Z52DkxMeFTU1NBrls1yUoPqAXeVnZ7VwlgRSRnC8LCh2f33ZGdjx4Zhdd+aj6Yf/IF\njY9ugdr8XVj40R/Q0DAcT/kwmXh7IvecqFoZPQle/fGBGUE3y8x2ufvEouMK5J21/trtmfntsqWD\nUhF507ybsWI1vO9Hta+vHqOpYLxide3vog+B0ZPK9zf5LweNoNsmK5CHKj+8CXgFcLKZHQCucve0\nx9cDL+8hZCfrsDWKLyFvBF2Uf02mPspM8y4Sr4duJt8cr3XOe5A5MlobHUN21UreCPrsSxW4OyxU\n1crlIdoZBEUVJ52owx74iTwLZhJGknnovEBcZhW7tAWTWhWvh07LN+dNcEnLs+dVrdTPU6qjEjSz\ns82SI9/zz1zFbbumcytOmq3DLjvK7suJPPHca9zSE+Gi6+cDUnKhpLpkcC4KxEWr2IVeGCk5czBr\n1mDZtEbFyusknwJ5G6WNfG/bNc0bzh3n7gcOZo7Mm6nDbmSUXamJPGUmZiRHz3FHn4DJd9a+PvvS\n7IWSYGFwLhOI885pJvWRlDVKrksLxkprDCQF8hbljYKzRr53P3CQezZvyKxgaaZ0sJFRdk9N5EkL\n1BAbXceKKxtZ9znu+Ox8gC6a8l0PzmUCcd4qdmmpj/g077SqlTJTvkVSKJC3oGgUXDTyDVmH3cgo\nu+0TebIW669bkAZJBOo73w3ucHwmOjlRmdHIus9x9XOKFkqqB+eimueiVew6uGCSiAJ5C4pGwWVG\nvqHW825klN3SB0jahJU6G4Znng6/eWD+WH2xfpivI14QIBOBuky5W9l1n5PnQPZCSbAwOCcDcaNV\nK/U2FLilAxTIW1A0Cu7kFPZGrzX3AVIfPd85C1sLan6TQTg5svXZhUE8bteNtXZDVHOUWfc5bmh4\nPkAnF0qaazMlD61ALBWhQN6ColFwu6awF+Xlf3noCJedsIM/4wbGJh/H74SjI2Mse+11C8vrklO6\nk6PnpFaC8NwKeS1WcxSu+1xQtQLZCyWJVFSwmZ2N6JeZne2Ybl/mmt+942/4kN3ISh4H4AmWMbJ0\nlGUzh8CGcZ/FHYYSm/3M2gjDr/+b2jd563LYMFyVMgmk2dmE8TaLppYPL03kyGEuj55VvSEyINo6\ns3NQdWTRqERVx+z/ez7/3b7JUjs2d8ozeBpmnq5947MYkLZj27DPzC+alPfAL3NSSQsldedeUfs7\nNQ2SCNSgh4QiDdCIvBPydlUZPal2TtpDtJT66OMpI+3GlFg0KWtEnlevHXfymfDbB0tUrShQizRC\nI/J2ySqlqy9qDzB1w/zx5K4qeVO/U3LSrQVx5h8U5o2s66PnpGQuOq1qpcwCSXqIKBKURuRpkiPG\nrMkaZUeojaqvctdKTjpFqRz5yWfCn+wMdk0RCUcj8roFCyYltmQYORF+7zL44d8tXLA/XnccHzW3\nY2EkKJxd6MwnSPIcByw6eVHVCtT2UZyJdiy3ITj3rarmEKmg6o/IM0fPLa5zUWTF6ijgtuH+1Ufk\n997CsTvfw5LZp+ZeOjZ8AkvW/ceFmxCMnAhLltW+z9tVRkQqrb9G5Fm7e2fu4t0G9Q+O0B8YsTrp\nydn1fHfmHbyXL3OK/ZZf+jO5/vhlvHT8XWzUyFlEItUL5PfeApPvStQZd0E9V56TI3cHHxpmKJaL\nrv8DaNaGGOY4llO1ct22fUwffQm38pIF7X6vysvNikhw1Qvk3/pYB4J4cjvjhPqoOQq4v779Q/wr\nPzj3TqjlpyeHXsXe4TN5x9Evzo2o//LYpWw9/lIgmjx0QfbkoUotNysiXVO9QB56wf5FhmDirQur\nVPKWGD37UnbMrk+f4bnxLD5/8x5u5EWpVyrazKGnlpsVkZ5VvUAeMC/tGI7PVXYcGxplZOOnGn5A\nmDfD87pt+3K3dssbXXdy0S0Rqa7qBfJXfjQ/R7725fDow5nB/ikbZfPRtzE5u37Ra6Mjw1wzexY0\nsTFxMphft20fkB6M4/JG1x1ZAkBEKq+a5YdpVSt5u3pHPjK5ly/u+EVu0yuXj/DUzPGGF8LKW0AL\n4Oqt93HoyMIPn3YvsCUi/SWr/HAoUOOvMrN9ZvaQmW0O0Wausy+FD/yMyYt/zPoT7mDtU3/Hev9s\n6ig77qadxSmZx56cydwsIk/RJhN7rrqA6990DuNjoxgwPjaqIC4iQbScWjGzYeDTwB8CB4AfmNlW\nd/9xq23naWSz4brZFv71UVQpUqbCJNRuQCIicSFG5C8CHnL3h939KPBl4OIA7ebKGwFnGU5b2zVm\ndGSYsdGR1NeKKkXGlqe/L+u4iEgoIR52jgPxnMUB4LwA7eYqGgGn7aJz+XmrM3Pk49E5QFOVIlmD\n/S48ghCRAdOxqhUz2wRsAlizZk3L7eXVWGelXeoPHm/auZ9Zd4bNuPy81fz5xrMWtdNopcjhI+lV\nNFnHRURCCRHIp4HVse9PjY4t4O5bgC1Qq1pp9aJ5NdZ5aZd7Nm9YELgnd0+z/trti4J2o7lsTd4R\nkW4JkSP/AXC6ma01s6XAZcDWAO3m2rhunGsuOSu1CqTs1Pb6yH360BGc+ZH75O5Fn0OFrrzwDEZH\nhhcc0+QdEemElkfk7n7MzP4E2AYMAze4+30t96yErJFz2dFxUclgo32pt6nJOyLSSUFy5O7+NeBr\nIdpqVNpDzbJT20MvSqXyQhHphiATgrolKzUCLEi7jI2OcMLIEO+7eQ/rr90+lzrJyl8rry0iVVLp\nQF6UGrln8wY++aZzePrYcR57cmZRHlx5bRHpB5UO5GVSI0UTh04Ymb8FY6MjmjYvIpVTvdUPY8o8\n1MwK9vWReTzIP33seOq5aXl4BXsR6RWVHpGXSY1kTZEfNis1xT9kiaKISDtUOpDn1ZJDLQg//tSx\nRe8bGbbMBbSSI/hm1nQREemkSqdWIL/k77pt+5g5vjhgn7h0CScuW1Kq1lz7ZopIr6v0iDzP5O7p\nzC3WDh+ZKV2xohJFEel1fRnI63ntLKeMjRamZepUoigiva6SqZWiKpK0vHZdPAiXmYmpqfci0usq\nF8jL7AyUl79upk5cU+9FpJdVLrVSVEUyuXuaoYydgIp2CBIRqaLKBfK8CT710XpWaeGsu2rARaTv\nVC6QZ1WLGHD11vsyc+N1qgEXkX5TuUB+5YVnkJYgceBQyW3VVAMuIv2kcoF847pxWt0nTjXgItJP\nKhfIoVbznWbl8pFFNd9JqgEXkX5TyUCeNUnnqtc+f9Eknze/eE3hpB8RkSqrXB05FE/SUaAWkUFS\nyUAOmqQjIlJXydSKiIjMUyAXEam4lgK5mf0HM7vPzI6b2USoTomISHmtjsh/BFwCfCdAX0REpAkt\nPex09/sBTItRiYh0Tcdy5Ga2ycymzGzq4MGDnbqsiEjfKxyRm9k3gX+d8tKH3f3Oshdy9y3AlqjN\ng2b289K9XOxk4DctvL8derFPoH41ohf7BOpXo3qxX6H69Jy0g4WB3N3/IMDFk22uauX9Zjbl7j31\ncLUX+wTqVyN6sU+gfjWqF/vV7j6p/FBEpOJaLT98vZkdAH4fuMvMtoXploiIlNVq1codwB2B+tKI\nLV24ZpFe7BOoX43oxT6B+tWoXuxXW/tknrEtmoiIVINy5CIiFadALiJScT0ZyMuu4WJmrzKzfWb2\nkJltjh1fa2Y7o+M3m9nSQP06ycy+YWYPRn+vTDnnfDPbE/vzlJltjF670cx+FnvtnE71KzpvNnbt\nrbHj3bxf55jZ96Kf971m9qbYa8HuV9bvSuz1ZdF/+0PRvTgt9toHo+P7zOzCZvvQZL/eb2Y/ju7N\nt8zsObHXUn+eHerXFdF8kPr13xF77S3Rz/xBM3tLB/v0yVh/fmJmh2KvteVemdkNZvaImf0o43Uz\ns09Ffb7XzF4Yey3cfXL3nvsD/DvgDODbwETGOcPAT4HnAkuBHwLPi167Bbgs+vpvgT8O1K+/BDZH\nX28GPl5w/knAo8Dy6PsbgTe24X6V6hfweMbxrt0v4N8Cp0dfnwL8ChgLeb/yfldi57wL+Nvo68uA\nm6OvnxedvwxYG7UzHOj+lOnX+bHfnz+u9yvv59mhfl0B/HXKe08CHo7+Xhl9vbITfUqc/x7ghg7c\nq5cBLwR+lPH6a4CvAwa8GNjZjvvUkyNyd7/f3fcVnPYi4CF3f9jdjwJfBi42MwM2ALdG530e2Bio\naxdH7ZVt943A1939yUDXz9Jov+Z0+365+0/c/cHo618CjwAtTRhLkfq7ktPXW4FXRvfmYuDL7v60\nu/8MeChqryP9cve7Y78/O4BTA127pX7luBD4hrs/6u6PAd8AXtWFPl0O3BTgurnc/TvUBmtZLga+\n4DU7gDEzezaB71NPBvKSxoH9se8PRMeeCRxy92OJ4yE8y91/FX39a+BZBedfxuJfpr+I/on1STNb\n1uF+nWC19W521NM99ND9MrMXURtt/TR2OMT9yvpdST0nuheHqd2bMu9tVqNtv53a6K4u7efZyX69\nIfrZ3Gpmqxt8b7v6RJR+Wgtsjx1u170qktXvoPepa1u9WaA1XELL61f8G3d3M8us3Yw+dc8C4pOk\nPkgtoC2lVlf6AeBjHezXc9x92syeC2w3s73UAlbTAt+v/wW8xd2PR4ebvl/9xszeDEwAL48dXvTz\ndPefprcQ3FeAm9z9aTP7I2r/mtnQoWsXuQy41d1nY8e6ea/armuB3Ftfw2UaWB37/tTo2G+p/fNl\nSTSyqh9vuV9m9s9m9mx3/zG7K+oAAAI+SURBVFUUeB7JaepS4A53n4m1XR+dPm1mnwP+tJP9cvfp\n6O+HzezbwDrgNrp8v8zsd4G7qH2I74i13fT9Ssj6XUk754CZLQFWUPtdKvPeZpVq28z+gNoH48vd\n/en68YyfZ4jgVNgvd/9t7NvPUHseUn/vKxLv/XYn+hRzGfDu+IE23qsiWf0Oep+qnFr5AXC61Sou\nllL74W312pOEu6nlpwHeAoQa4W+N2ivT7qIcXRTM6nnpjdQ25uhIv8xsZT01YWYnA+uBH3f7fkU/\nuzuo5RFvTbwW6n6l/q7k9PWNwPbo3mwFLrNaVcta4HTg+032o+F+mdk64H8Cr3P3R2LHU3+eHezX\ns2Pfvg64P/p6G3BB1L+VwAUs/Fdp2/oU9etMag8Pvxc71s57VWQr8J+j6pUXA4ejAUrY+9SOJ7mt\n/gFeTy1n9DTwz8C26PgpwNdi570G+Am1T9YPx44/l9r/bA8B/xtYFqhfzwS+BTwIfBM4KTo+AXwm\ndt5p1D5xhxLv3w7spRaQvgg8o1P9Al4SXfuH0d9v74X7BbwZmAH2xP6cE/p+pf2uUEvTvC76+oTo\nv/2h6F48N/beD0fv2we8OvDvelG/vhn9P1C/N1uLfp4d6tc1wH3R9e8Gzoy9923RfXwIeGun+hR9\nfzVwbeJ9bbtX1AZrv4p+hw9Qe47xTuCd0esGfDrq815iVXgh75Om6IuIVFyVUysiIoICuYhI5SmQ\ni4hUnAK5iEjFKZCLiFScArmISMUpkIuIVNz/B0bdp6aV+sWfAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model doesn't seen to quite fit. What's next? **Optimization**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Walking down Gradient Descent\n", "\n", "* Goal: Minimize the loss function (`mse`)\n", "* Gradient Descent:\n", " * Starts with parameters\n", " * Moves towards new parameters to minimize the function\n", " * Take steps in the negative direction of gradient function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First let's make this `parameter`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([0.5000, 0.7500], requires_grad=True)" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(a); a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next let's create an `update` function to check if the current `a` improved. If so, move even closer.\n", "\n", "We'll print out every 10 iterations to see how we are doing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def update():\n", " y_hat = x@a\n", " loss = mse(y, y_hat)\n", " if i % 10 == 0: print(loss)\n", " loss.backward()\n", " with torch.no_grad():\n", " a.sub_(lr * a.grad)\n", " a.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* `torch.no_grad`: No back propogation (no updating of our weights)\n", "* `sub_`: Subtracts some value (lr * our gradient)\n", "* `grad.zero_`: Zeros our gradients" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(3.7350, grad_fn=)\n", "tensor(0.5128, grad_fn=)\n", "tensor(0.1768, grad_fn=)\n", "tensor(0.1048, grad_fn=)\n", "tensor(0.0890, grad_fn=)\n", "tensor(0.0856, grad_fn=)\n", "tensor(0.0848, grad_fn=)\n", "tensor(0.0847, grad_fn=)\n", "tensor(0.0846, grad_fn=)\n", "tensor(0.0846, grad_fn=)\n" ] } ], "source": [ "for i in range(100): update()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see how this new `a` compares. \n", "\n", "* Detach removes all gradients" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3dfXxcdZX48c+ZyaSdFkgo7c+2acuT\nWNcFpBARbZddnkRESwB50FVEdKuLyyJdgSJQCj/YFvAllt/6VAuriCIBIdZ2tSDFh1aLpBQKgpWn\nRZoWoZQEaNM2mZzfH/PQOzP33rkzc2cyMznv1wuT3Lm599tpPPn23PM9X1FVjDHG1K/IcA/AGGNM\neSyQG2NMnbNAbowxdc4CuTHG1DkL5MYYU+eahuOm48eP1wMOOGA4bm2MMXVr3bp1W1V1Qu7xYQnk\nBxxwAN3d3cNxa2OMqVsi8pLb8VBSKyLSKiL3isifReQZEflAGNc1xhhTWFgz8sXAL1X14yLSDIwJ\n6brGGGMKKDuQi0gLcCxwPoCq7gZ2l3tdY4wxwYSRWjkQeA34bxFZLyJLRWRs7kkiMkdEukWk+7XX\nXgvhtsYYYyCcQN4EHAl8W1VnANuBebknqeoSVW1X1fYJE/IeuhpjjClRGIF8E7BJVR9JfX0vycBu\njDGmCsrOkavqKyLysohMV9WNwAnA0+UPzRhj6lvX+h5uXrmRzb39TG6Nc+nJ0+mY0Rb6fcKqWrkI\n+FGqYuUF4LMhXdcYY+pS1/oerrjvSfoHEgD09PZzxX1PAoQezEOpI1fVx1P578NVtUNV3wjjusYY\nU69uXrkxE8TT+gcS3LxyY+j3sl4rxhhTAZt7+4s6Xg4L5MYYUwGTW+NFHS+HBXJjjKmAS0+eTjwW\nzToWj0W59OTpod9rWJpmGWNMo0s/0KynqhVjjDE5Oma0VSRw57LUijHG1DkL5MYYUykbOuGWQ2FB\na/Ljhs6K3MZSK8YYUwkbOuHn/w4DqXLDvpeTXwMcfnaot7IZuTHGVMJD1+0J4mkD/cnjIbNAbowx\nldC3qbjjZbBAbowxldAypbjjZbBAbowxlXDCfIjlrOKMxZPHQ2aB3BhjihWkGuXws+Fjt0LLVECS\nHz92a+gPOsGqVowxpjjFVKMcfnZFAncuC+TGGFPIhs5ktUnfJpAIaHZ72kw1ShWCthsL5MYY4yd3\nBp4bxNMqUI0SlAVyY0xdqtY2aq714G4qUI0SlAVyY0zdqfg2as5UClrw9H5t5qmDL+J95d+5JFa1\nYoypOxXdRi2dSul7Gb8gPqgRhlTYNDSeywc+z5efPqT8e5fIZuTGmLrgTKV4hddQtlELkErZoc3M\nG/g8y4ZmZY5JBbZwC8oCuTGm5uWmUryEso2a70NL4RXG858DZ2UF8dDuXaJQArmI/C/wFpAABlW1\nPYzrGmMMuKdScpW0jdryubDu+8lKFInCUeezIz6RMf1b8k7dEZ/EmMv/zNr1PTx435MwtGc8ldrC\nLagwc+THqeoRFsSNMWHqWt9Dj0/aQoC21jhnHtXGzSs3cuC8FcxctIqu9T3+F14+F+2+bU85oSbQ\n7tt4sn8CO7Q569Qd2sxNA+cAyYepC884jLbWeObeC884rCo7AXmx1IoxpmalUype2lrjrJl3fElV\nLEPr/jtvJivAUfoUcwe+yGVNnUyW19ms+3HT4Nn8fNfRLEidV60t3IIKK5Ar8ICIKPBdVV2Se4KI\nzAHmAEybNi2k2xpjGplfSsWZzvCrYvEKuKJDrsejDLFsaBbLdmfnwNuGMQdeSFiplVmqeiRwCvAl\nETk29wRVXaKq7araPmHChJBua4xpZH5VKM50htd5ft+fUPfwlyBCPBbNOjbcOfBCQgnkqtqT+vgq\ncD9wdBjXNcaMbF6VIG2tcTpmtNG1voeZi1Z5liP6VZLcKyehOd+omjyemwMvOv9eZWWnVkRkLBBR\n1bdSn38ICH8vI2PMiHPpydPzyg7Ts+NCJYmFZtGL5PPsGkzwz9FVRBkiQYQfJY7nltjnedyRA6/4\nKtIQhJEjfwdwv4ikr/djVf1lCNc1xoxw6UDp1lNl5qJVWUF8dmQ1C2J3sC9vg8DuWAujol8D3DsS\n9vUPcA0XcM3gBVnHZXAg6+tS8u/VVnYgV9UXgPeGMBZjjMnjVSHiLEmcHVnN12JLaJbBzLFRA33Q\ndWHyC5f2spNb465ljbnpmFLy79VmvVaMMXWna30P4vj6sqbOrCCeMTTguWv9pSdPD/RQ0yvPPpwr\nOXNZIDfG1J2bV25kQdPtPD/qn3lx1Cdpk63eJ3ssuQ+6sCdowB9OtiDIGFN3vvD2N/l09FeIFD7X\nr094kIU9fnn6WmGB3BgTqopt+ODoEf7pJiVIDCcSC2XX+lpbyZnLArkxJjRBS/W8gr3nL4Gc7dYC\nBfH4ODjlxmHbR7OaRHMr4qugvb1du7u7q35fY0xlzVy0yrUSJN0TBdxb0sZjUc48qo2fruvJO778\noPs5+KWfBB9Ey1S45KnS/xA1TETWuTUmtBm5MSY0QUr1vOqy73rkZRI5E8t5+j0OeulXwQcg0VBS\nKfXGqlaMMaEJUqrnFexzgzjAP0dXBUujAMTGwunfGRGplFw2IzfGhOa4d0/gzrV/dT2e5rUQJyrC\nqfK7VPvYrWzW8URx71AIQCwOH7t1RAbuXDYjN8YElm5S5dU86uE/v+b6fc7jXnXZCw95hhtjS5kS\n2UpEYErEpzYcLIg72IzcGBNIkIqUIDny3Lrsr8Xv4HQeIPLXobxyFM868fbPWRB3sEBujAkkSPOo\noP1LOqJr6JDLYfQ2PHvQOkk0a19NPvr1Uv8YDckCuTGmIL99M52zbb+2s0CyHvwXl0P/tuA3b+By\nwrBYIDfG+Cq0b6Zztu27nD1nUU8gsXhR5YQVW1Va4yyQG2N8Bd03M81zOftD1xUXxFumwgnz6UrM\n5OZFqwoG53rYAKJSrGrFGOMr6L6ZBXl0IcylCs/vfy5c8hRdiZlccd+T9PT2o+wJzm5brfnl8Bud\nBXJjjK9C+2bm2dAJtxwKC1qTHzd0Jo/7dCGEZABPqHBH4kTO+9s5QHHBuR42gKgUS60YY3wV2jfT\nmZP+xnue5X2PzYP0Qp6+l+G+LyQ/P2F+Xo48vZhzm+7FtYPnsWxoVvJAKvgWE5yDVsw0Igvkxoxw\nhR4Qej3ABPJy0n+37mqQ3NWYQ/DzL8OVm5NfplrR0jKFS7bOpisxM29M0VQBeTHBuWDFTAOzQG7M\nCBb0AaHbA8zczY8BxrLL/UYD25MfDz87ayFP17wVrqen+64UE5zrYQOISrFAbswIVs4O8Zt7+5kd\nWZ3VG6UQ5+y/dUwMwX09UFtqxl1scK71DSAqxQK5MQ0qSE11yQ8IN3SyfvRcWvStzDL6KbIVz+0N\nJJI3+39jx4Drqbkz7pEanIsRWiAXkSjQDfSo6kfDuq4xpnhBUybF5KC71vfw+Iol/PvAUvaVt2kF\n194omn8Yjvqsbz16WlSkuJJGA4Q7I78YeAbYJ8RrGmNKEDRlErQi5dzRa7lSv8tp7Cq44bGAa2+U\nzavd8+FOQ6oWxEsQSiAXkSnAqcANwNwwrmmMKV3QlEmhipSTEr9hWfMdjNO3g+1YD569Ubxm/7nn\nmOKFNSP/BnAZsLfXCSIyB5gDMG3atJBua4xxU0zKxKsiZZ5+j0/HfkUk8BY9+PZGcZv9O42UUsFK\nKHtlp4h8FHhVVdf5naeqS1S1XVXbJ0yY4HeqMaZMXps3BA2U7W8+yKejwYO4KvSyt+9mDx0z2lh4\nxmG0tcYRYN8xMVrjycqVtta45cbLEMaMfCYwW0Q+AowG9hGRO1X1UyFc2xhTgnJrqq9ovifQLE81\nuSpzIeczq+NCOg73v75VoFSGqGe9UAkXE/kn4CuFqlba29u1u7s7tPsaY8KlC1qRAjs+KPDDwRP5\n7l5fGjELb4abiKxT1fbc41ZHbsxIt6Eza9k8J8xHWqYk+6R4iY9DTrmR8w4/m/OqN1LjIdRArqq/\nBn4d5jWNMRXitltP38vJxlbv/SQ88eOc/uEC7RfYNms1yNrYGjMSLZ8L981x33JtoB+efSD54LJl\nKiDJj2cssSBeoyy1YsxIs6ET7b4tf/WlU9+mvAZXQY3U7daGkwVyY0aYHb+Yz5hCJxXYBMLLSN5u\nbThZasWYEWZ0/yv+JxS54bHTSN5ubThZIDemUXlsubZ5aD/X01WB+DjfRT2FjOTt1oaTpVaMaUQb\nOrO3VUtXowBLmz/FZQPfYozszpw+pHB/5MOcefndWZcpNt89krdbG042IzemkSyfC9eOg/v+Jad0\nkOTXD13HEafOYb7OYdPQeIZU2DQ0nsv034jOzq5ISee7g+xgn1ZuawBTGpuRG9Mols+F7tv8z+nb\nlJpRX8g5K0/wnWmXsnvQSN5ubThZIDemUaz7fuFzUtUoQXqelJrvtn4q1WeB3JgaUFLtde7SevXf\nfafYahTLd9cPy5EbM8xKyUVnHmb2vQyof18USK7MLLIaxfLd9cMCuTHDrKTa64euy3+Y6aX9c8kd\ne4osKcztH249w2uXpVaMGWYl5aL7Nnm/5rJfZqks310fLJAbM8xKykV7tZn12C/TNDZLrRgzzNxy\n0bMjq7ln57+gOasyM06Yn3x46VTkw8yu9T3MXLSKA+etYOaiVf45eVPTbEZuzDBLpy4WLPsTx+56\nmAWxO9iXt/d0J3SsyszkudMfczaECJoHt+ZWjcUCuTHDwLXcMLqGy2JLs5bOZ6RWZWYF6hLbzEJp\ni31M7bJAbkyVuc2GV9//LRbJt2iSIe9v9HvA6bh2kHp0a27VWCyQG1Ml6SCb+2BzdmQ118lS/yAO\nBXuEF5MuscU+jcUedhpTBc5FP5AM3qub/50XRn2Sr8e+455OcejXZh49+CLfc4qpR7fFPo3FZuTG\nVIEzyM6OrGaRIxcewXsmrgrbdC+uHTyPdU8fwprZ3vcoJl1iza0aS9mBXERGA78FRqWud6+qXlPu\ndY0ZbmHuPbm5t5/ZkdVc1tRJm2xFfDfMTBrUCHMHvsiyoVkASIH8dUs8Rm//gOtxN7bYp3GEMSPf\nBRyvqm+LSAxYLSK/UNW1IVzbmGERdnneZ/b6I5cNeFSkuNihzcwb+HwmiEN+/jr3F81Awn1mH+SX\nhqlvZQdyVVXg7dSXsdR/Wu51jRlOoZTnOboTzhchUuhhpkRBh9gRn8j87WeybOiDmZdy89duv2i8\n9O7In6WbxhJKjlxEosA64J3AN1X1EZdz5gBzAKZNmxbGbY2pmLLL83K2WotogblNLJ7pTjgGmLW+\nhz/4pHXcftF4sUqUxhdKIFfVBHCEiLQC94vIoar6VM45S4AlAO3t7TZjNzWt7PK8YroTtkzNW5VZ\nKH8d9BeKVaKMDKGWH6pqL/Aw8OEwr2tMtZVdnhdg8Q6xOJzxvZJazHr9QmmNx6zt7AgURtXKBGBA\nVXtFJA6cBNxY9siMGUZll+d5dSdM5cGL7Y2S69KTp2flyCH5i2bB7L+3wD0ChZFamQT8IJUnjwCd\nqro8hOsaM6zKKs87YX5WjhzIyoOHMTawOnCTFEbVygZgRghjMaZxlNmdMAirAzdptrLTmEopozuh\nMcWwXivGBLGhM7nBg9dGD8YMI5uRmxHnqq4nueuRl0moEhXhE++fyvUdh3l/Q05NuOtGDy7CXOJv\njB8L5GZEuarrSe5c+9fM1wnVzNd5wTyzMtOl+sRtowcH24HHVJOlVsyIctcjLkHZ7Xh6Fu4WxNN8\nasWLaSlrTLkskJsRJeGxVD7veJCVmT4bPdgOPKaaLJCbESXq0Qow73ihlZkFdqz3WnlpfU9MJVgg\nNw2la30PMxet4sB5K5i5aBVd63uyXv/E+6e6fl/ecb9t1VqmFlzYYzvwmGqyh52mYQR5wJh+oHnX\nIy9zqvyOy2OdTJbXkRenwAbHgp0yV2bayktTTaKF2mtWQHt7u3Z3d1f9vqaxzVy0yrVjYVtrnDXz\njt9zYPlc6L6dvLb5uYHa0U+8EiszjSmWiKxT1fbc4zYjNw0j0APG5XOh+zb3C+SWFNrKTFMnLEdu\nGkagB4zrvu9/kSDtZ42pMRbITcMI9IBR/XfV2RGfWImhGVNRlloxNafUpe2BHjBK1DOY79Bmbho4\nhwVh/CGMqSIL5KamlLu0vSO6hg65HEZvg53AA+MgeuOeXPdR5+flyFVhO6P56sAF/HzX0a6B3Pqm\nmFpmqRVTU8pa2r6hE+77AvRvc3zzNvjZl/Z0K/zo16H9cwwSQRUGNcIdiRM5dNftLBua5ZpnT/9y\n6entR9nzyyW3Rt2Y4WKB3NSUkpe2b+iE++YAQ/mvJXYnq1HSPvp1lp/2FO8Zupt37rqTawYvALwX\n7FjfFFPrLJCbmlLS0vZ0g6vcunAHzalG6ZjRxsIzDgu0UbH1TTG1znLkpqZ4bSrsu7Q9QIOrvzGe\n3HqUoFulTW6Nuy40sr4pplbYjNzUlGJmymm5s+281xUW7j6r5DFZ3xRT62xGbmpOsZsKJ2fbr7m+\npgp3JE6ke5+TyhoPWN8UU7vKDuQiMhW4A3gHySTlElVdXO51jQlq4e6zWBhbyhjZnTmmCm+wFwsG\nzuPB6D+ysMzZs+1Yb2pZGDPyQeA/VPUxEdkbWCciD6rq0yFc24x0ARpXde9zEvPehMuakp0MN+t+\n3DR4NsuGZiVTMzZ7Ng2u7ECuqluALanP3xKRZ4A2wAK5KU/ATY+TD0h3s2z3rMyxeCzKN87yz60b\n0yhCzZGLyAHADOARl9fmAHMApk2bFuZtTaMpctNjy2GbkS60fuQishfwG+AGVb3P71zrR27Scpe+\nf+M9z/K+J68pUE4osKC3amM0plZUtB+5iMSAnwI/KhTETWPw6z0StC+JW1+VyetuAil902NjRqIw\nqlYEuA14RlW/Xv6QTK3za2wFBG56dfPKjZyU+A2XNXcyWbayWcczia3+Ny+w6bExI1EYM/KZwKeB\nJ0Xk8dSxr6rq/4RwbVODCvUe8XotE8hTOfDV/S+jMYikNrCfIlsZ8sv0tUwtuN2adSk0I1EYVSur\nAQlhLKZOlNJ7JPOaoxJFJP8HJyLJtldZS44DbnpcbgtcY+qVLdE3RfNrbFWw6VWAvigCydk3kvwY\ncOd661JoRipbom+KVqixlW/TqwB7YkrLVLjkqaLHZV0KzUhlM3JTNL/GVh0z2jjzqDaikkyaREU4\n8yjH8vZCFSdlPMwsqQWuMQ3AZuSmJF69R7rW97DzsZ/wm9hPMpUo33jsXLr2H5c8/4T52as1gWQy\nRQM9zPRTUgtcYxqABXITqsdXLOEG+TajJBlMp8hWbtBvs3BFEx0zrt0TpAv0TymFrfA0I5UFchOq\ni3YvZVQk+4HjKElw0e6lwLUAdCVmcvOuW9m8s5/Jo+NcmphOR0j3ty6FZiSyQD7ClV13ndOdcFzk\nbdfT0seLLRG0unBjCrNA3sAKBcFS6q7T12x/80Gubf4hLby1pxbcrclVDr8Swdx7Wl24McFY1UqD\nSgfBnt5+lD1BsGt9T+acYuuu09c86s0HWRhbSqsziKd4rQyT+DiguBJBqws3JhgL5A0qSBAstu46\n3Rvl67HvZO3GU1C0GU65ESiuRNDqwo0JxgJ5gwoSBIutu25/80EWxZbSJEOFB+BcmXnaNzNVKcVs\nZGx14cYEYznyBjW5NU6PSzB3BsHj3j2BO9f+Ne+c4949wfWaVzTfwxgCzMTj4zxXZhZTImh14cYE\nY4G8QQUJgg//2X3nea/j7yjUYhbYrVGe+Lt5vM/nnKAlglYXbkwwFsgbVJAgWGwOWlqmuFampDeZ\n6tHx3DR4NuuePoQ1s8v8A6RYXbgxhVkgb2CFgmCQ9EsWl+X1O7SZeQOfZ9nQno2PxR5GGgPAihdW\nsPixxbyy/RUmjp3IxUdezKkHnRr6fSyQj2DHvXsCP1r7V5x7OfjmoHOW17/CeP5z4KysIA72MNI0\nntyAfOyUY/ntpt/6BugVL6xgwe8XsDOxE4At27ew4PcLAEIP5hbIR6h0c6sNzd9jL9kFwBDCH/fr\n4AMzPuz9jYefnQnoa9f38OB9T8KQPYw09en6tddzz1/uYUiHiEiEs951Flcdc1VW4N6neR92DO5g\nYGgASAbkuzfenbmGV4Be/NjiTBBP25nYyeLHFlsgN+F4fMUSFsl/0eRYwRNFOWbb/bB8HHy08Par\n9jDS1JpiUhnXr70+KyAP6RB3b7ybl/pe4vHXHs8E4b7dfQXv6xagX9n+iuu5XsfLYYF8hPqP3d+m\nyWUVgQCs+36gQA72MNJUT26Q3n/v/XnklUfQVHIwQoQh9qxxKJTKuOcv97jeZ+0ra0saX26Anjh2\nIlu2b8k7b+LYiSVd348F8jpVVjOpDZ2ZdIorTXi/FuY4jCE7vQEwpmkM/YP9mRn1+lfX07mxMxOw\nIRmkc4OkM4in+aUy0vcLS26AvvjIi7Ny5ACjo6O5+MiLQ70vWCCvS27NpC65+3G+fPfjtAUJpg9d\nh/htly1Rnxf9x2FNrYxfeiP92pbtWxAkKzin7RjcASSD9Vd/91XXAF0Mr1RGRCKhBXO3AJ3+M9dN\n1YqI3A58FHhVVQ8N45rGm1sflfT/HbKCaXSN+wYOPvtmKiBHnV/yOLw6GZrG4AzE6UCY/jhp7CSO\nnXJs1ux5y/YtXPG7KzLf75yhugXxXOUGcfBOZZz1rrOycuRpx0w8JitHDtAkTezVvBd9u/oCV61A\nMphXInDnCmtG/n3gv4A7Qrqe8VGoaVT/QILEsrmgK8mE+L6XkzXgkAzqHgt7Xh1/DO8ImB+3plaN\nbcULK1j0x0X07up1fT09m01/zK3mSFOUa39/La2jW/OqOKrBK5Vx1TFXARSsWqnkTDosoQRyVf2t\niBwQxrVMYV4LedJmR1Zz+tAv83vKDvQnZ+guC3uGgBcPOJeDP/vdssdhdeS1oVCKwxmk02mOSWMn\nZQLf1WuuzpTclas/0c/O7dUP4udMP8c3AF91zFWZgO5UrZl0WKqWIxeROcAcgGnTplXrtg3JrY8K\nJAP4ZU2dtMlW7xx43ybXfTMjJ8zn4CL3zbSmVsPLLVCvf3V91oPDNGcFB+QHaWcqZMHvFzAqOiq0\nIJ7mVcXhJ7cSJS0ejfPeCe/NqlqJSYyxzWMz6Y9an0WHSVQL56kCXSg5I18eJEfe3t6u3d3dody3\nHoVR6ZG+Rk9vP6dFVnNN7A725W3/h5iQbCvr0ZmwFFa1Ei63HHR6luwMSrmrBsE76DlNGjsJoOiA\nWi5BWPgPC/PG7MatasUt/TESicg6VW3PO26BvLpyKz0gOYtdeMZhpQXADZ0M/uwimgLlHgXOWBLK\njvWmeLmrBUUk6+HZL1/8pefik9HR0Sz44IJMMP/QvR8qKRhLKt8W5EFjKaISJeFSvnrO9HOycs+5\nVSstzS1c8f4rRswMulQWyGvEzEWrPPPbgUoH0zKbHhfeJzNJoP2CwAt9TPn8lnmXYtLYSTzw8QcA\nOPwHh5cUjIPOyFtHtbJ9YHug8eb+y8Fm0JXjFcjDKj+8C/gnYLyIbAKuUdXbwrh2o/Gr6Ahch718\nLnTfDkH/j9wydU/pYYqlRApzBuKWUS2oKm/ufjNQ/jU39RFkmXchznroUvLNzlpnvweZo6OjmXf0\nPADPqpXWUa3MO3qeZ8mdBe7qCqtq5RNhXGckKFRxUrAOe0Nn8CAei8PHbs1LpYz0hTy5KwmBvDx0\nbiB2BrMgXezcGiaVy1kP7bZq0C9H7pZn96taSZ9nqY76YCs7Kyx35nvcuyfw03U9eRUnTlmz9kwK\nJbWoZ/d2CgVxVeiTvXn2sKt5n0s+vBEX8jhzr05jmsYw/wPzMwEpt1FSWm5wLhSIC3WxC7sxUu7K\nQa9Vg0HTGvVWXmf8WSCvILeZ70/X9XDmUW08/OfXPGfmmTrsDZ3Z9d4F8uGqe3bpWTY0i/ijURZO\n7ckLzvW0kCfIwgy3Co60HYM7uHL1lUAyeHk1SoLs4BwkEPudU0rqI5fXLDnNLRhbWmNkskBeJr9c\ns9fM9+E/v8aaecd7VrBk6rAfui5r0Y6fIeDLAxdmbfLgNcuupYU8boEacJ1dF9P32SmhiUyALtRb\nIx2cgwRivy52bqkP5zJvt6qVIEu+jXFjgbwMhXLNhWa+Bft5+/REySbcOXhC3k49zns5VXohj1ez\n/jSvNMiW7Vu4es3VqCqDOuh67WL6PjulzynUKCkdnN0CsVOhLnbVbJhkjAXyMhTKNQeZ+fr28/bo\niUJ8HDSPzWqG9d3/GQ8BZ9nlbAjhtmAlLSIRDtz7QJ5/8/nMsXSzfiBTR+wXIIOUuwXt+5x7Dng3\nSoLs4JwbiIutWklfwwK3qYbQ6siL0Sh15AfOW+H62FGAFxedGmzxT+7DTGeZYG6OHAJXorjey0Wh\n2bNToSDsJyIRnjjviZIXsjg566mDjCsqUW6YdUPWA89CVSvG1KKK1pGPVIVm3AVnvm4PM9MdCh17\nY+YG+q7ETG5etMozL7+5t5+9x29Ax3Vx1RM7uPoJGNO0D/M/+NWs8rrcXs+5s+dc5ZTUpYNmudUc\nhfo+F6paAe9GScbUK5uRl6Hk5faFVmX69EPpWt/DVx/4ATL+Z0g02YCfoRhjm+PsSLyZSXeoktd3\nJSpN3DDregDfhv3p2XOuUlcTOq9ZaEYei8Q8c+Q2azYjnc3IK6DoXPOGTvjF5dC/zf/CjoecuVUd\nr/3tICL/Zy0SceTmowPsSCRzy+mZr1vzrIQOsvixxcnzfJoreT0MLKek7qx3nQX4P0R0tlC1h4TG\nBGcz8ipY8cIKFq9dyJbdvURIlgqmP7YkhhCBvkiEiYMJLn6jl1Ob9oNLnnLN/brNtIsRpGmS14w8\naI784H0O5sW3XixYtWKB2pji2Iy8Qvz2IDxn+jkAe6okRDLz4PTHvqY9+2NuiTWxYPx+cODpnIp7\nTrqcIA57qjf8Ztbp2XOu3Fy0W9VKkAZJVs1hTLhsRu4id8botVijnCoOP+mqjHJy0m6C5MgP3udg\nuk7vCu2expjw2Iw8xVl6ljuDjkfjzH7nbH723M8ywTl3H0Ln6sJKNEaC4lYX+tLUZsrkV60AXPv7\na+lPJCtmBOHs6WdbNYcxdQqoLCIAAA1YSURBVKjuZ+Res+dK74AyaewkXtn+SkUa9Kdn5CteWMHV\nq69hQHdlXovJKM54Vwcr/3dlpnNdPBpnVNMoenf1+u4qY4ypbw01I1/xwgoWPrIwr8ez1y7elZD+\nxRH2LwxnnfRA3xHs3HIGMu4XSKwXHWhl57ZTOPSwz3DVuTZzNsYk1V0gX/HCCq5afZVnL45qSefK\nF6y+mp3qsaxclSiQcD6hTP0LSIigorQ0t2Q1T3LOom9euZEdve+FN96bddl6bjdrjAlf3QXyxY8t\nHvYgnp41n/r2dtj6Oov3GcOWpihCdqfwc958iyN27ebWfVt5pSnKxMEEF23rZdubH+SawQuIx6J8\nxWfxUD21mzXGDJ+6C+RhN+zPlX7o56xS8WwxesuhnPpmL6e+mb8VVpoqHPHWGCbL62zW/bhp8PxM\nl8JCmznUUrtZY0ztqrtAHnpe2lHZEZNR/N9/uDb4A8IAbWb7ZG9m7brV83W/2XWl280aYxpD3QXy\ni4+82DdHfszEY3jprZc8g32EUfT3dLD7zRl5r8VjUQYOPyxvs4hvvOdZ3vf8/8vvUOjVZjZliAjP\nHnk18Uejnlu7+c2uy2k3a4wZOeqy/NCtasVvV++0q7qe5M61f/W99r5jYuwcGMoE3tmR1dwYW0pc\ndu85Kd1KFvLazKbfzm26Fws5n1mnXwjAgmV/orc/+6FooAZbxhiTUtHyQxH5MLAYiAJLVXVRGNf1\nkl7i7Zw579MaZ2C6f8rhrkf897wEeGNHdrC9rKkzO4hDMnA/dN2eDoUPXcdQ3yY2D+2X2S8z7Q8r\nN7Jm3vF0zGjz3RbOGGNKVXYgF5Eo8E3gJGAT8KiILFPVp8u9tp9C26y5SQT818fsyGoua+pksmzF\ns7VJOj+e6ht+sMcmE84cuO9uQMYYU6JICNc4GnhOVV9Q1d3AT4DTQriuL79t1rxEC3SciseifGL0\nWhbFljIlspWI+DSpapmS9WXrmJjraV7HjTEmLGGkVtoAZ85iE/D+EK7rq1CNtVsa4xPvn+qZI29L\nnfOhB77MmP7drudkxOLJB54OXpP9YXgEYYwZYapWtSIic4A5ANOmTSv7en411l5pl4VnHAYkc+UJ\nVTqia7hh9J2MTfTBTuCBcQU2fZD8fTVT+vrdV3d6HTfGmLCEkVrpAaY6vp6SOpZFVZeoaruqtk+Y\nMKHsm1568nTisWjWsXSNtV/a5fqOw3h+4Uf4309u5xvN30kG8cxJ27xbYLVMhQW9yQecOUEcvMsI\nbfGOMabSwgjkjwKHiMiBItIMnAssC+G6vjpmtLHwjMNoa40jJFMj6VI+37TL8rlw7Ti4719A82u7\nc5fZA66plFx+v1iMMaaSyk6tqOqgiPwbsJJk+eHtqvqnskcWgFcViFfa5WvxO6D7lwWvq4C0TM1f\nAFRgLGCLd4wx1VeXC4Kc3B5qAlk58tmR1VweK1BO6LBpaDxTrns+lPEZY0xYvBYEhZFaGTbph5o9\nvf0o2bXk6bTLHbEbWNz8LdoCBnFVWNr8qYqO2xhjwlR3vVac/B5qrjl0OR07b0ejGiiAQzKI/1hP\n4ohT54Q/WGOMqZC6DuRuDzVnR1Zzw47boDu5PVqhIK6p/+nR8Xwz8kne3/FFy2sbY+pKXQfy3Iea\nsyOrWRRbypjc3iguVCFBhB8ljueawQuAZJWJ20om65FijKlldZ0jzy35u6ypM3AQv2TwS7xz152Z\nIA7uS/y98vBd6/NK5Y0xZljU9Yy8I7qGU0Z/heZo357dIQpQhdV6KF2Jma6v56Zr/PLwNis3xtSC\n+gzkGzrh51+Gge2MSh8rEMQ1tRPQDxMnckvsC7S1NgXaRs32zTTG1Lr6S61s6ISuC2Fge6DTVSGh\nwh2JEzlo14+5ZvAC+voHAq/EtKX3xphaV38z8oeugyH/RlQKqEpqs+PsjR4gGYSDrsS0fTONMbWu\n/gJ5gA2P/8YEjtm12PU1ZxAOstGDLb03xtS6ugvkO+ITGdPvvrEyANFmFvaf5flyKXtk2s4+xpha\nVnc58psGzmGXRvOOK0DzWB597/Us139w/d5COwQZY0w9qrtA/oO3j+bSgS/w+tBeyUoUTe5Yf/Hu\nC+k65VHOe3R/z705E6pWA26MaTh1l1qZ3BpnWe8slu3OfoApwG+X/Smv5juX1YAbYxpN3c3ILz15\numvJuAK9AbdVsxpwY0wjqbtA3jGjzXs7toCsBtwY00jqLpBDcls3N/uOieUt8sllNeDGmEZTl4Hc\na1XmNR/7+7x9PD91zDTXfT2NMaZR1N3DTii8SMcCtTFmJKnLQA62SMcYY9LqMrVijDFmDwvkxhhT\n58oK5CJyloj8SUSGRKQ9rEEZY4wJrtwZ+VPAGcBvQxiLMcaYEpT1sFNVnwEQa0ZljDHDpmo5chGZ\nIyLdItL92muvVeu2xhjT8ArOyEXkV8BEl5euVNWfBb2Rqi4BlqSu+ZqIvBR4lPnGA1vL+P5KqMUx\ngY2rGLU4JrBxFasWxxXWmPZ3O1gwkKvqiSHcPPeaE8r5fhHpVtWaerhai2MCG1cxanFMYOMqVi2O\nq9JjsvJDY4ypc+WWH54uIpuADwArRGRlOMMyxhgTVLlVK/cD94c0lmIsGYZ7FlKLYwIbVzFqcUxg\n4ypWLY6romMS9dgWzRhjTH2wHLkxxtQ5C+TGGFPnajKQB+3hIiIfFpGNIvKciMxzHD9QRB5JHb9b\nRJpDGtc4EXlQRJ5NfdzX5ZzjRORxx387RaQj9dr3ReRFx2tHVGtcqfMSjnsvcxwfzvfrCBH5Q+rv\ne4OInON4LbT3y+tnxfH6qNSf/bnUe3GA47UrUsc3isjJpY6hxHHNFZGnU+/NQyKyv+M117/PKozp\n/NRakPS9P+947TOpv+9nReQzYY0p4LhucYzpLyLS63itUu/V7SLyqog85fG6iMitqTFvEJEjHa+F\n916pas39B/wdMB34NdDucU4UeB44CGgGngDek3qtEzg39fl3gH8NaVw3AfNSn88Dbixw/jhgGzAm\n9fX3gY9X4P0KNC7gbY/jw/Z+Ae8CDkl9PhnYArSG+X75/aw4zrkQ+E7q83OBu1Ofvyd1/ijgwNR1\noiG9P0HGdZzj5+df0+Py+/uswpjOB/7L4+f9hdTHfVOf71utceWcfxFweyXfq9R1jwWOBJ7yeP0j\nwC8AAY4BHqnEe1WTM3JVfUZVNxY47WjgOVV9QVV3Az8BThMRAY4H7k2d9wOgI6ShnZa6XtDrfhz4\nharuCOn+XoodV8Zwv1+q+hdVfTb1+WbgVaCsBWMuXH9WfMZ6L3BC6r05DfiJqu5S1ReB51LXq8q4\nVPVhx8/PWmBKSPcueUw+TgYeVNVtqvoG8CDw4WEa1yeAu0K6tydV/S3JyZqX04A7NGkt0Coikwj5\nvarJQB5QG/Cy4+tNqWP7Ab2qOphzPAzvUNUtqc9fAd5R4Pxzyf9huiH1T6xbRGRUlcc1WpL9btam\n0z3U0PslIkeTnG097zgcxvvl9bPiek7qvegj+d4E+d5SFXvtz5Gc3aW5/X1Wa0xnpv5e7hWRqUV+\nbyXHRSr9dCCwynG4Eu9VEF7jDvW9Grat3iSkHi5h8xuX8wtVVRHxrN1M/dY9DHAukrqCZEBrJllX\nejlwXRXHtb+q9ojIQcAqEXmSZMAqWcjv1w+Bz6jqUOpwye9XoxGRTwHtwD86Duf9farq8+5XCNXP\ngbtUdZeIfIHkv2SOr8J9gzoXuFdVE45jw/VeVcWwBXItv4dLDzDV8fWU1LHXSf7zpSk1s0ofL3tc\nIvI3EZmkqltSgedVn0udDdyvqgOOa6dnp7tE5L+Br1RzXKrak/r4goj8GpgB/JRhfr9EZB9gBclf\n4msd1y75/crh9bPids4mEWkCWkj+LAX53lIFuraInEjyF+M/ququ9HGPv89yg1PBManq644vl5J8\nFpL+3n/K+d5flzmewONyOBf4kvNAhd6rILzGHep7Vc+plUeBQyRZcdFM8i9vmSafJDxMMj8N8Bkg\nrBn+stT1glw3L0eXCmbpvHQHyY05qjIuEdk3nZoQkfHATODp4X6/Un9395PMI96b81pY75frz4rP\nWD8OrEq9N8uAcyVZ1XIgcAjwxxLHUfS4RGQG8F1gtqq+6jju+vdZpTFNcnw5G3gm9flK4EOpse0L\nfIjsf5FWdFypsb2b5MPDPziOVeq9CmIZcF6qeuUYoC81QQn3varEk9xy/wNOJ5kz2gX8DViZOj4Z\n+B/HeR8B/kLyN+uVjuMHkfw/23PAPcCokMa1H/AQ8CzwK2Bc6ng7sNRx3gEkf+NGcr5/FfAkyYB0\nJ7BXtcYFfDB17ydSHz9XC+8X8ClgAHjc8d8RYb9fbj8rJNM0s1Ofj0792Z9LvRcHOb73ytT3bQRO\nCflnvdC4fpX6/0D6vVlW6O+zCmNaCPwpde+HgXc7vveC1Hv4HPDZar5Xqa8XAItyvq+S79VdJCut\nBkjGrM8BXwS+mHpdgG+mxvwkjiq8MN8rW6JvjDF1rp5TK8YYY7BAbowxdc8CuTHG1DkL5MYYU+cs\nkBtjTJ2zQG6MMXXOArkxxtS5/w8jDzXc2rhxsgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0], (x@a).detach())\n", "plt.scatter(x[:,0],y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We fit our line **much** better here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Animate the process" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "rc('animation', html='jshtml')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's redo the process and animate our y closing in" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([0.5000, 0.7500], requires_grad=True)" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(tensor(0.5, 0.75)); a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll want to set a new `y` to our `x@a`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def animate(i):\n", " update()\n", " line.set_ydata((x@a).detach())\n", " return line," ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create a base figure" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure()\n", "plt.scatter(x[:,0], y, c='orange')\n", "line, = plt.plot(x[:,0], (x@a).detach())\n", "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And animate!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "animation.FuncAnimation(fig, animate, np.arange(0,100), interval=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ideally we split things up into batches of data to fit, and then work with all those batches (else we'd run out of memory!\n", "\n", "If this were a classification problem, we would want to use `Cross Entropy Loss`, where we penalize incorrect confident predictions along with correct unconfident predictions. It's also called `negative loss likelihood`" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }