{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "WNixalo | 20181112 | fast.ai DL1v3 L2\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks. We will illustrate the concepts with concrete examples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of linear regression is to fit a line to a set of points." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n=100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "create a column of numbers for the x's, and a column of 1's.\n", "\n", "Instead of having a special case of $y = ax+b$, instead we'll always have a second x value which is always 1 – thus allowing us to do a simple matrix-vector product. $y = a_1x_1 + a_2x_2$\n", "\n", "Ahh, interesting. So our biases are encoded into the weights then?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.6318, 1.0000],\n", " [ 0.0486, 1.0000],\n", " [-0.5819, 1.0000],\n", " [-0.1692, 1.0000],\n", " [-0.3766, 1.0000]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.ones(n,2) # create an n x 2 tensor of 1's\n", "x[:,0].uniform_(-1.,1) # replace col0 with uniform random numbers\n", "x[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "create some coefficients; a1 is `3`, a2 is `2`\n", "\n", "This creates a 'vector' or rank-1 tensor. `3` & `2` represent the coefficients: the slope (3) and intercept (2) of our line.\n", "\n", "Ref: Fast.ai DL1v3 Lesson 2 [[1:21:35](https://youtu.be/Egp4Zajhzog?t=4895)]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3., 2.])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor(3.,2); a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "generate data by creating a line via `x@a` and add some random noise to it.\n", "\n", "The columns of 1s is just to make the linear function convenient." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "y = x@a + torch.rand(n)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHLBJREFUeJzt3XGQHOdZ5/Hfo/XaXpOQtbEAe21FylWQic+FNtkyLlQFkS5lJZjYix2wDbkLd6FU4TiK+IJgXUmd5au78nI+Lqm7owoEhMAllSixHSGjBGFYpVJR4YTVSYqs2EoUOwSvfVgh3oCjjbySnvtjeuTe2e6Z7p63Z3p6vp8qlXZ7erpf9a6efffp533a3F0AgPpY0+8BAADCIrADQM0Q2AGgZgjsAFAzBHYAqBkCOwDUDIEdAGqGwA4ANUNgB4CauagfJ73yyit9/fr1/Tg1AAysQ4cOfcvd13bary+Bff369Zqfn+/HqQFgYJnZ32XZj1QMANQMgR0AaobADgA1Q2AHgJohsANAzRDYAaBm+lLuCADDYs/hBT24/4SeW1zS1eNj2rFto6YnJ0o9J4EdAEqy5/CC7n3kmJaWz0mSFhaXdO8jxySp1OBOKgYASvLg/hMXgnrT0vI5Pbj/RKnnZcYOABnlTas8t7iUa3sozNgBIINmWmVhcUmuV9Iqew4vpL7n6vGxXNtDIbADqJU9hxe0eXZOG2b2afPsXNvAm0eRtMqObRs1NjqyYtvY6Ih2bNsYZExpSMUAGFitqZEt163Vw4cWSrlZWSSt0jwnVTEAkEFSxcnHHv+mvGW/5qy622B69fiYFhKCeKe0yvTkROmBvFWQVIyZfcPMjpnZETOjHy+A0iWlRlqDelOIm5X9SqsUEXLGvsXdvxXweACQKk+wDnGzsl9plSJIxQAYSGmpkVajayzYrDpLWqUfK01bhaqKcUl/aWaHzGx7oGMCQKqk1MjoGtPIGluxbfm8a+fe48GqY9opUhJZhlCBfbO7v1HS2yT9qpn9ZOsOZrbdzObNbP7UqVOBTgtgWE1PTuiB22/QxPiYTNLE+JhedelFOnd+daZ9cWm5JwG2XytNWwVJxbj7c9HfL5jZpyXdKOnzLfvskrRLkqamptLucQBAZq2pkQ0z+1L3DVUd006/Vpq26nrGbmbfZ2avbn4s6WZJT3R7XADIq9NN0iIBNs+Cp36tNG0VIhXzQ5K+YGZHJX1J0j53/4sAxwUw5PKuIk3Ku8flDbB5c+ZVKYnsOhXj7k9L+rEAYwEwJLJUjhRpedvcfv+jx/Xi6eUVrxUJsO1y5kljqEpJJOWOAHoqa8DOG1Sbmnn3EGWHRdsI9Lu2ncAOoKfSAvb9jx5fERC7vREZIsAWbSPQb3R3BNBTaYH5xdPLK3LXVbgRWZWceV4EdgA91S4wx+u9qxBUk2rlH7j9hr6nWjohFQOgp3Zs26j37j6S+Fp8Nl+VG5FVyJnnRWAH0FPTkxPaufe4FpeWV73WOpsfxKBaBaRiAPTczluv73uapc6YsQPoidbywzveNKEDT52qfAvcQURgB5Bb3hrxpNr1hw8tDMSNyEFEKgZALkVa01al6+GwYMYOIJciK0I7LTaqwsMp6oTADiCXIitC263gbNdiQOp/ueMgIrADyKXIMvsd2zauCN7SK1Uwab8B7Nx7XGfOns/UBIwZ/0rk2AHkUmRFaLsVnGkz/cWl5cSA/75PHl2Rz6/K4+iqhBk7gFzyrAjNMpPO+lDqpnPuK2buRbtA1hmBHUBuWVaEZm3Pm5amuXR0zaqe6k3xwF2Vx9FVCakYAKXIWuKYlqa57+2rV6fGNQN3FbpAVg0zdgClyDOTbvcbwPs+eVTn3FdtbwbudjdmhxUzdgClCDGTnp6c0O/8/I+1vVk7qK11y8SMHUApQs2ks9yspQvkSgR2AIW1q3oJ2U+dwJ0PgR1AIVmqXroJyCw6Ko4cO4BCymzsxaKj7hDYARRSZv043SC7EywVY2YjkuYlLbj7z4Q6LoBwQqY3ivSMyYpFR90JmWP/dUlPSvr+gMcE0IV4IH/N2Ki++/JZLZ9r1IS3a6qVRZn142X+0BgGQVIxZnaNpFsk/WGI4wHoXmueenFp+UJQb+omvZGnfnzP4QVtnp3Thpl92jw71zFXXqTRGF4Rasb+IUm/KenVaTuY2XZJ2yVp3bp1gU4LIE1SnjpJN+mNpKqX1nTPluvW6uFDC5na78aP2/w3UBWTX9eB3cx+RtIL7n7IzN6ctp+775K0S5KmpqZWrw8GEFTWgB0yvZFUAvmxx7+p1v/wWbovUrteXIhUzGZJt5rZNyR9QtJWM/togOMC6EKWgB06vZH0W0LaLI4boeXpOrC7+73ufo27r5d0l6Q5d39n1yMD0JWkPPXoGtPll41KkkbMLsycO+W8s+bI8wRrboSWh5WnQE2l5aklZeqT3pS1r7qU/aEZoyPGjdASmSe0wyzb1NSUz8/P9/y8AKTNs3OJwXdifEwHZ7Z2tX/rD4E042OjOnLfzTlHDjM75O5TnfZjxg4MmbyLf/L2VZde+S0hbdr4naXkJyMhDAI7UEPtVpjmXfyTd/94NUvabJ/8ernoFQPUTKcGWju2bdToGlvxntE16TnvbhYLsdCoP5ixA32QtWdLkd4u7RpoXXivtbyp9fOYbhYLsdCoP7h5CvRY0g3GsdGRVcvxk/YzNerCJ9oEyA0z+xJz2ybpmdlbct88RXVkvXlKKgbosawtadst9mnXn7zTs0bpnFh/BHagx7IG1k6BNq2BV6e8doiHTOdt6oXeIscO9Fhalcn4ZaPaPDt3IRf9mrFRLXYoC8xSctia1+623W6eBUvoD2bsQI8lLvUfMb30vbMrKlm++/LZVdUrrYqUDeZpt5uEpxtVHzN2oMeSZtTfPXN21ex8+Zzr8stGddnFF2lhcenCjdOmtFl22Q+ZJkdffczYgT6YnpzQwZmtemb2Fh2c2Zq6EnPx9LIOzmzVh+7cpPGoeZfUWJKfNssue0YdIkePchHYgQpoFyybM/AXT78S/M+cPZ96rLJn1Cw6qj4CO1AB7YJl3hl41hl10cqWbnP0KB85dqAC2lWy3LP7SOJ70mbgWapeuq1s4elG1UZgByoiLVgWacIlNX5ILCwurXigRvP1TG0HMLBIxQAVlzen3ewv06ykORe1DVlYXNI9u49o/cy+1IdhUNlSD8zYgYrL00irNcXS2jOmU2coKlvqgcAODICsOe2kFEtWVLbUB4EdqLC8bXuLpFJMop1uzRDYgYoqUrmS9WHSTbTqrSdungIVVWQFadKNVmv5u4nUS30R2AFVsw1tkcqV6ckJ3fGmCY1YI4yPmOkXb1qnb8zeog/euYlFRUOCVAyGXhXb0O45vLCq6VdTu8qVPYcX9PChhQsljufc9fChBU299goWFQ2RrmfsZnapmX3JzI6a2XEzuz/EwIBeqWIb2gf3n0h9vF279EkV/y3ovRAz9jOStrr7S2Y2KukLZvZZd388wLGB0sQX8iTp52KdtHO72v8WQUtdSAFm7N7wUvTpaPSn90/IBnLYc3hBOx462raCpJ+LddLOPdFhTLTUhRTo5qmZjZjZEUkvSHrM3b8Y4rhAWe5/9LiWz6XPP/pdMVK0NS4tdSEFunnq7uckbTKzcUmfNrN/6e5PxPcxs+2StkvSunXrQpwWKCze27zVRMHFOnkXE7WTp41AiPehXsw9bNbEzO6T9F13/+9p+0xNTfn8/HzQ8wJ5rJ/Zl/raN2ZvyX281soaqTFTTispDPlDAMPDzA65+1Sn/UJUxayNZuoyszFJb5H0VLfHBco0Pjaaa3sneapRmj8E4g+uvveRY5WonUc9hEjFXCXpT8xsRI0fFJ909z8PcFygNDtvvV47PnVUy+df+Y11dI1p563XZ5pNt+6TpbKmXRUOvdARUteB3d2/LGkywFiAnknLRUvquFgpaUFTmmY1SlKqphUliQiFlacYWkkrMTfPznV8slDW1rjxapQs76EkEaEQ2IGYtFnzwuKSNszs69g9ccRM591XpXA6zcYpSURIBHYgpl3gbt7oTOvhIknn3fVMQlVNu+MWLa8E0tDdEYhJWuDTql2BcFo6JW3h0Ifu3KSDM1sJ6giKGTuGXmuFyx1vmtCBp07puagcMU3rzL1dOoWFQ+il4AuUsmCBEqqi08KizbNziSkUk/SLN6278AOAQI1eyLpAiRk7hlq7hUXTkxPasW2j7tl9ZNXM3SUdeOoUj5VDJZFjx1Dr1OZ2enIiNR1D3Tmqihk7hlpatUr8JuhEm33o+YIqYsaOoZalzW3aPluuW0vPF1QSgR1DbXpyQg/cfkPbhzyn7XPgqVM8hg6VRCoGQ6+1FLEZmFuDe2uK5Z7dRxKPR+4d/caMHUOvaBtdHkOHqiKwY+jl6aUex2PoUFWkYjD0OpU8pmE1KaqKwI6hl6XkMU1S7h3oNwI7KiOpJlwqf0a8Y9vGxLYCpFQwqAjsqISkpxLt+NRRyaTlc35hW+vTjEIgpYK6IbCjEpJuYMafR9oU4tmgaatFCeSoCwI7SpV1yX2e2u9u6sSTfjMo47cAoJ8od0Rp8tSH56n97qZOvGhpIzBICOwoTZ4gmlQTPrrGNDpiK7Z1e1OzaGkjMEhIxaA0eYJo2g3MpG3dpEy6KW0EBgWBHaXJG0TTbmCGzH1T2ohh0HUqxsyuNbMDZvakmR03s18PMTAMviouuc/SzREYdCFm7Gclvc/d/6+ZvVrSITN7zN2/EuDYGGB56sN7+cAKShtRd10Hdnd/XtLz0cf/bGZPSpqQRGAfcFmDbbv9sgTRsksQecoRhk3QHLuZrZc0KemLCa9tl7RdktatWxfytChB1mAbIih3eqB0kbE3A/n4ZaN66XtnLyx2om4dwyBYuaOZvUrSw5Le6+7/1Pq6u+9y9yl3n1q7dm2o06IkWUsVu6kL/8CeY/oX934m8QarVKwEsbV2/sXTy6tWsFK3jroLMmM3s1E1gvrH3P2REMdEf2UtVSxaF/6BPcf00ce/2XafIiWIST9oklC3jjrrOrCbmUn6I0lPuvv/6H5IqIKspYp5SxqbaZK0WXpT0eqZrAGbunXUWYhUzGZJ/1rSVjM7Ev356QDHRR/t2LZx1arP0RFbFWzzlDTG0yTtdFOCmCVg97vkEihbiKqYL0iyjjti8LQ2V1zdbDFXSWOWNMmImQ7ObJVUrJolaQHS6Ijp+y6+SN9ZWqYqBkOBladIDKAP7j+x6qbj8nlPrFTJWheeJU1y949fe2FMRapt6K0OENiHXloATZtZd3PTMS0fLzVm6nf/+LX6L9M3SOquBJIFSBh2BPYhlxZAR8x0zlfnXrq56ZjWpyUpnx6yBBIYNgT2IZcWKM+5a2x0JGizrKxpkj2HF2RKTOlTzQJkQGAfcmnpkYlYrj1krjpLmuTB/ScSg7pJVLMAGRDYh1y7NrZFctUh+rKk/Rbhog0AkAWBfciFrCIJ1cyr3W8RADojsCNYFUmoZl48DAPoDoG9hvrVpjbU80SpRQe6Q2CvmbJ7m7cT8nmi1KIDxQVr24tqSEuH7Nx7XJtn57RhZp82z85pz+GF4Oeu4qPwgGHEjL1m0tIei0vLWlxallTeLJ4UClANBPaaabdsP66bJxS1QwoF6D8Ce03E+5ynrdpsxfJ8oJ4I7DXQesPUpQvBfWJ8TKdfPqsXTy+veh/L84F64uZpDSTdMG0G9YMzW3Xf26/npiYwRJix10Cn+nFuagLDhcBeA1nqx7mpCQwPUjE1QP04gDhm7H0Wavn/JRetuZBnv/yyUd339uvbHqf1vFuuW6sDT50iVQPUAIG9j0Is/289hiR9b/l87vN+9PFvXni9l20IAIRHKqaP2nVDLPMYSe9plXccAKqDwN5HIbohFjlG1uMvLC6V1lcGQHmCBHYz+7CZvWBmT4Q43rBIWyC0xixzME07RrvFR3kWJjXTMgR3YHCEmrF/RNJbAx1raCRVs0iNB0lnDaZZKmL2HF5Y0dlxy3VrE8+bhrQMMFiCBHZ3/7ykb4c41jCZnpzQA7ffoBGzVa9lDabNY0yMj8nUWG36wO03XLjp2bxRurC4JFdjBv7woQXd8aaJFe95503r2j56jr4ywOCgKqbPpicndM/uI4mvZQ2m7RYfpd1cPfDUKR2c2bpq/82zc8EelgGgP3p289TMtpvZvJnNnzp1qlenHQhF8uSdNNMvaS18035osNgJGHw9C+zuvsvdp9x9au3atb067UAIHUzj6Zc0aT80OqV2AFQfqZgKCN2kq1OdeqcfGvSVAQZbkMBuZh+X9GZJV5rZs5Luc/c/CnHsYREymLbLzU/QLgCovSCB3d3vDnEchJHW7bHZnx1AvbHytIa4AQoMN3LsNcSDNYDhRmCvKW6AAsOLVAwA1Awz9goK9fANAMOJwF4xIR6+AWC4Edgrpt2DM9ICOzN8AHEE9orJ++AMZvgAWnHztGLyNgQL8Xg9APVCYK+YvIuLQjxeD0C9kIrpkax58LyLi9LaB9A/HRhezNgLaH3UXKdH2CU9xajdo++mJyd0cGarPnjnJknSPbuPpJ6H9gEAWhHYc8obpKViefCs56F/OoBWpGJyKlKOWCQPnuc8tA8AEEdgV7468CJBukgenJuiAIoa+lRM3tRKkeeTbrku+VGAaduLngcAJAJ77vx3kZuVB55Kfnh32vai5wEAiVRM7pRHkV7nRdIq9FQHUNTQB/Yi+e+8NyuL1ppzUxRAEUOfigmd8kiqcSetAqCXhj6wh6wDT7sRK4lacwA9Y+7e85NOTU35/Px8z89bRJ5SyM2zc4kpl4nxMR2c2Vr2UAHUnJkdcvepTvsNRY69aL/yvC1x026GJgV7AChL7VMxRVoANGUthWzm1dN+97FoHwDohSCB3czeamYnzOykmc2EOGYo3fQrz1KmGP/BkcajcQBAL3Qd2M1sRNLvSnqbpDdIutvM3tDtcUPpZml+ltWfST84ip4PAEIIMWO/UdJJd3/a3V+W9AlJtwU4bhDdLM3PUqaYNWDTCgBAr4QI7BOS/j72+bPRthXMbLuZzZvZ/KlT6UvpQ+umhjxLKWSWgE3NOoBeClEVYwnbVt1HdPddknZJjXLHAOfNpNul+Z1Wf+7YtnFF5Ywkja4xverSi7R4eplWAAB6LkRgf1bStbHPr5H0XIDjBlPm0nx6ugComhCB/W8lvd7MNkhakHSXpF8IcNyBQU8XAFXSdWB397Nm9h8k7Zc0IunD7n6865EBAAoJsvLU3T8j6TMhjgUA6E7tV54CwLAhsANAzRDYAaBmCOwAUDMD37a3aEteAKirgZ6xJ7Xkfe/uI9p0/1/SJhfA0BrowJ7WWXFxaTlzz3UAqJuBDuztOitm7bkOAHUz0IG9U2dFeqADGEYDHdiTWvLG0QMdwDAa6KqYZvXL/Y8e14unl1e8Rg90AMNqoGfsUiO4H/5PN+tDd25q+0AMABgWAz1jj6N1LgA0DPyMHQCwEoEdAGqGwA4ANTNwOXZ6wwBAewMV2Ju9YZptBBYWl3TvI8ckieAOAJGBSsUk9YahdQAArDRQgT2tRQCtAwDgFQMV2NNaBNA6AABeMVCBPak3DK0DAGClgbp52rxBSlUMAKTrKrCb2c9J2inpRyXd6O7zIQbVDq0DAKC9blMxT0i6XdLnA4wFABBAVzN2d39SkswszGgAAF0bqJunAIDOOs7YzeyvJP1wwkvvd/c/y3oiM9suabskrVu3LvMAAQD5dAzs7v6WECdy912SdknS1NSUhzgmAGA1UjEAUDPmXnzybGY/K+l/SVoraVHSEXffluF9pyT9XcvmKyV9q/Bgylfl8TG2YhhbMVUem1Tt8XU7tte6+9pOO3UV2EMys3l3n+r3ONJUeXyMrRjGVkyVxyZVe3y9GhupGACoGQI7ANRMlQL7rn4PoIMqj4+xFcPYiqny2KRqj68nY6tMjh0AEEaVZuwAgAB6GtjN7OfM7LiZnTez1DvDZvZWMzthZifNbCa2fYOZfdHMvmZmu83s4oBju8LMHouO/ZiZXZ6wzxYzOxL78z0zm45e+4iZPRN7bVOosWUdX7TfudgY9sa29/vabTKzv4m+/l82sztjrwW/dmnfQ7HXL4muw8nouqyPvXZvtP2EmXUs3y1hbP/RzL4SXae/NrPXxl5L/Pr2cGy/ZGanYmP45dhr74q+B75mZu/qw9g+GBvXV81sMfZa2dftw2b2gpk9kfK6mdn/jMb+ZTN7Y+y18NfN3Xv2R432vhslfU7SVMo+I5K+Lul1ki6WdFTSG6LXPinprujj35P0KwHH9t8kzUQfz0j67Q77XyHp25Iuiz7/iKR3lHjtMo1P0ksp2/t67ST9iKTXRx9fLel5SeNlXLt230Oxff69pN+LPr5L0u7o4zdE+18iaUN0nJEej21L7PvqV5pja/f17eHYfknS/0547xWSno7+vjz6+PJejq1l/1+T9OFeXLfo+D8p6Y2Snkh5/aclfVaSSbpJ0hfLvG49nbG7+5Pu3unJ0zdKOunuT7v7y5I+Iek2MzNJWyU9FO33J5KmAw7vtuiYWY/9DkmfdffTAcfQTt7xXVCFa+fuX3X3r0UfPyfpBTUWtpUh8XuozZgfkvSvout0m6RPuPsZd39G0snoeD0bm7sfiH1fPS7pmoDn72psbWyT9Ji7f9vdX5T0mKS39nFsd0v6eMDzt+Xun1djopfmNkl/6g2PSxo3s6tU0nWrYo59QtLfxz5/Ntr2A5IW3f1sy/ZQfsjdn5ek6O8f7LD/XVr9jfNfo1+zPmhmlwQcW57xXWpm82b2eDNNpIpdOzO7UY1Z19djm0Neu7TvocR9ouvyHTWuU5b3lj22uHerMdNrSvr69npsd0Rfq4fM7Nqc7y17bIpSVxskzcU2l3ndskgbfynXLfij8az7bpBJzd29zfYgY8t5nKsk3SBpf2zzvZL+nxoBa5ek35L0n/swvnXu/pyZvU7SnJkdk/RPCfv189r9H0nvcvfz0eaur13raRK2tf57S/s+6yDz8c3snZKmJP1UbPOqr6+7fz3p/SWN7VFJH3f3M2b2HjV+69ma8b1lj63pLkkPufu52LYyr1sWPf1+Cx7YvftukM9Kujb2+TWSnlOjv8K4mV0UzbCa24OMzcz+wcyucvfno+DzQptD/bykT7v7cuzYz0cfnjGzP5b0G3nGFmp8UZpD7v60mX1O0qSkh1WBa2dm3y9pn6QPRL+ONo/d9bVrkfY9lLTPs2Z2kaTXqPGrdJb3lj02mdlb1Pih+VPufqa5PeXrGypAdRybu/9j7NM/kPTbsfe+ueW9nws0rkxji7lL0q/GN5R83bJIG38p162KqZi/lfR6a1RxXKzGF2mvN+40HFAjty1J75KUuR98BnujY2Y59qr8XRTQmvnsaTUeGxhSx/GZ2eXNNIaZXSlps6SvVOHaRV/LT6uRZ/xUy2uhr13i91CbMb9D0lx0nfZKussaVTMbJL1e0pe6HE+usZnZpKTfl3Sru78Q25749e3x2K6KfXqrpCejj/dLujka4+WSbtbK32hLH1s0vo1q3IT8m9i2sq9bFnsl/ZuoOuYmSd+JJjTlXLcy7xQn3Bn+WTV+Qp2R9A+S9kfbr5b0mZY7yF9V4yfq+2PbX6fGf7KTkj4l6ZKAY/sBSX8t6WvR31dE26ck/WFsv/WSFiStaXn/nKRjagSlj0p6VeBr13F8kn4iGsPR6O93V+XaSXqnpGVJR2J/NpV17ZK+h9RI79wafXxpdB1ORtfldbH3vj963wlJbyvh/0Gnsf1V9P+jeZ32dvr69nBsD0g6Ho3hgKTrYu/9d9H1PCnp3/Z6bNHnOyXNtryvF9ft42pUei2rEePeLek9kt4TvW6Sfjca+zHFqgLLuG6sPAWAmqliKgYA0AUCOwDUDIEdAGqGwA4ANUNgB4CaIbADQM0Q2AGgZgjsAFAz/x8hhMUohzj9uwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0], y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You want to find **parameters** (weights) `a` such that you minimize the *error* between the points and the line `x@a`. Note that here `a` is unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**. \n", "\n", "Now, we're going to pretend we don't know the values of the coefficients (`a`) are 3 & 2. And we have to figure them out.\n", "\n", "DL1v3 Lesson 2 @ [[1:26:39](https://youtu.be/Egp4Zajhzog?t=5199)]\n", "> if we can find a way to find those 2 parameters to fit that line to those 100 points, we can also fit arbitrary functions that convert from pixel values to probabilities.\n", ">\n", "> The techniques we're going to learn to find these 2 numbers, work equally well for the 50 million numbers in ResNet34.\n", "\n", "*parameters* in machine learning are *coefficients* in statistics.\n", "\n", "A **regression** problem is where the dependent variable is continuous. In mathematics the actual is $y$ and the prediction: $\\hat{y}$" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat-y)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above function in mathematical form is: $\\frac{\\sum_{i=1}^n{(\\hat{y_i}-y_i)}^2}{n}$\n", "\n", "Codal and Mathematical forms are both just notations of the same thing; but the code notation is executable – allowing you to experiment – while the math note. is abstract." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we believe `a = (-1.0,1.0)` then we can compute `y_hat` which is our *prediction* and then compute our error." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "a = tensor(-1.,1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(6.9835)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = x@a\n", "mse(y_hat, y)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X+UXOV5H/Dvs6uRWGGilcI2hpXWEj1Uim0UyWwJp5vjWMIH2QaLjXAEtG7s2kV14jQBuwqrYxcktylL3WLqJD2JbFPbxTHIIBSBTIGwcn3gRDi71i9kwAiEQStqyRGrWGiRRqunf9x7d+/cvb/ve2fu3Pl+ztHZ3Zk7d15ml2feee7zPq+oKoiIqDzaGj0AIiIyi4GdiKhkGNiJiEqGgZ2IqGQY2ImISoaBnYioZBjYiYhKhoGdiKhkGNiJiEpmRiOe9MILL9SFCxc24qmJiJrWyMjIL1S1K+q4hgT2hQsXYnh4uBFPTUTUtETkZ3GOYyqGiKhkGNiJiEqGgZ2IqGQY2ImISoaBnYioZBjYiYhKpiHljkRErWDb7lF8+fEXcWRsHBd3dmD9qsXoX96d+/MysBMR5WDb7lFs2Lof49UJAMDo2Dg2bN0PALkHd6ZiiIhy8OXHX5wM6o7x6gS+/PiLuT83Z+xERDEkTascGRtPdLtJnLETEUVw0iqjY+NQTKVVtu0eDXzMxZ0diW43iYGdiEpj2+5R9A0OYdHADvQNDoUG3iTSpFXWr1qMjkp7zW0dlXasX7XYyJjCMBVDRE3JmxpZsaQLD42M5nKxMk1axXlOVsUQEcXgV3HynV2vQT3HObPqrMH04s4OjPoE8ai0Sv/y7roEci8jqRgReVVE9ovIHhFhP14iypVfasQb1B0mLlY2Mq2ShskZ+wpV/YXB8xER+UoSrE1crGxkWiUNpmKIqOkEpUa8Km1ibFYdlVZp1CpTP6aqYhTAEyIyIiLrDJ2TiMiXX2qk0iZob5Oa26rnFBu3HzBWHRMkTTlknkwF9j5VfR+ADwP4rIi833uAiKwTkWERGT527JihpyWiVtS/vBt3rrkM3Z0dEADdnR14x3kzMHFueqZ9bLyae5Bt5CpTP0ZSMap6xP56VEQeBnAFgB96jtkMYDMA9Pb2Bl3nICKKxZsaWTSwI/BYU9UxQRq5ytRP5hm7iJwvIhc43wO4GsBzWc9LRJRE1EXSpEE2yWKnRq4y9WMiFfNrAJ4Wkb0AfgRgh6r+HwPnJaIWlnQVqV/e3S1JkE2aMy9aOWTmVIyqvgLgNwyMhYhaQJzqkTQtb53bNz1yAG+eqtbclzTIhuXM/Z6/aOWQLHckorqJG7CTBlaHk3fPWnqYtoVAUeraGdiJqG6CAvamRw7UBMWsFyOzBtm0LQSKgt0diahuggLzm6eqNfnrRl+MLFrOPCkGdiKqm7DA7K75bnRg9auTv3PNZYVJtURhKoaI6mb9qsW45YE9vve5Z/NFuBhZpJx5UgzsRFQ3/cu7sXH7AYyNV6fd553NN3NgbTSmYoiorjaufk9T56+bAWfsRJQ7b/nh9Zd3Y+cLxwpR811GDOxElEjSGnG/2vWHRkab6mJks2EqhohiS9OetmidD1sBZ+xEFFuaFaFRi42KtEFFWTCwE1FsaVaEhq3iDGsxABSn90qzYWAnotjSLLVfv2pxTfAGpqpggj4BbNx+AKfPnovVBIwz/umYYyei2NKsCA1bxRk00x8br/oG/M9v2VuTzy/alnRFwRk7EcWWZEVonJl03E2pHROqNTP3tF0gy46BnYgSibMiNG573qA0zXmVtmk91R3uwF20LemKgqkYIjIuboljUJrmjo9OX53q5gTuRneBLCrO2InIuCQz6bBPAJ/fshcTqtNudwJ32IXZVsYZOxEZZ2Im3b+8G/997W+EXqxt9va6eeGMnYiMMzWTjnOxll0gp2NgJ6JUwqpeTPZTZ+BOjoGdiBKLU/WSJSBz0VE2zLETUWJ5NvbioqPsGNiJKLE868fZDTI7Y6kYEWkHMAxgVFWvNXVeIjLDZHojTc+YuLjoKDuTOfY/BvA8gF8xeE4iSskdyOd0VPDWmbOoTlg14WFNteLIs348zzeNVmEkFSMi8wFcA+DrJs5HRNl489Rj49XJoO7Ikt5IUj++bfco+gaHsGhgB/oGhyJz5WkajVEtUzP2ewD8CYALgg4QkXUA1gFAT0+PoaclIj9+eWo/WdIbflUv3nTPiiVdeGhkNFb7Xfd5nf8GVsWkkzmwi8i1AI6q6oiIfCDoOFXdDGAzAPT29k5fI0xExsQN2CbTG34lkN/Z9Rq8/7PH6b7I2vVsTKRi+gCsFpFXAdwPYKWI3GfgvESUUpyAbTq94fcpIWgGxwuh+coc2FV1g6rOV9WFAG4EMKSqH888MiJKzS9PXWkTzJ1dAQC0i0zOnKNy3nFz5EmCNS+E5osrT4lKKChPDSBWn3RH3L7qQPxNMyrtwguhORP1aYmZt97eXh0eHq778xK1ur7BId/g293ZgWcGVmY63vsmEKSzo4I9d1ydcOQEACIyoqq9Ucdxxk7UQpIu/knaVx2Y+pQQNGU8Me6/MxKZw8BOVEJBq0yTLv5Jery7miVots/8ev7YK4aoZMKaaPldVBUAK5Z0+Z4ry2IhLjRqHM7YiRogbt+WNP1dwppoPTOwEsM/O15TX64AHhoZRe+75k07d5bFQlxo1Di8eEpUZ34XGTsq7dOW5PsdJ7ACcXdIkFw0sMM3vy0ADg1ek/gCKhVH3IunTMUQ1VnctrRhC37CepRH7TfK7onlx8BOVGdxA2tUoA1q4hWV28660XTSpl5Uf8yxE9VZUKVJ5+wK+gaHJvPRczoqGIsoDYxTdujNbWdpuZtkwRI1DmfsRHXmu9y/XXDy7bM1lSxvnTmLSpuEnitN6WCSlrte3N2oOXDGTlRnfjPqt06fnTY7r04o5s6uYPbMGRgdG5+8cOoImmXnudE08/PNgTN2ogboX96NZwZW4tDgNXhmYGXgasyxU1U8M7AS99ywDJ12Ay/AWpYfNMvOc1adNT9P9cHATlQAYQHTmYG/eWoq+J8+ey7wXHnOqrnoqDkwsBMVQFjATDoDjzurTlPdkiU/T/XDHDtRAYRVstz6wB7fxwTNwONUvWSpbuHuRsXHwE5UEEEBM00jLsB6kxgdG6/ZVMO5P+xTAIN282Mqhqjgkua1nf4yTiXNhN02ZHRsHLc8sAfLNj0RuCEGq1vKgTN2ooJL0kzLm2Lx6xkzNl6dVjrpYHVLOTCwEzWBuHltvxSLHwVi18VT82FgJyqwpG17k6RSnC6RbKlbPgzsRAWVpnIl7obSANv0lhkvnhIVVJoVpH4XWv0w7VJuDOxEKF4r2m27R1NVrvQv78b1l3ejXazmYe0i+PiVPbjnhmVcVNRCmIqhlle0VrTOeIKEVa5s2z2Kh0ZGJ0scJ1Qnt71j2qV1ZJ6xi8h5IvIjEdkrIgdEZJOJgRHVS9Fa0YZVtkSlUIr230KNYWLGfhrASlU9KSIVAE+LyGOqusvAuYly417I46dRi3XCnjcqhcK2ugQYmLGr5aT9Y8X+V/8dsokS2LZ7FOsf3BtaQdKoxTpBz9vd2RGZGmJbXQIMXTwVkXYR2QPgKIAnVfVZE+clysumRw6gOhE8/2hk1UiW1rhsq0uAoYunqjoBYJmIdAJ4WETeq6rPuY8RkXUA1gFAT0+PiaclSs3d29yrO+VinaSLiYIkaSFg8rFUHqJqNmsiIncAeEtV/1vQMb29vTo8PGz0eYmSWDiwI/C+VwevSXw+b2UNYM2Ug3Lipt4EqLWIyIiq9kYdZ6IqpsueqUNEOgB8EMALWc9LlKfOjkqi26MkqUZx3gTcG1dv2Lq/4bXzVB4mUjEXAfiWiLTDeqPYoqqPGjgvUW42rn4P1n9vL6rnpj6xVtoEG1e/J9Zs2ntMnMqasCoc9kInkzIHdlXdB2C5gbEQ1U1QLhpA5GIlvwVNQZxqFL9UjRdLEskUrjylluXXCrdvcChyZ6G4rXHd1ShxHsOSRDKFgZ3IJWjWPDo2jkUDOyK7J7aL4JzqtBRO1GycJYlkEgM7kUtY4HYudAbtPgQA51RxyKeqJuy8acsriYKwuyORS5y2t2EFwkHplKCFQ/fcsAzPDKxkUCejOGOnluetcLn+8m7sfOEYjtjliEGSbC3HhUNUT8YXKMXBBUpUFFELi/oGh3xTKALgX13ZM/kGwEBN9RB3gRJn7NTSwhYW9S/vxvpVi3HrA3umzdwVwM4XjrHHORUSc+zU0qLa3PYv7w5Mx7DunIqKM3ZqaUHVKu6LoN0hx7DnCxURZ+zU0uK0uQ06ZsWSLvZ8oUJiYKeW1r+8G3euuSx0o+egY3a+cIzb0FEhMRVDLc9biugEZm9w96ZYbn1gj+/5mHunRuOMnVpe2ja63IaOioqBnVpekl7qbtyGjoqKqRhqeVElj0G4mpSKioGdWl6ckscgfrl3okZjYKfC8KsJB/KfEa9ftdi3rQBTKtSsGNipEPx2JVr/vb2AANUJnbzNu5uRCUypUNkwsFMh+F3AdO9H6jCxN2jQalEGcioLBnbKVdwl90lqv7PUift9MsjjUwBRI7HckXKTpD48Se13ljrxtKWNRM2EgZ1ykySI+tWEV9oElXapuS3rRc20pY1EzYSpGMpNkiAadAHT77YsKZMspY1EzYKBnXKTNIgGXcA0mftmaSO1gsypGBFZICI7ReR5ETkgIn9sYmDU/Iq45D5ON0eiZmdixn4WwOdV9ccicgGAERF5UlV/YuDc1MSS1IfXc8MKljZS2WUO7Kr6BoA37O9/KSLPA+gGwMBeAnEDblhteFQQzbsEkbscUasxmmMXkYUAlgN41ue+dQDWAUBPT0+2J9q3BXjqS8CJw8Cc+cClVwMvPTH181W3A0vXZnsOih1wswbmqA2l04zbCeSdsys4+fbZycVOrFunVmCs3FFE3gHgIQC3qOo/eu9X1c2q2quqvV1dXemfaN8W4JE/Ak68DkCtr8PfqP15683AXYusYym1uOWKaWvDv7htP/7phu/7XmAF0pUgemvn3zxVnbaClXXrVHZGAruIVGAF9e+o6lYT5wz01JeAaoz/4cePW28ATnDftwX4ynuBjZ3WVwb9SHHLFdPUhn9x237ct+s1TOj0tgGONCWIfm8yScdG1Owyp2JERAB8A8Dzqnp39iFFOHE4/rHVceuNALCCvPOGcOJ162eAKZsQccsVk5Q1OmmSoFm6I231TNyAzbp1KjMTM/Y+AP8awEoR2WP/+4iB8/qbMz/Z8ScO+8/y3UHfwVl9jRVLuiCe2/wCbtyyRneaJEyWEsQ4AbvRJZdEeTNRFfM0MO3///xcdXvt7DvKnPnBs3z37U7unrN6AFYQfmhkFO5EiQC4/vLpVS5xyxrjpEnaRfDMwMrJMSStZvFbgFRpF5w/cwZOjFdZFUMtoflWnjpB1lsVc+BhK6/uVumw3gie+pJ9cdXDPfsPm9UvXTu9EqdklTfeIHrqzNlpQVgB7HzhmO/j45Q1xkmT3PSbCybHk6bShr3ViZoxsANWQPUG1WvvDg++3lm+E/QdYbP6oNn8a7umyiw75lr3jb/ZdIHfL4gGyXLRMSgXD1gz9Zt+cwH+c/9lALKVQHIBErW65gzsQfwCvnM7ED7jnjM/eFYfNJsfvhdwkhXuTwtOyeXWm62fO+YBH76rsIE+biUJkO2iY1CfFr98uskSSKJWU67AHiYo6Dv8cvfOrH7ruoAHBZfq1Rg/DvzNZ4Hd9wGvPg3oBCDtwOWftD5pNFjcYJn1omPcNMm23aMQ+L+6rGYhitY6gT1K2Kw+KEefxMQZ4ND/nfpZJ6yFVQDQc2VD8/dBKZLOjgrOnzXDaK46Tprky4+/6BvUBWA1C1EMoiELRPLS29urw8PDdX/e1Lw5dgAInFMmJUDlvOmpnsr5wIxZdcnZe3PsQHCKJO75sly8XDSwI/CVfXXwmsTjISoLERlR1d6o47iDUhxL1wIf/SowZwEAsb72fspK1WSm/qWb1bfsvH3+bRJMtrJNsh1ekKB0SzfTMESxcMaehbsKZ7Iq5nj4Y7KqdFhvMgW9ENs3OOSb1unu7JisT49i+hMEUVnEnbEzx55F0AXZfVuAx26bCvId84B3XlabY3fMPB8481b853TX1rs9+rmaKp1q+2z8qazDt05eUddabhN7irIWnSgbztjr6dHPASPfrK2K6bky2UpaAIAAG8dqz+tciHVx/2rfxAV4+fL/iH+++t+lHX0sJmbsROSPM/Yiuvbu4PJG9ww/irdfzsg3fQ8TV6OHefgl5v74TzDx49sgqjgqXXj9feuNB3ruKUrUeAzsReCkdLw5+7OnrYuobt4Vs4D1CSAGAdAOBQR4J45h3sgGnN7/XzCresLYylmmUYgaj6mYoovTo2bTvNjBPZG2mcC5M9b3BV89S9QKmIopi6gVs4CVq/fJsWfmBHXAShNtvdnqj1OA1bJEFIx17GVw7d14+V03YkKtC6bOv1wM39vyfeqJio4z9pL4vZ/fgNHTqyd/Xt32NO6Y8W3MaztpuFm++pdb+tb0N1+nS6Iy4Iy9JLx14tvP/RYuP7MZl7z918CarwFzFkABnEUbzikwhgswIZV0T+ZtcezdYHz8uGfV7Dpg4xzuSkVUJ5yxl0TovqNLrwGWroVg6hfeCdTOsmfOjr9QyltuGbnBuJ0XOvG61eXysds4myfKEQN7SaSqH/demPWmU06frL2ACviXWybZYHzizFS9fk3fegHaZwITp+3BswqHKC0G9gLI2g3RMWtG22Rgnzu7gjs++p7Q83ifd8WSX8fO01/FkbfHcfF5HVj/0cXob38mutwyaJOSRHQqqANTVTjOZiUAgz1RTAzsDZZ2b8+wcwDA29VziZ/3vl2vTd4/OY41fei/9bnwASTdYDyt8ePAtj+wvmdwJwrEi6cNFra3Z57niLMdXuxxeNsad8yz/gHWzyadq1qfIIgoEGfsDWaiG2Kac8Q9/+jYOPoGh6LTQ2GdLmvy9r+0gnMWUTn9OKt1iUrMyIxdRO4VkaMiEvGZnbyCNpVoE4m9OUXQOcL2B02yd2iazTImLV0L3Pqc1Y3ytkNA//+0Z/YZeKty3Lyll04lzl2LgI2dLLmklmCkV4yIvB/ASQDfVtX3Rh3PXjFT/PLjjribS0RtTOF3cRZA4PMGMd56193G2FsVE6StYr05BM3Av/Le6Au5bRVg1gW13TQLtLk4UZC4vWKMNQETkYUAHmVgT27b7lF8fsteTPj8LuIG06DKmrCgD8BTFdOFnS8c862HB6xs+aF67Tnq3awEiFcVs7ET2feiFaAyG6ieYiqHCoWBvckEbeCcNZim2fiiqTfLiDNjT6qtYn2acFoos+ySGqRwm1mLyDoRGRaR4WPHjtXraZtGmjx5lG27RwNn32EXT9evWoyOSnvNbU2zWcZVtxvaZNzlXLW2L/74ceDhzzBvT4VVt8CuqptVtVdVe7u6uur1tE3DdDB1UjBBwt4w+pd34841l6G7swMCa6beNBtJ+5VetqXsiRNGJzz9cG4G/lMXgz0VAssdC8L0zkNhdepx3jD6l3c3RyD3E9Uq4cxJq7WBad52CY/80dR4iOrIVFXMdwF8AMCFAH4O4A5VDdz5gTn2/AXl7AHgnhuWNW/QNmEy0BvOxfuZs8Aq96x5XtbXUzp13UFJVW8ycR4yJ6jbY3dnR2sHdcB/Rv/ILa48uqsqpmMu8PaJ9FsPOoupnPp6p+2CuwEaSy3JMKZiSipVt8dWFbX9oLf0snL+9E3GgziLqcJaG+uEtbXhPxwEjr/CzUooM25mXWKmukaSj0c/F73PbKXDupC7dK2h+noXlly2JG5mTc19AbTorr0b6LmydiY/83ygfZb/LNtIa2MXdrqkEJyxE9WDN8duivvirMPdqoH5+1LhjL3JMY1SMs6s2nQ1jrfTpTdF5OTvh7/B9E0L4Yy9gKKaelEJeGfVC38LOPyj5DN674x907yYFTyCyZw/A37T4Iy9iYVtnBEU2DnDbzLX3j09PZJ0IVVbZfr+s7HLMl0TOmcbwtd2MWVTEgzsBZR04wwT2+tRAfiVXda0NnYJmmVLe/qa++F7rQvCnLk3PQb2AgpaXBTU3yXNDJ+ahN/MPszln4wuwwyk1icGb2D3rpi99GrgpSe4grbAuOdpASVtCGZiez0qiWvvBi5ckv7x3ouxfjtSDX+j9uetN1vNz9j0rDA4Y6+juHnwpA3Bks7wqeT+8Fm7OuZeJF4U5d12MGzFrNv4cesN4LVdnM0XAKtiMkhywTJtpUuc52AVDUXy25HKy71S1pF4xayr2iYIq3BSq/sOSkmUIbAnDaZpdiVK8hysiqHE4nSbzGNHKi8G+tgY2HOWNFCn2fquqbeoo3LIa8VskDkLmL4JwTr2FJLMepNesEyTB+dFUWq4mhWzdn39+BiAcz4Hx0jDRHE2KHltF7DvfuCMp4smA38srIqxOWmP0bFxKKZqwbftHvU9PukepSuW+G8HGHR7mucgysXStdbq1o1jwG2HgDV/ZW89CKtuHrB+7v2Umf1mq+NW5Y03qANTgZ8VOKE4Y7clrQVP2u985wv+G3gH3Z7mOYjqIqx/fc+V+e9OVR23ngOYei5nYRZn9AAY2CclTXskLUlMk1YxvQ8qUe7cQT9OJU5azszdyf07q23dO1M5WjDYM7Db0uTAk/Q7T1trzp7q1LSCZvaRAT9Grl7a41/QbcGNxZljtyVd7Rlm2+5R9A0OYdHADvQNDmHb7lGj5ydqakvXWrn6jSesf2u+ZufsZSpX39Ye/PhKR/J+OO70TQtguaOLiVrwsNpzgGkVolj2bQEevSW4KiZVHl+sC8BNjHXshsUN+qw9J6qDNPX1frtNRT1HwZqfsY7dI8tsPElb3KCLoX7BnohS8u5IFdWuuNIxvXd9GO8bh9P8zFHwvH1L5NiT1qh7hZVCup+jb3Ao8JKP2McQkSGT9fUngDuOW197Pz1VW++Ys2B6D5wocZqfefP2+7ZYLRg2dlpfG1hrb2TGLiIfAvA/ALQD+LqqDpo4rylZ+5VHlSr65dW91B4Hc+pEOUravz6It31x1HF+M/wGdrvMPGMXkXYAfwHgwwDeDeAmEXl31vOalHVpftQKUL83jizPR0QN5m1fHHWc3wy/Om61Tnb3rq/TqlkTqZgrABxU1VdU9QyA+wFcZ+C8xmRdmh9Vqpj1DYKICuaq26PbI7jz9oEzfE9ytk5llyYCezcAd93RYfu2GiKyTkSGRWT42LHgZfR5yFpD3r+8G3euuQzdnR0QWBUu7ta5cQI2a9aJmsjStVZevqa+/tO1P7vz9nFn+ED8NE8GJnLs4nPbtGuIqroZwGbAKnc08LyxmViaH7YC1K+nS6VN8I7zZmDsVJU160TNKKwnjtdVt/uUXwasoE3yJpCSicB+GMAC18/zARwxcF6j8lyaz54uRC3O297YqXvf+9e1wT5p2WVKmRcoicgMAD8FcBWAUQB/D+BfquqBoMc04wIlIqLE4uxSlUDdFiip6lkR+UMAj8Mqd7w3LKgTEbWMJOkcg4zUsavq9wF838S5iIgom5ZYeUpE1EoY2ImISoaBnYioZBjYiYhKphRte01skEFEVBZNP2P3a8l7ywN7sGzTE2yTS0QtqekDe1BnxbHxaqKe60REZdH0gT2ss6J3MwwiolbQ9IE9qrMie6ATUatp+sDu15LXjT3QiajVNH1VjFP9sumRA3jzVLXmPvZAJ6JW1PQzdsAK7rtvvxr33LAscDMMIqJW0fQzdrc8e64TETWLUszYiYhoCgM7EVHJMLATEZVMU+bY2RuGiChY0wV2pzeM00ZgdGwcG7buBwAGdyIiNGEqxq83DFsHEBFNabrAHtQigK0DiIgsTRfYg1oEsHUAEZGl6QK7X28Ytg4gIprSdBdPnQukrIohIvKXKbCLyO8C2Ajg1wFcoarDJgYVha0DiIiCZU3FPAdgDYAfGhgLEREZkGnGrqrPA4CImBkNERFl1nQXT4mIKFzkjF1E/hbAO33u+oKq/k3cJxKRdQDWAUBPT0/sARIRUTKRgV1VP2jiiVR1M4DNANDb26smzklERNMxFUNEVDKimn7yLCK/A+DPAHQBGAOwR1VXxXjcMQA/C7j7QgC/SD2ofHFsyRV1XADHllZRx1bUcQHmxvYuVe2KOihTYM+DiAyram+jx+GHY0uuqOMCOLa0ijq2oo4LqP/YmIohIioZBnYiopIpYmDf3OgBhODYkivquACOLa2ijq2o4wLqPLbC5diJiCibIs7YiYgog4YEdhH5XRE5ICLnRCTwSrGIfEhEXhSRgyIy4Lp9kYg8KyIvicgDIjLT4NjmiciT9rmfFJG5PsesEJE9rn9vi0i/fd83ReSQ675l9RqXfdyE67m3u25v9Gu2TET+zv697xORG1z3GX/Ngv52XPfPsl+Hg/brstB13wb79hdFJLJ81/C4PiciP7Ffo6dE5F2u+3x/t3Uc2ydF5JhrDP/Wdd8n7N//SyLyiQaM7Suucf1URMZc9+X2uonIvSJyVESeC7hfROSr9rj3icj7XPfl95qpat3/wWrzuxjADwD0BhzTDuBlAJcAmAlgL4B32/dtAXCj/f1fAvh9g2P7rwAG7O8HANwVcfw8AMcBzLZ//iaAj+XwmsUaF4CTAbc39DUD8M8AXGp/fzGANwB05vGahf3tuI75AwB/aX9/I4AH7O/fbR8/C8Ai+zztdRzXCtff0u874wr73dZxbJ8E8Oc+j50H4BX761z7+7n1HJvn+H8P4N46vW7vB/A+AM8F3P8RAI8BEABXAni2Hq9ZQ2bsqvq8qkbtPn0FgIOq+oqqngFwP4DrREQArATwoH3ctwD0GxzedfY54577YwAeU9VTBsfgJ+m4JhXhNVPVn6rqS/b3RwAchbWwLQ++fzshY34QwFX263QdgPtV9bSqHgJw0D5fXcalqjtdf0u7AMw39NyZxxZiFYAnVfW4qr4J4EkAH2rg2G4C8F2Dzx9IVX8Ia2IX5DoA31bLLgCdInIRcn7Nipxj7wbwuuvnw/ZtvwpgTFXPem4+8ptdAAADf0lEQVQ35ddU9Q0AsL/+k4jjb8T0P6I/tT92fUVEZtV5XOeJyLCI7HLSQyjYayYiV8Caeb3sutnkaxb0t+N7jP26nID1OsV5bJ7jcvs0rNmew+93a0rcsV1v/54eFJEFCR+b99hgp64WARhy3Zzn6xYlaOy5vma5bY0n2btC+jV515DbjYwt4XkuAnAZgMddN28A8P9gBa7NAG4D8KU6jqtHVY+IyCUAhkRkP4B/9Dmuka/Z/wbwCVU9Z9+c+jULehqf27z/vbn9fYWIfW4R+TiAXgC/7bp52u9WVV/2e3xOY3sEwHdV9bSIfAbWJ56VMR+b99gcNwJ4UFUnXLfl+bpFacTfWX6BXbN3hTwMYIHr5/kAjsDqt9ApIjPsmZZzu5GxicjPReQiVX3DDkJHQ061FsDDqlp1nfsN+9vTIvK/APyHeo7LTnNAVV8RkR8AWA7gIRTgNRORXwGwA8AX7Y+lzrlTv2YBgv52/I45LCIzAMyB9ZE6zmPzHBdE5IOw3jB/W1VPO7cH/G5NBajIsanqP7h+/BqAu1yP/YDnsT8wNK5YY3O5EcBn3Tfk/LpFCRp7rq9ZkVMxfw/gUrGqOWbC+oVtV+vKw05YuW0A+ASA2H3hY9hunzPOuafl8uzA5uS1+2FtH1iXcYnIXCeNISIXAugD8JMivGb27/BhWPnG73nuM/2a+f7thIz5YwCG7NdpO4AbxaqaWQTgUgA/yjie2OMSkeUA/grAalU96rrd93draFxxx3aR68fVAJ63v38cwNX2GOcCuBq1n2JzH5s9vsWwLkT+neu2vF+3KNsB/J5dHXMlgBP2RCbf1yyvq8Vh/wD8Dqx3rNMAfg7gcfv2iwF833XcRwD8FNa76xdct18C63+2gwC+B2CWwbH9KoCnALxkf51n394L4Ouu4xYCGAXQ5nn8EID9sILTfQDeUa9xAfgX9nPvtb9+uiivGYCPA6gC2OP6tyyv18zvbwdWeme1/f159utw0H5dLnE99gv2414E8GHDf/tR4/pb+/8J5zXaHvW7rePY7gRwwB7DTgBLXI/9lP1aHgTwb+o9NvvnjQAGPY/L9XWDNbF7w/7bPgzrushnAHzGvl8A/IU97v1wVQHm+Zpx5SkRUckUORVDREQpMLATEZUMAzsRUckwsBMRlQwDOxFRyTCwExGVDAM7EVHJMLATEZXM/wcJVynJJEYYNwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],y_hat);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for `a`? How do we find the best *fitting* linear regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like to find the values of `a` that minimize `mse_loss`.\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([-1., 1.], requires_grad=True)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(a); a" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def update():\n", " y_hat = x@a # prediction\n", " loss = mse(y, y_hat) # MSE\n", " if t % 10 == 0: print(loss) # printout\n", " loss.backward() # calculate gradient\n", " with torch.no_grad(): # turnoff gradient calculations when updating SGD\n", " a.sub_(lr * a.grad) # subtract learning rate x gradient from coeffs a inplace\n", " a.grad.zero_() # zero-out the gradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch keeps track of how our loss, mse, was calculated for us, and lets us calculate the derivative.\n", "> So if you do a mathematical operation on a tensor in pytorch, you can call `backward` to calculate the derivative." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(6.9835, grad_fn=)\n", "tensor(1.4813, grad_fn=)\n", "tensor(0.5689, grad_fn=)\n", "tensor(0.2572, grad_fn=)\n", "tensor(0.1468, grad_fn=)\n", "tensor(0.1077, grad_fn=)\n", "tensor(0.0939, grad_fn=)\n", "tensor(0.0889, grad_fn=)\n", "tensor(0.0872, grad_fn=)\n", "tensor(0.0866, grad_fn=)\n" ] } ], "source": [ "lr = 1e-1\n", "for t in range(100): update()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X98XHWd7/HXJ8mkHURTSrtC0xbQhxdXgaVL6LqWe3cBBV2xxKIVd110FSvioshjgXL1ltKHbkvZFexeeUhFVl29SpYftYDc3kpRt1zRJraWIrIiLNumuFAh3dsmbabJ5/5xZtLJ5JyZM5kzk8nk/Xw8apIzZ875chI/+ebz/Xy/X3N3RESkcTRNdANERCRZCuwiIg1GgV1EpMEosIuINBgFdhGRBqPALiLSYBTYRUQajAK7iEiDUWAXEWkwLRNx01mzZvnJJ588EbcWEZm0enp69rn77FLnTUhgP/nkk+nu7p6IW4uITFpm9nyc85SKERFpMArsIiINRoFdRKTBKLCLiDQYBXYRkQajwC4i0mAU2EVEGowCu4hIg1FgFxGplp1dcOtpsHJG8HFnV01uOyEzT0VEGt7OLnjgU5AZCL7evzv4GuCMpVW9tXrsIiLV8Miqo0E9JzMQHK8y9dhFRJKwsysI2vv3QNvcoIceZv+eqjdFgV1EpFJhaRcM8LHnts2tenOUihERqVRY2gUfE9aPNE+H81dUvTkK7CIilYpKrzjsGZ7FsBt7hmexPHM5G4YWVb05SsWIiFQqIqfe67M4Z3DdqGM/2fQ0nQvaq9oc9dhFROIoVpN+/gpIpUed3u+trD0ytqxxb19hyiZ56rGLiJSys4sj37uKlqFDwdf7dwdfQ1CTnqtLz6uKWXvwEjYeXjjmUnNmpMccS5oCu4g0jA3be7ll09Ps7Rtgzow011546vjTHnnli8NmtPjwqJdbhg7R//AKjskF9fwAD5y5vZf0fU8wkBkaOZZONXPthaeOrz1lUGAXkUmpMIif+8bZ3NvTOxJIe/sGuOG+JwDKD+4F5YtNHlK2CEwf+G3kJXL3TOwXTRnMIxpcTR0dHa7NrEVkvDZs7+WGgt5wRNU47TPSPLb8vPJucOtp0ROM8uwZnsXcVb8p79oVMLMed+8odV4ig6dm9m9m9oSZ7TAzRWwRqapbNj09KqhDeFCHcQ5Wxpgd2u+t3Nn6wfKvXQNJpmLOdfd9CV5PRCRUOcF6XIOVEeWLR7yJJpy9fjy3cSnnvGtZ+deuAeXYRWTSmTMjTW+M4J5qsvENVp6/YvQSAQSzRj9vV/CNAwtD8+WJDtxWKKnA7sD/MTMH7nD39QldV0RkjGsvPHVMjj3VZAwDQ8NHkzKZYWflxieBMgdQQ8oXW85fwcozlrIy5PTCnH9FA7cJSGTw1MzmuPteM/s9YDNwlbv/uOCcZcAygPnz55/1/PPPV3xfEZm6CnvI/YNHeKU/E3puOtXM6iWnVy3ILlqzJfQviHEN3BYRd/A0kR67u+/NfnzRzO4HFgI/LjhnPbAegqqYJO4rIlNX54L2UYH6lOUPRZ47kBnilipO5Y/K+ddilmmYiqtizOxVZvbq3OfABcCuSq8rIlKOUoOk5QbZDdt7WbRmC6csf4hFa7awYXtv2feuxSzTMEmUO74W2GpmvwB+Bjzk7v87geuKyBRWTmCFIO+eTjVHvl5OkM3lzHv7BnCO5syj2hB271rNMg1TcSrG3Z8F/iCBtojIFBCnemQ8g5GdC9pp3/0g7T9fy4m+jyGaaGKYvT4rKE288MrYbQyrky+WzpnIWaZhVO4oIjUTN2CXG1gB2NnF2b/4HDAIBi0Ea7vMtX2sab6TluY/AOJtIj2enHlhzn8iadleEamZqIB90wNPjjo2rsHIh6+HocHQl1qGDpW1iXS95czLpcAuIjUTFZhf6c+Myl+XFVhz66QPvFz85mVsIl1vOfNyKbCLSM0U6/Hesunpkc9jBdadXXDzKXDfx2It2FXOJtKdC9pZveR02mekMYJ69GrWwSdNOXYRqZlrLzyVq+/eEfpafm++5GBkwbK6JaXSZW8iXU8583IpsItIzXQuaGflxifpGxg7Q7SwN180sD6yKn5QT8+Ed948ahOMRqfALiJVl1/i2JZOkWo2MkNHJ6AXzV/n7WRE29yg5x0nX942Lzh3CgX0HAV2ESlLuasYFpY49g1kSDUZxx2Toq8/E32NB6+Bnn+E/C3p9u8OUjDp46IHS1NpePe6KRnQcxTYRSS28UwcCitxzAw7x7S2sH3FBeE3+sZieO5H4a9lBqAlHQTwwnTMFEy7hFFgF5HYxjNxqFRNeuFfALe96decHRXUR276CixZPzZFM8UDeo4Cu4jENp6JQ1GbYsyZkQ79C2BOz9pgA9Ni2uYGQVyBPJQCu4jEVixIRwnbFCM3WHrLpqd5+9CPuK61izm2j70+izmU2GGzoHSxnnYuqheaoCQisY1nRmaxyT4d/7mZNak7mdu0jyaDuU37IjeldocDTGPb6TeN9NTLXYVxqlCPXURiK2cVw6I96Z1dcOsqbmvdPSbr0mQw7MHHHHf4l+E3c1nms6S3NbN6Xi+dC9rHt1jYFKDALiJliTMjs2j1TO/fQ/ddgBdNpe8ZnsUc+x17/XjWHlnKxuFzgNGBu952LqoXCuwikrjCnvTipq1cZ13M+V6J/HnWizabcw5/KfL1XOAeT85/KlCOXUQSlwu8i5u2smvaX/Gl1O1BHj3Om1Npdv/htbF2Q5rsqzBWiwK7iCRuzow0N7XcxW2p2znWDmOlyhdz2ubBu9dx9uKPs3rJ6cxIp8ackh+4J/sqjNVi7lFj0NXT0dHh3d3dNb+viNTGto13cFbPdaMGQIuzYMJRSF26yhmPMrMed+8odZ5y7CIyLqEBt/kxeGQVZ+/fXXqS0QiDjo9ETjaazMvnThQFdhEpW1jVy9b7b+ei1J3BNnRxRazAqF56ZRTYRaRsYfXjV/Pd+EG946Nw0RdDXxrPQmMymgK7iJRtb99AUMLYkrcUgMUoZWx9FVx0W9E1XjTpqHKJBXYzawa6gV53vyip64pIMhJLbzx4Dc9Mv4sm95Fql7m2j+FidRhlbHqhSUeVS7LH/mngKeA1CV5TRMapcNeig4NHRnYtGnd648FroPtrNMOYwdEmAy88PI5NLzTpqHKJ1LGb2VzgXcCdSVxPRCpTuDhW30Bm1FZ0cDS9UZaerxd92SDonWMjNekbhhaxaM0WTln+EIvWbCm5QJcmHVUuqR77bcB1wKujTjCzZcAygPnz5yd0WxEJE5anDhMrvZG/52jk2ouB/vSJvP3wOvYeGmDO9DTnPjube3vKGwgtZ6ExCVdxYDezi4AX3b3HzP406jx3Xw+sh2CCUqX3FZFocfPRJdMbO7uCPUYLt6ALcaR5OisOXkLvYHBub98A337838f8KogzEKra9cokkYpZBCw2s38DvgucZ2bfSuC6IjJOcfLRsdIbj6yKFdRJvYrP2xXcM/jWUYejenAaCK2uigO7u9/g7nPd/WTgUmCLu3+w4paJyLiF5alTTcZxxwRrrzSbjfSci+a89+8pfiNrDmrSP7uXbxxYGLt9GgitLtWxizSgqDw1UNbkn/70CRwz8ELI8RM55vpfjToWVc1SKNVsGgitskRXd3T3H6qGXaQ+dC5o57Hl5/Hcmnfx2PLzSu44FGZt5v30e+uoY/3eytrM+8ecG/ZXQphXtbYof15l6rGLTCH566SvTH2T4zgAwMv9x8LOW8fUm3/jwEJebhrMzjA9upvRA4cXsrLg2oV/JUTl1/cPZBL8L5IwCuwiDShqlumcGWmWHfgylzX/YNQa6cfbAdhwZfBFXnCfMyPNxr5z2Dh4zqjrt0fkyPOrWRat2aKJRhNEG22I1LEN23vLmtyTe0/+5KTcyouHvzCfrYfeMyaojxjOBFUweSqZLKSJRhNHPXaRCRBn3ZbxrnIYtt/oF+wOpuWOFVsnvaAKppLJQppoNHG0g5JIjRUGbAh6soVbukWlMiBIhUQFyVOWP4QTBPQbW77JTDtQ3tZ0n9lVzn+O1JB2UBKpU3GXpS02iadY733OjDSrD36O/9r0ZPyADhz2Zqadv6LkedoEo/4psIvUWFTA7u0bYNGaLSMBsy2doq9IBUnU1PzvN/8Nr2n6Teyg7g4HmcbfpT7ByhKrMGoTjMlBg6ciNRZVFWIwasDz4OARUiV2gx7zS+LBa2g7GC+ou8Pvho/l05krOXv4m5z5rmUl31NuHbxMDAV2kRoLqxYxxq6rkhlyjp3eEllaCCG/JEosq5tzODWDm1JX0zG4np7XvH1Mfj+KNsGYHJSKEamxsGqRqEHSvv4M21dcEDngOqZ00Ess1dvcChd/mWlnLGUljJlkVIo2wZgcFNhFJkDhsrRxJvNMTzWNBPYZ6RQrF795bC/bmkODuzscJsX0i79c1m5Gha698NR4v2BkQikVI1IHik3myfXWX+k/OpB6+Mhw+IXO+vCYlI47/MrbOT/dNSqoj2fyU+eCdlYvOZ32GWmMoOwybhpHakc9dpE6UGwyz6I1W3j70I9YOS1vbRc/ln946HI6F9w0+kIXfZFnXzrISf/WRTPDDNHEt4fOY419jNV5vepKqlu0CUb90wQlkXr24DUMbbuLJnxMpcthb2baJV8JTa3kas17+wZoNmPIfdSkpqjUT/uMNI8tP69a/zVSobgTlBTYRerRzi544GrIHCx+XshM0fygHlZtM6NIfbwBz61517ibLdWlmacik1UZ+4wWru1SmGIJ67b1DWRCAz6ouqVRaPBUpN7E3WcUoG3uqC/DJhCFccauBabqlsahwC5SL3Z2wa2nwf7d8c5vSkHB2i7lTBRyUHVLg1IqRqQexEy/eLarbemZ8M6bxwycxt13FDRQ2sjUYxepByXSL+4w5MY3h97GOdPvh+ufC62GibvvqNIujU09dhEmYCnaB68J1nXxocjZohAE9F6fxdojS9k4HGxPZ0V65J0L2ul+/mW+89PdDLnTbMYH/mgeHSfN1FK7U4gCu0x5NVuKdmdX0DMvzKEXWd+l12dxzuC6UceKVa5s2N7LvT29DGXLmIfcubenl46TZirtMoVUnIoxs+lm9jMz+4WZPWlmN5V+l8jEy02pv/ruHdVfijaXQ487MAr0eytrj4xOt5RKoWhZXYFkeuyHgfPc/YCZpYCtZvawuz+ewLVFqiJstcRCiS5FG7eEsW0ew3172OvHj0q/5JSqXNGyugIJBHYPpq4eyH6Zyv6r/XRWkTLEqfdOdLJOwUSiUNYMn9nFfy0y3b9UakjL6gokVBVjZs1mtgN4Edjs7j8NOWeZmXWbWfdLL72UxG1Fxq1USeB4qkaKrpZYMJEo1FkfBoqv9FhKJe+VxpFIYHf3IXc/E5gLLDSz00LOWe/uHe7eMXv27CRuKzJuzUX2jhvPZJ0N23vZev/t3N3/MX4z7c+5u/9jbL3/9qPB/fwVkIroNVszdHwULvoiUNnSuFpWVyDhqhh37zOzHwLvAHaVOF1kwgwVWfxuPNUjOx5azypbzzE2CMBc28cqX8/ah1qCpXVzNeePrIL9e+hPn8DazPv5xoGFQflh+6l05l2vkqVxtayuVBzYzWw2kMkG9TTwNuDmilsmUkXtEbno9hnpWDXtG7b3suOh9Vw1eCczmw5wozNmWd1jbJDLB78FBIViG4YWccvhdfQeGsAOHR2Iqlp5pUxZSaRiTgQeNbOdwDaCHPuDCVxXpGqictHnvnE2N9z3BL19AzhHg25+vnzD9l4O3v9pbszcxvFNBzDGBvWcOU2/G3lP7rowtrpAJYmSpIoDu7vvdPcF7n6Gu5/m7quSaJhINUXloh/91Usl68B3PLSeD9jmyGCe71D6BCBeFY5KEiUpmnkqU1ZYLvozd+8IPTc/6F4++C2aYnSJjjRP55h3rhrz/igqSZSkaBEwkTxhwXVx01Z6pi3Db2zDV7bR3rSv9IXa5tFy8T+MDJqWCtoqSZQkKbCL5CnMvS9u2sotqTuYaQcwCzanKJaBGfQWWPLVYLu6vNUXw3L6ueuoJFGSplSMTHmFVTCXnNXOo796ib19A1yf6mKaxdiRyOEg0/i71CdYGbKcbi5oa4VFqQUFdpnSwlZ2PPTz77L5VfdyzPTfUmyzd3dwbGRdl83Nf8Lqd50eeb7qy6VWFNhlSstVqyxu2sp1LV202z4caMqOdRarfHnBZvG+6V8d6YGvVg9c6oQCu0xpe/sGuKnlLi5r/sFIEI9Rxcigt3BzZimPrdQa51J/FNhl6nrwGp6ZfhdN7iVr0vMzMq9wLCszl9HzmrfXfuclkRgU2GVqevAa6P4azRCri76XWSw6fHQno3SqmUuys1SrvvOSSJlU7ihTz84u6P5a/PNTafaedd24ZqmKTAT12GVqyW1RF1fbPDh/BWefsZTHFo9+Kc4sVZGJoMAujW9kE+k9YE1FN4/OOeLG9rNu5uzFH488R7sVSb1SKkYa26hNpL1oUHcP/r3sx3JN5hNc/cs3FL20diuSeqUeuzSecfTQ3eGbQ2/jxiMfGTlmJVIqmk0q9UqBXepa2eWEuR56JhuUYwT1AaZxfeajbBw+Z9TxOCkVzSaVeqTALnWjMIif+8bZ3NvTW1454SOrjgb1YqwZfBja5rLr9VexedtJMHz0l4BSKjKZKccudSF/h6HczkXffvzfyy8n3L+n5L2ONE9nZfNVnHLo2yw6vI7eeRdpA2hpKOqxS1XFTaWE7TAUtfxW0XLCtrnZgdIC2R56f/oEVhy8hHsGFwJH/wpYveT0cW1iLVKP1GOXqgnrhRfuH5pTTu130dz3+SsgVfB6Kg3v+Qqs7OPtfjv3DL511MuaVCSNRoFdqiasFx4VRKOCdeFs/5K57zOWwrvXBROLsODju9eNbHoR9QtEk4qkkSgVI1VTThC99sJTR627Atn1WPI2vYhdTnjG0lG7F+XTpCKZCtRjl6qJCpZhxzsXtIcOYH7+dU/x2LRP8dz0v+CxaZ+is/mxitqkSUUyFVTcYzezecA3gROAYWC9u3+p0uvK5BfVC48KoqNqwnd2wQMfhMzBoyfs3310nZeIHnkpmlQkU4EV2/or1gXMTgROdPefm9mrgR6g091/GfWejo4O7+7urui+MjmMa73ynV2w4UoYzoS/3jYv2CxaZIoxsx537yh1XsU9dnd/AXgh+/n/M7OngHYgMrDL5BE3MEedF3dm5obtvex4aD2XD36L9qZ9xZdIj1GrLjKVJTp4amYnAwuAnyZ5XZkYYRs9h838jHtesfscvP/TrLDNNMUZ9WmbW/J6hTNYyx6AFZnEKk7FjFzI7FjgR8AX3P2+kNeXAcsA5s+ff9bzzz+fyH2lehat2RJaQdI+Iz1qMk/c8wr9ZN2HOft336OZYaD4xtFHGSxZH5ljL/wlEyadatbMUpmU4qZiEqmKMbMUcC/w7bCgDuDu6929w907Zs+encRtpcriliuOpzb8J+s+zFt+dz8tNoxZ3KAOdHyk6MBpWO18IU1IkkaXRFWMAV8DnnL3L1beJKkXcWu+y6kN37bxDub9/Bbe4i/FD+YA6ZnwzptLVsPEnWikCUnSyJLosS8C/hI4z8x2ZP/9WQLXlQkWt+Y77nnbNt7BaT2f4wTiBXUHWPJVWLkfrn8uVolj3IlGmpAkjaziwO7uW93d3P0Mdz8z++/7STROJt60lqM/IscdkwrNTUdNLio8b97PbyFtg7Hu68CzJ13Kou/P4pTlD7FozZbQNWYKhf2SKaQJSdLotKSAAKXXQgc4lBmOfH9kWWPebkavdR+7+Esed8DA2ubR/fqruGzbSQxk11aPW2kTNgFJVTEy1SRWFVMOTVCqL2GVJEb4srmlKl1GKdzNKII7DNHEtuMv5o8/9XVg/JU2Io2sZhOUZPJLbC30nJFeesi66AUGvJVdZ32esxd/nD/OOx4W1GPfX2SKU2CX5NZCh1i9dAfcjRdtFrvPupazF3981OsbtvdG/sWgQU+R0hTYJbJcsTC4Fh10LKOXbm3zsM/s4gSCleMK3bLp6dCgbqBBT5EYtGyvRJYr/sVb5sfbBzTXS48R1Emlg12Oioj6C8KJt0SByFSnHrtUvpTtI6tKD5AS9NQ5f0XJevSovyDalYYRiUWBXYAi5YpxlFhtsd9bWZu6kpWfuSnW5cpdx11ERlNgb1DjWgd9vNrmhqZh3KHXZ7H2yFIeOLyQlTEvp80wRCqjwN6AKl1Gt2znrxhTCdPvrSzPXM7G4XOA8tMoFf0FITLFKbA3oLC69IHMECs3PlleL3hnFzx8PQy8HHwdtRBX7utHVuH797DXj+fmzNKRoK40ikhtKbA3oKiqkr6BDH0DwXZzJXvxD14D3V8bfWzgZfjeJ4PPw4L7GUsxYNv2Xno2PY0pjSIyIRTYG1BUVUmh3LrkY4Luzi7oviv8TUODQRVMkcoWpVFEJpYCewPJDZj29g1EztwsNNK7L0y7FKM9R0XqmgJ7gygcMM0ulIgTDFz2Dx7hlf7MmPfNmZEOgvr3Phn0xuMoseeoiEwszTxtEFELeeVWQ7zx3W8eM7v0va3/l812Jdz3sfhBvbm15MxREZlY6rE3iFL7jhbWhn/o2J/xOb+TloFD8W/S+iq46LZYOxmJyMRRYG8QcfYd7Wx+jM5pq2D6HhhqAi++6fOImEsBiEh9UGCvA0nMEj33jbP59uP/Hr4aY9jAaJyg3pSCztsV0EUmGQX2CZbELNEN23u5t6d3VFA34JKz2ulsfizWLkY5nv2f/fZqfn3m/+BsBXWRSUeDpxMsapboLZuerugaDrTsugfuvyJ2UO/3Vj49eCWnHP5fnHnoDi7bdlKsDaRFpL4osE+wUoOe473GTS13sSJzW8mUyxFvYtiNPcOzRq3tAsEvmKvv3sGiNVsU4EUmkURSMWZ2F3AR8KK7n5bENaeKqEHPJjM2bO+NlY7JXWNx01aua+mi3fYBYFb8fYULdUWp+iJiIpKopHrsXwfekdC1ppSw3YsAhty54b4nYvWUr73wVL7V+rd8KXU7c5v2YVY8qDvQx6tjBfWcctNDIjJxEgns7v5jIMZcdCnUuaCd1UtOpzkkEscNpp29f8+ipl0le+gQpF6uHf5r/u7Mh9nc/CdltbWc9JCITBzl2OtA54J2hj18ZZdYwbTn68SI6Qw7XJO5gnsG38qjv3qJ1UtOH7Wn6Qeze5xGmaOt6UQmhZqVO5rZMmAZwPz582t120kjzgSjSDFq0ocd/mnobSOpl719A5GrMBaWYILWVBeZTGrWY3f39e7e4e4ds2fPrtVtJ42wXHvsYGpjc/QQbE037LBneBZXZ67kxiMfGXmt2C+MXHoovze/esnpGjgVmSQ0QalOVLTP51kfHrMphgP/MvxmLst8dszpcX5haE11kckrqXLH7wB/Cswysz3Aje7+teLvkkKhwXRnV7Cxxf49wXK5YWu2XPTF4GPP14O0jDXzT5lzWZHXQ8+n3rdIYzOPGLSrpo6ODu/u7q75fSeVqI0vUml497qS67csWrMlNGefW8ZXRCYfM+tx945S56kqph7t7ArWdwnbzSgzEPTgS6goZy8ik5py7PVkJO2yu/h5MbamqyhnLyKTmgJ7PShnv1GIvTWdBkBFpiYF9omWS7vEXIGRVFpb04lIUcqxT7RHVsUK6g6Qnhlr4FREpjb12GsprHSxRL7cHXp9FrdxKedccCWdZyi1IiLFKbDXys4u2HAlDGeCr/fvDr5OHxeZWy9cVvcnm56OzJknsb2eiDQGBfZaefj6o0E9ZzgDRw4HefO8dIw7vOzHctORy0Ytqxu1IFgS2+uJSONQjr1WoipeMgeDvHnbPMCgbR43pa7mrMH1Y9ZKj1rfJYnt9USkcajHXg2F5YvpmcXPP2PpqAHRM7f3ki5jdcUkttcTkcahHnvScrn0/B56sfr0kKCfv7oiQLPZSA88bEelqJ681k8XmZrUY69A/oDlh479Gdel7uaYgRciz3cfvWXdoLfwi99fztkh5+Zy43Fy59deeKrWTxeREeqxj1NuwLK3b4B3N23lusztRYM6BLXoe4ZnMezGnuFZ/E1mGVf/8g2R58fNnWv9dBHJpx57nnJKBm/Z9DRvH/oR17V20W77Yu03utdncc7gulHHrEgevJzcuZYPEJEcBfascksGO/5zM6tTd3KMDca6/iAtrD0ydsZosTx4RdvliciUpVRMVrklgze0/nPsoE56Jhvm//cx5YsA574xeptALb0rIuOhHntWybRHwXIAr+Wl4hcs2BDjS2u2AGPv8eivoq+jpXdFZDwU2LOKpj0KV2DcvxvDyC7NNVbbvDFb2I231ly5cxEpl1IxWUXTHqErMDpQMGKaSsOSr8Jndo1ZgVG15iJSK+qxZxVNe3wvagVGD3rnBRtNh1XXqNZcRGpFgT1PZ/NjdE5bBdP3wLS50LwCWBoE7bDt6trmBb3zPFHVNauXnM7qJacrXy4iVTdlAnvJGvWQPDoPfAqAba+/itN6Pkc6rwpmwFvZ9fqrxswajaquWbnxSXbceIECuYhU3ZTIsefPEnWO9qJHrbsSlkfPDMAjq7j6l2/g+szlo2aNXp+5fNSs0Q3be1m0ZkvoACxA30AmdJ0XEZGkJdJjN7N3AF8CmoE73X1NEtdNSrEa9ZEedNRORvv3sPfQAL2cw8bB0XXouVmjhemXYu1Qj11Eqq3iHruZNQNfBt4JvAn4gJm9qdLrJilWqWHb3PA3t80tWdES9oujnHaIiCQpiVTMQuAZd3/W3QeB7wIXJ3DdxMyZkWZx01a2tn6KZ6f9OVtbP8Xipq2jA/b5K4JyxXypNJy/ouQM0LgBW6WNIlILSaRi2oH8kpE9wB8VnmRmy4BlAPPnz0/gtvHd9qZfc1rPnSODn3NtHzen7mTXm04GzgtOytWdF242fcZSOrPXiRp8jZrclE+ljSJSK+YeMXsy7gXM3gdc6O6XZ7/+S2Chu18V9Z6Ojg7v7u6u6L5lufW02OWK4xGWY081GcdOb6GvP6PSRhFJhJn1uHtHqfOS6LHvAeblfT0X2JvAdZNTZGA0CVrTRUTqSRKBfRv5PfvGAAAIpklEQVTwBjM7BegFLgX+PIHrlq9goa6R9VoiJxhFDJiOg9Z0EZF6UfHgqbsfAf4a2AQ8BXS5+5OVXrdsuQlG+3cDfnSC0c6uogOjIiKNJpE6dnf/PvD9JK41Lju74P4rwAtKDrMTjEby6GG9eRGRBjP5lxTI9dQLg3pOLo9+xlIFchGZEiZnYM/PpVtTdFCHRPPoIiKTweQL7IWLdRUL6sqji8gUNPkWAQvd9CKENY/amk5EZKqYfIE9Ru15v7dyg3+SDUOLatAgEZH6MvlSMRE16Ue8iSacvX48a48sZePwW9hw3xMAqi8XkSll8vXYQ2rS+72VazJX8LrD3+acwXVsHA6W180tzSsiMpVMvh57wWJdv2UWf5t530gwL6SlckVkqpl8gR1G1aQ/vr2Xzfc9AcPh1TFaKldEpprJGdjz5PLnNz3wJK/0Z0a9pqVyRWQqmnw59hCdC9rZvuICbnv/mbTPSGNA+4w0q5ecroFTEZlyJn2PPZ9WWBQRaZAeu4iIHKXALiLSYBTYRUQazKTMsW/Y3qtt6EREIky6wF64cXRv3wA3aOkAEZERky4Vc8ump0eCeo6WDhAROWrSBfaoJQK0dICISGDSBfaoJQK0dICISGDSBfZrLzyVdKp51DEtHSAictSkGzzNDZCqKkZEJFxFgd3M3gesBH4fWOju3Uk0qhQtHSAiEq3SVMwuYAnw4wTaIiIiCaiox+7uTwGYWTKtERGRitVs8NTMlplZt5l1v/TSS7W6rYjIlFOyx25mPwBOCHnps+7+vbg3cvf1wHqAjo4Oj91CEREpS8nA7u5vq0VDREQkGZOujl1ERIoz9/FnRczsPcA/ALOBPmCHu18Y430vAc9HvDwL2DfuRlWX2la+em0XqG3jVa9tq9d2QXJtO8ndZ5c6qaLAXg1m1u3uHRPdjjBqW/nqtV2gto1XvbatXtsFtW+bUjEiIg1GgV1EpMHUY2BfP9ENKEJtK1+9tgvUtvGq17bVa7ugxm2ruxy7iIhUph577CIiUoEJCexm9j4ze9LMhs0scqTYzN5hZk+b2TNmtjzv+Clm9lMz+7WZ3W1mrQm2baaZbc5ee7OZHRdyzrlmtiPv3yEz68y+9nUzey7vtTNr1a7seUN5996Yd3yin9mZZvaT7Pd9p5m9P++1xJ9Z1M9O3uvTss/hmexzOTnvtRuyx582s5Lluwm36xoz+2X2GT1iZiflvRb6va1h2z5sZi/lteHyvNc+lP3+/9rMPjQBbbs1r13/amZ9ea9V7bmZ2V1m9qKZ7Yp43cxsXbbdO83sD/Neq94zc/ea/yNY5vdU4IdAR8Q5zcBvgNcBrcAvgDdlX+sCLs1+/hXgEwm2bS2wPPv5cuDmEufPBF4Gjsl+/XXgvVV4ZrHaBRyIOD6hzwz4L8Absp/PAV4AZlTjmRX72ck750rgK9nPLwXuzn7+puz504BTstdprmG7zs37WfpErl3Fvrc1bNuHgf8Z8t6ZwLPZj8dlPz+ulm0rOP8q4K4aPbf/BvwhsCvi9T8DHgYMeAvw01o8swnpsbv7U+5eavfphcAz7v6suw8C3wUuNjMDzgPuyZ73DaAzweZdnL1m3Gu/F3jY3fsTbEOYcts1oh6embv/q7v/Ovv5XuBFgolt1RD6s1OkzfcA52ef08XAd939sLs/BzyTvV5N2uXuj+b9LD0OzE3o3hW3rYgLgc3u/rK7vwJsBt4xgW37APCdBO8fyd1/TNCxi3Ix8E0PPA7MMLMTqfIzq+ccezuwO+/rPdljxwN97n6k4HhSXuvuLwBkP/5eifMvZewP0Reyf3bdambTatyu6Rasovl4Lj1EnT0zM1tI0PP6Td7hJJ9Z1M9O6DnZ57Kf4DnFeW8125XvowS9vZyw721S4rbtkuz36R4zm1fme6vdNrKpq1OALXmHq/ncSolqe1WfWdW2xrPKV4UMW+TdixxPpG1lXudE4HRgU97hG4DfEgSu9cD1wKoatmu+u+81s9cBW8zsCeA/Q86byGf2T8CH3H04e3jczyzqNiHHCv97q/bzVUTsa5vZB4EO4E/yDo/53rr7b8LeX6W2PQB8x90Pm9kVBH/xnBfzvdVuW86lwD3uPpR3rJrPrZSJ+DmrXmD3yleF3APMy/t6LrCXYL2FGWbWku1p5Y4n0jYz+w8zO9HdX8gGoReLXGopcL+7Z/Ku/UL208Nm9o/A39SyXdk0B+7+rJn9EFgA3EsdPDMzew3wEPC57J+luWuP+5lFiPrZCTtnj5m1AG0Ef1LHeW8124WZvY3gF+afuPvh3PGI721SAapk29z9d3lffhW4Oe+9f1rw3h8m1K5YbctzKfDJ/ANVfm6lRLW9qs+snlMx24A3WFDN0UrwDdvowcjDowS5bYAPAbHXhY9hY/aaca49JpeXDWy5vHYnwfaBNWmXmR2XS2OY2SxgEfDLenhm2e/h/QT5xn8ueC3pZxb6s1Okze8FtmSf00bgUguqZk4B3gD8rML2xG6XmS0A7gAWu/uLecdDv7cJtStu207M+3Ix8FT2803ABdk2HgdcwOi/Yqvetmz7TiUYiPxJ3rFqP7dSNgKXZatj3gLsz3ZkqvvMqjVaXOwf8B6C31iHgf8ANmWPzwG+n3fenwH/SvDb9bN5x19H8H+2Z4B/BqYl2LbjgUeAX2c/zswe7wDuzDvvZKAXaCp4/xbgCYLg9C3g2Fq1C3hr9t6/yH78aL08M+CDQAbYkffvzGo9s7CfHYL0zuLs59Ozz+GZ7HN5Xd57P5t939PAOxP+2S/Vrh9k/z+Re0YbS31va9i21cCT2TY8Crwx770fyT7LZ4C/qnXbsl+vBNYUvK+qz42gY/dC9md7D8G4yBXAFdnXDfhytt1PkFcFWM1nppmnIiINpp5TMSIiMg4K7CIiDUaBXUSkwSiwi4g0GAV2EZEGo8AuItJgFNhFRBqMAruISIP5/6ZMxNomeTh0AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],x@a);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate it!" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "rc('animation', html='html5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may need to uncomment the following to install the necessary plugin the first time you run this:
(after you run following commands, make sure to restart the kernal for this notebook)
If you are running in colab, the installs are not needed; just change the cell above to be ... html='jshtml' instead of ... html='html5'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#! sudo add-apt-repository -y ppa:mc3man/trusty-media \n", "#! sudo apt-get update -y \n", "#! sudo apt-get install -y ffmpeg \n", "#! sudo apt-get install -y frei0r-plugins " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "instead of writing a loop, we call matplotlib's `animation.FuncAnimation`, to run `animate` 100 times.\n", "\n", "Our `animate` function just calls the `update` we wrote above, and updates the `y_data` of our line before returning it." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(tensor(-1.,1))\n", "\n", "fig = plt.figure()\n", "plt.scatter(x[:,0], y, c='orange')\n", "line, = plt.plot(x[:,0], x@a)\n", "plt.close()\n", "\n", "def animate(i):\n", " update()\n", " line.set_ydata(x@a)\n", " return line,\n", "\n", "animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's essentially SGD visualized. The only difference between SGD and this is mini-batches.\n", "\n", "___\n", "\n", "In practice, we don't calculate on the whole file at once, but we use *mini-batches*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Learning rate\n", "- Epoch\n", "- Minibatch\n", "- SGD\n", "- Model / Architecture\n", "- Parameters\n", "- Loss function\n", "\n", "For classification problems, we use *cross entropy loss*, also known as *negative log likelihood loss*. This penalizes incorrect confident predictions, and correct unconfident predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (FastAI)", "language": "python", "name": "fastai" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 1 }