{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "WNixalo | 20181112 | fast.ai DL1v3 L2\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks. We will illustrate the concepts with concrete examples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of linear regression is to fit a line to a set of points." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n=100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "create a column of numbers for the x's, and a column of 1's.\n", "\n", "Instead of having a special case of $y = ax+b$, instead we'll always have a second x value which is always 1 – thus allowing us to do a simple matrix-vector product. $y = a_1x_1 + a_2x_2$\n", "\n", "Ahh, interesting. So our biases are encoded into the weights then?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.6318, 1.0000],\n", " [ 0.0486, 1.0000],\n", " [-0.5819, 1.0000],\n", " [-0.1692, 1.0000],\n", " [-0.3766, 1.0000]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.ones(n,2) # create an n x 2 tensor of 1's\n", "x[:,0].uniform_(-1.,1) # replace col0 with uniform random numbers\n", "x[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "create some coefficients; a1 is `3`, a2 is `2`\n", "\n", "This creates a 'vector' or rank-1 tensor. `3` & `2` represent the coefficients: the slope (3) and intercept (2) of our line.\n", "\n", "Ref: Fast.ai DL1v3 Lesson 2 [[1:21:35](https://youtu.be/Egp4Zajhzog?t=4895)]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3., 2.])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor(3.,2); a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "generate data by creating a line via `x@a` and add some random noise to it.\n", "\n", "The columns of 1s is just to make the linear function convenient." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "y = x@a + torch.rand(n)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0], y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You want to find **parameters** (weights) `a` such that you minimize the *error* between the points and the line `x@a`. Note that here `a` is unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**. \n", "\n", "Now, we're going to pretend we don't know the values of the coefficients (`a`) are 3 & 2. And we have to figure them out.\n", "\n", "DL1v3 Lesson 2 @ [[1:26:39](https://youtu.be/Egp4Zajhzog?t=5199)]\n", "> if we can find a way to find those 2 parameters to fit that line to those 100 points, we can also fit arbitrary functions that convert from pixel values to probabilities.\n", ">\n", "> The techniques we're going to learn to find these 2 numbers, work equally well for the 50 million numbers in ResNet34.\n", "\n", "*parameters* in machine learning are *coefficients* in statistics.\n", "\n", "A **regression** problem is where the dependent variable is continuous. In mathematics the actual is $y$ and the prediction: $\\hat{y}$" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat-y)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above function in mathematical form is: $\\frac{\\sum_{i=1}^n{(\\hat{y_i}-y_i)}^2}{n}$\n", "\n", "Codal and Mathematical forms are both just notations of the same thing; but the code notation is executable – allowing you to experiment – while the math note. is abstract." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we believe `a = (-1.0,1.0)` then we can compute `y_hat` which is our *prediction* and then compute our error." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "a = tensor(-1.,1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(6.9835)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = x@a\n", "mse(y_hat, y)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X+UXOV5H/Dvs6uRWGGilcI2hpXWEj1Uim0UyWwJp5vjWMIH2QaLjXAEtG7s2kV14jQBuwqrYxcktylL3WLqJD2JbFPbxTHIIBSBTIGwcn3gRDi71i9kwAiEQStqyRGrWGiRRqunf9x7d+/cvb/ve2fu3Pl+ztHZ3Zk7d15ml2feee7zPq+oKoiIqDzaGj0AIiIyi4GdiKhkGNiJiEqGgZ2IqGQY2ImISoaBnYioZBjYiYhKhoGdiKhkGNiJiEpmRiOe9MILL9SFCxc24qmJiJrWyMjIL1S1K+q4hgT2hQsXYnh4uBFPTUTUtETkZ3GOYyqGiKhkGNiJiEqGgZ2IqGQY2ImISoaBnYioZBjYiYhKpiHljkRErWDb7lF8+fEXcWRsHBd3dmD9qsXoX96d+/MysBMR5WDb7lFs2Lof49UJAMDo2Dg2bN0PALkHd6ZiiIhy8OXHX5wM6o7x6gS+/PiLuT83Z+xERDEkTascGRtPdLtJnLETEUVw0iqjY+NQTKVVtu0eDXzMxZ0diW43iYGdiEpj2+5R9A0OYdHADvQNDoUG3iTSpFXWr1qMjkp7zW0dlXasX7XYyJjCMBVDRE3JmxpZsaQLD42M5nKxMk1axXlOVsUQEcXgV3HynV2vQT3HObPqrMH04s4OjPoE8ai0Sv/y7roEci8jqRgReVVE9ovIHhFhP14iypVfasQb1B0mLlY2Mq2ShskZ+wpV/YXB8xER+UoSrE1crGxkWiUNpmKIqOkEpUa8Km1ibFYdlVZp1CpTP6aqYhTAEyIyIiLrDJ2TiMiXX2qk0iZob5Oa26rnFBu3HzBWHRMkTTlknkwF9j5VfR+ADwP4rIi833uAiKwTkWERGT527JihpyWiVtS/vBt3rrkM3Z0dEADdnR14x3kzMHFueqZ9bLyae5Bt5CpTP0ZSMap6xP56VEQeBnAFgB96jtkMYDMA9Pb2Bl3nICKKxZsaWTSwI/BYU9UxQRq5ytRP5hm7iJwvIhc43wO4GsBzWc9LRJRE1EXSpEE2yWKnRq4y9WMiFfNrAJ4Wkb0AfgRgh6r+HwPnJaIWlnQVqV/e3S1JkE2aMy9aOWTmVIyqvgLgNwyMhYhaQJzqkTQtb53bNz1yAG+eqtbclzTIhuXM/Z6/aOWQLHckorqJG7CTBlaHk3fPWnqYtoVAUeraGdiJqG6CAvamRw7UBMWsFyOzBtm0LQSKgt0diahuggLzm6eqNfnrRl+MLFrOPCkGdiKqm7DA7K75bnRg9auTv3PNZYVJtURhKoaI6mb9qsW45YE9vve5Z/NFuBhZpJx5UgzsRFQ3/cu7sXH7AYyNV6fd553NN3NgbTSmYoiorjaufk9T56+bAWfsRJQ7b/nh9Zd3Y+cLxwpR811GDOxElEjSGnG/2vWHRkab6mJks2EqhohiS9OetmidD1sBZ+xEFFuaFaFRi42KtEFFWTCwE1FsaVaEhq3iDGsxABSn90qzYWAnotjSLLVfv2pxTfAGpqpggj4BbNx+AKfPnovVBIwz/umYYyei2NKsCA1bxRk00x8br/oG/M9v2VuTzy/alnRFwRk7EcWWZEVonJl03E2pHROqNTP3tF0gy46BnYgSibMiNG573qA0zXmVtmk91R3uwF20LemKgqkYIjIuboljUJrmjo9OX53q5gTuRneBLCrO2InIuCQz6bBPAJ/fshcTqtNudwJ32IXZVsYZOxEZZ2Im3b+8G/997W+EXqxt9va6eeGMnYiMMzWTjnOxll0gp2NgJ6JUwqpeTPZTZ+BOjoGdiBKLU/WSJSBz0VE2zLETUWJ5NvbioqPsGNiJKLE868fZDTI7Y6kYEWkHMAxgVFWvNXVeIjLDZHojTc+YuLjoKDuTOfY/BvA8gF8xeE4iSskdyOd0VPDWmbOoTlg14WFNteLIs348zzeNVmEkFSMi8wFcA+DrJs5HRNl489Rj49XJoO7Ikt5IUj++bfco+gaHsGhgB/oGhyJz5WkajVEtUzP2ewD8CYALgg4QkXUA1gFAT0+PoaclIj9+eWo/WdIbflUv3nTPiiVdeGhkNFb7Xfd5nf8GVsWkkzmwi8i1AI6q6oiIfCDoOFXdDGAzAPT29k5fI0xExsQN2CbTG34lkN/Z9Rq8/7PH6b7I2vVsTKRi+gCsFpFXAdwPYKWI3GfgvESUUpyAbTq94fcpIWgGxwuh+coc2FV1g6rOV9WFAG4EMKSqH888MiJKzS9PXWkTzJ1dAQC0i0zOnKNy3nFz5EmCNS+E5osrT4lKKChPDSBWn3RH3L7qQPxNMyrtwguhORP1aYmZt97eXh0eHq778xK1ur7BId/g293ZgWcGVmY63vsmEKSzo4I9d1ydcOQEACIyoqq9Ucdxxk7UQpIu/knaVx2Y+pQQNGU8Me6/MxKZw8BOVEJBq0yTLv5Jery7miVots/8ev7YK4aoZMKaaPldVBUAK5Z0+Z4ry2IhLjRqHM7YiRogbt+WNP1dwppoPTOwEsM/O15TX64AHhoZRe+75k07d5bFQlxo1Di8eEpUZ34XGTsq7dOW5PsdJ7ACcXdIkFw0sMM3vy0ADg1ek/gCKhVH3IunTMUQ1VnctrRhC37CepRH7TfK7onlx8BOVGdxA2tUoA1q4hWV28660XTSpl5Uf8yxE9VZUKVJ5+wK+gaHJvPRczoqGIsoDYxTdujNbWdpuZtkwRI1DmfsRHXmu9y/XXDy7bM1lSxvnTmLSpuEnitN6WCSlrte3N2oOXDGTlRnfjPqt06fnTY7r04o5s6uYPbMGRgdG5+8cOoImmXnudE08/PNgTN2ogboX96NZwZW4tDgNXhmYGXgasyxU1U8M7AS99ywDJ12Ay/AWpYfNMvOc1adNT9P9cHATlQAYQHTmYG/eWoq+J8+ey7wXHnOqrnoqDkwsBMVQFjATDoDjzurTlPdkiU/T/XDHDtRAYRVstz6wB7fxwTNwONUvWSpbuHuRsXHwE5UEEEBM00jLsB6kxgdG6/ZVMO5P+xTAIN282Mqhqjgkua1nf4yTiXNhN02ZHRsHLc8sAfLNj0RuCEGq1vKgTN2ooJL0kzLm2Lx6xkzNl6dVjrpYHVLOTCwEzWBuHltvxSLHwVi18VT82FgJyqwpG17k6RSnC6RbKlbPgzsRAWVpnIl7obSANv0lhkvnhIVVJoVpH4XWv0w7VJuDOxEKF4r2m27R1NVrvQv78b1l3ejXazmYe0i+PiVPbjnhmVcVNRCmIqhlle0VrTOeIKEVa5s2z2Kh0ZGJ0scJ1Qnt71j2qV1ZJ6xi8h5IvIjEdkrIgdEZJOJgRHVS9Fa0YZVtkSlUIr230KNYWLGfhrASlU9KSIVAE+LyGOqusvAuYly417I46dRi3XCnjcqhcK2ugQYmLGr5aT9Y8X+V/8dsokS2LZ7FOsf3BtaQdKoxTpBz9vd2RGZGmJbXQIMXTwVkXYR2QPgKIAnVfVZE+clysumRw6gOhE8/2hk1UiW1rhsq0uAoYunqjoBYJmIdAJ4WETeq6rPuY8RkXUA1gFAT0+PiaclSs3d29yrO+VinaSLiYIkaSFg8rFUHqJqNmsiIncAeEtV/1vQMb29vTo8PGz0eYmSWDiwI/C+VwevSXw+b2UNYM2Ug3Lipt4EqLWIyIiq9kYdZ6IqpsueqUNEOgB8EMALWc9LlKfOjkqi26MkqUZx3gTcG1dv2Lq/4bXzVB4mUjEXAfiWiLTDeqPYoqqPGjgvUW42rn4P1n9vL6rnpj6xVtoEG1e/J9Zs2ntMnMqasCoc9kInkzIHdlXdB2C5gbEQ1U1QLhpA5GIlvwVNQZxqFL9UjRdLEskUrjylluXXCrdvcChyZ6G4rXHd1ShxHsOSRDKFgZ3IJWjWPDo2jkUDOyK7J7aL4JzqtBRO1GycJYlkEgM7kUtY4HYudAbtPgQA51RxyKeqJuy8acsriYKwuyORS5y2t2EFwkHplKCFQ/fcsAzPDKxkUCejOGOnluetcLn+8m7sfOEYjtjliEGSbC3HhUNUT8YXKMXBBUpUFFELi/oGh3xTKALgX13ZM/kGwEBN9RB3gRJn7NTSwhYW9S/vxvpVi3HrA3umzdwVwM4XjrHHORUSc+zU0qLa3PYv7w5Mx7DunIqKM3ZqaUHVKu6LoN0hx7DnCxURZ+zU0uK0uQ06ZsWSLvZ8oUJiYKeW1r+8G3euuSx0o+egY3a+cIzb0FEhMRVDLc9biugEZm9w96ZYbn1gj+/5mHunRuOMnVpe2ja63IaOioqBnVpekl7qbtyGjoqKqRhqeVElj0G4mpSKioGdWl6ckscgfrl3okZjYKfC8KsJB/KfEa9ftdi3rQBTKtSsGNipEPx2JVr/vb2AANUJnbzNu5uRCUypUNkwsFMh+F3AdO9H6jCxN2jQalEGcioLBnbKVdwl90lqv7PUift9MsjjUwBRI7HckXKTpD48Se13ljrxtKWNRM2EgZ1ykySI+tWEV9oElXapuS3rRc20pY1EzYSpGMpNkiAadAHT77YsKZMspY1EzYKBnXKTNIgGXcA0mftmaSO1gsypGBFZICI7ReR5ETkgIn9sYmDU/Iq45D5ON0eiZmdixn4WwOdV9ccicgGAERF5UlV/YuDc1MSS1IfXc8MKljZS2WUO7Kr6BoA37O9/KSLPA+gGwMBeAnEDblhteFQQzbsEkbscUasxmmMXkYUAlgN41ue+dQDWAUBPT0+2J9q3BXjqS8CJw8Cc+cClVwMvPTH181W3A0vXZnsOih1wswbmqA2l04zbCeSdsys4+fbZycVOrFunVmCs3FFE3gHgIQC3qOo/eu9X1c2q2quqvV1dXemfaN8W4JE/Ak68DkCtr8PfqP15683AXYusYym1uOWKaWvDv7htP/7phu/7XmAF0pUgemvn3zxVnbaClXXrVHZGAruIVGAF9e+o6lYT5wz01JeAaoz/4cePW28ATnDftwX4ynuBjZ3WVwb9SHHLFdPUhn9x237ct+s1TOj0tgGONCWIfm8yScdG1Owyp2JERAB8A8Dzqnp39iFFOHE4/rHVceuNALCCvPOGcOJ162eAKZsQccsVk5Q1OmmSoFm6I231TNyAzbp1KjMTM/Y+AP8awEoR2WP/+4iB8/qbMz/Z8ScO+8/y3UHfwVl9jRVLuiCe2/wCbtyyRneaJEyWEsQ4AbvRJZdEeTNRFfM0MO3///xcdXvt7DvKnPnBs3z37U7unrN6AFYQfmhkFO5EiQC4/vLpVS5xyxrjpEnaRfDMwMrJMSStZvFbgFRpF5w/cwZOjFdZFUMtoflWnjpB1lsVc+BhK6/uVumw3gie+pJ9cdXDPfsPm9UvXTu9EqdklTfeIHrqzNlpQVgB7HzhmO/j45Q1xkmT3PSbCybHk6bShr3ViZoxsANWQPUG1WvvDg++3lm+E/QdYbP6oNn8a7umyiw75lr3jb/ZdIHfL4gGyXLRMSgXD1gz9Zt+cwH+c/9lALKVQHIBErW65gzsQfwCvnM7ED7jnjM/eFYfNJsfvhdwkhXuTwtOyeXWm62fO+YBH76rsIE+biUJkO2iY1CfFr98uskSSKJWU67AHiYo6Dv8cvfOrH7ruoAHBZfq1Rg/DvzNZ4Hd9wGvPg3oBCDtwOWftD5pNFjcYJn1omPcNMm23aMQ+L+6rGYhitY6gT1K2Kw+KEefxMQZ4ND/nfpZJ6yFVQDQc2VD8/dBKZLOjgrOnzXDaK46Tprky4+/6BvUBWA1C1EMoiELRPLS29urw8PDdX/e1Lw5dgAInFMmJUDlvOmpnsr5wIxZdcnZe3PsQHCKJO75sly8XDSwI/CVfXXwmsTjISoLERlR1d6o47iDUhxL1wIf/SowZwEAsb72fspK1WSm/qWb1bfsvH3+bRJMtrJNsh1ekKB0SzfTMESxcMaehbsKZ7Iq5nj4Y7KqdFhvMgW9ENs3OOSb1unu7JisT49i+hMEUVnEnbEzx55F0AXZfVuAx26bCvId84B3XlabY3fMPB8481b853TX1rs9+rmaKp1q+2z8qazDt05eUddabhN7irIWnSgbztjr6dHPASPfrK2K6bky2UpaAIAAG8dqz+tciHVx/2rfxAV4+fL/iH+++t+lHX0sJmbsROSPM/Yiuvbu4PJG9ww/irdfzsg3fQ8TV6OHefgl5v74TzDx49sgqjgqXXj9feuNB3ruKUrUeAzsReCkdLw5+7OnrYuobt4Vs4D1CSAGAdAOBQR4J45h3sgGnN7/XzCresLYylmmUYgaj6mYoovTo2bTvNjBPZG2mcC5M9b3BV89S9QKmIopi6gVs4CVq/fJsWfmBHXAShNtvdnqj1OA1bJEFIx17GVw7d14+V03YkKtC6bOv1wM39vyfeqJio4z9pL4vZ/fgNHTqyd/Xt32NO6Y8W3MaztpuFm++pdb+tb0N1+nS6Iy4Iy9JLx14tvP/RYuP7MZl7z918CarwFzFkABnEUbzikwhgswIZV0T+ZtcezdYHz8uGfV7Dpg4xzuSkVUJ5yxl0TovqNLrwGWroVg6hfeCdTOsmfOjr9QyltuGbnBuJ0XOvG61eXysds4myfKEQN7SaSqH/demPWmU06frL2ACviXWybZYHzizFS9fk3fegHaZwITp+3BswqHKC0G9gLI2g3RMWtG22Rgnzu7gjs++p7Q83ifd8WSX8fO01/FkbfHcfF5HVj/0cXob38mutwyaJOSRHQqqANTVTjOZiUAgz1RTAzsDZZ2b8+wcwDA29VziZ/3vl2vTd4/OY41fei/9bnwASTdYDyt8ePAtj+wvmdwJwrEi6cNFra3Z57niLMdXuxxeNsad8yz/gHWzyadq1qfIIgoEGfsDWaiG2Kac8Q9/+jYOPoGh6LTQ2GdLmvy9r+0gnMWUTn9OKt1iUrMyIxdRO4VkaMiEvGZnbyCNpVoE4m9OUXQOcL2B02yd2iazTImLV0L3Pqc1Y3ytkNA//+0Z/YZeKty3Lyll04lzl2LgI2dLLmklmCkV4yIvB/ASQDfVtX3Rh3PXjFT/PLjjribS0RtTOF3cRZA4PMGMd56193G2FsVE6StYr05BM3Av/Le6Au5bRVg1gW13TQLtLk4UZC4vWKMNQETkYUAHmVgT27b7lF8fsteTPj8LuIG06DKmrCgD8BTFdOFnS8c862HB6xs+aF67Tnq3awEiFcVs7ET2feiFaAyG6ieYiqHCoWBvckEbeCcNZim2fiiqTfLiDNjT6qtYn2acFoos+ySGqRwm1mLyDoRGRaR4WPHjtXraZtGmjx5lG27RwNn32EXT9evWoyOSnvNbU2zWcZVtxvaZNzlXLW2L/74ceDhzzBvT4VVt8CuqptVtVdVe7u6uur1tE3DdDB1UjBBwt4w+pd34841l6G7swMCa6beNBtJ+5VetqXsiRNGJzz9cG4G/lMXgz0VAssdC8L0zkNhdepx3jD6l3c3RyD3E9Uq4cxJq7WBad52CY/80dR4iOrIVFXMdwF8AMCFAH4O4A5VDdz5gTn2/AXl7AHgnhuWNW/QNmEy0BvOxfuZs8Aq96x5XtbXUzp13UFJVW8ycR4yJ6jbY3dnR2sHdcB/Rv/ILa48uqsqpmMu8PaJ9FsPOoupnPp6p+2CuwEaSy3JMKZiSipVt8dWFbX9oLf0snL+9E3GgziLqcJaG+uEtbXhPxwEjr/CzUooM25mXWKmukaSj0c/F73PbKXDupC7dK2h+noXlly2JG5mTc19AbTorr0b6LmydiY/83ygfZb/LNtIa2MXdrqkEJyxE9WDN8duivvirMPdqoH5+1LhjL3JMY1SMs6s2nQ1jrfTpTdF5OTvh7/B9E0L4Yy9gKKaelEJeGfVC38LOPyj5DN674x907yYFTyCyZw/A37T4Iy9iYVtnBEU2DnDbzLX3j09PZJ0IVVbZfr+s7HLMl0TOmcbwtd2MWVTEgzsBZR04wwT2+tRAfiVXda0NnYJmmVLe/qa++F7rQvCnLk3PQb2AgpaXBTU3yXNDJ+ahN/MPszln4wuwwyk1icGb2D3rpi99GrgpSe4grbAuOdpASVtCGZiez0qiWvvBi5ckv7x3ouxfjtSDX+j9uetN1vNz9j0rDA4Y6+juHnwpA3Bks7wqeT+8Fm7OuZeJF4U5d12MGzFrNv4cesN4LVdnM0XAKtiMkhywTJtpUuc52AVDUXy25HKy71S1pF4xayr2iYIq3BSq/sOSkmUIbAnDaZpdiVK8hysiqHE4nSbzGNHKi8G+tgY2HOWNFCn2fquqbeoo3LIa8VskDkLmL4JwTr2FJLMepNesEyTB+dFUWq4mhWzdn39+BiAcz4Hx0jDRHE2KHltF7DvfuCMp4smA38srIqxOWmP0bFxKKZqwbftHvU9PukepSuW+G8HGHR7mucgysXStdbq1o1jwG2HgDV/ZW89CKtuHrB+7v2Umf1mq+NW5Y03qANTgZ8VOKE4Y7clrQVP2u985wv+G3gH3Z7mOYjqIqx/fc+V+e9OVR23ngOYei5nYRZn9AAY2CclTXskLUlMk1YxvQ8qUe7cQT9OJU5azszdyf07q23dO1M5WjDYM7Db0uTAk/Q7T1trzp7q1LSCZvaRAT9Grl7a41/QbcGNxZljtyVd7Rlm2+5R9A0OYdHADvQNDmHb7lGj5ydqakvXWrn6jSesf2u+ZufsZSpX39Ye/PhKR/J+OO70TQtguaOLiVrwsNpzgGkVolj2bQEevSW4KiZVHl+sC8BNjHXshsUN+qw9J6qDNPX1frtNRT1HwZqfsY7dI8tsPElb3KCLoX7BnohS8u5IFdWuuNIxvXd9GO8bh9P8zFHwvH1L5NiT1qh7hZVCup+jb3Ao8JKP2McQkSGT9fUngDuOW197Pz1VW++Ys2B6D5wocZqfefP2+7ZYLRg2dlpfG1hrb2TGLiIfAvA/ALQD+LqqDpo4rylZ+5VHlSr65dW91B4Hc+pEOUravz6It31x1HF+M/wGdrvMPGMXkXYAfwHgwwDeDeAmEXl31vOalHVpftQKUL83jizPR0QN5m1fHHWc3wy/Om61Tnb3rq/TqlkTqZgrABxU1VdU9QyA+wFcZ+C8xmRdmh9Vqpj1DYKICuaq26PbI7jz9oEzfE9ytk5llyYCezcAd93RYfu2GiKyTkSGRWT42LHgZfR5yFpD3r+8G3euuQzdnR0QWBUu7ta5cQI2a9aJmsjStVZevqa+/tO1P7vz9nFn+ED8NE8GJnLs4nPbtGuIqroZwGbAKnc08LyxmViaH7YC1K+nS6VN8I7zZmDsVJU160TNKKwnjtdVt/uUXwasoE3yJpCSicB+GMAC18/zARwxcF6j8lyaz54uRC3O297YqXvf+9e1wT5p2WVKmRcoicgMAD8FcBWAUQB/D+BfquqBoMc04wIlIqLE4uxSlUDdFiip6lkR+UMAj8Mqd7w3LKgTEbWMJOkcg4zUsavq9wF838S5iIgom5ZYeUpE1EoY2ImISoaBnYioZBjYiYhKphRte01skEFEVBZNP2P3a8l7ywN7sGzTE2yTS0QtqekDe1BnxbHxaqKe60REZdH0gT2ss6J3MwwiolbQ9IE9qrMie6ATUatp+sDu15LXjT3QiajVNH1VjFP9sumRA3jzVLXmPvZAJ6JW1PQzdsAK7rtvvxr33LAscDMMIqJW0fQzdrc8e64TETWLUszYiYhoCgM7EVHJMLATEZVMU+bY2RuGiChY0wV2pzeM00ZgdGwcG7buBwAGdyIiNGEqxq83DFsHEBFNabrAHtQigK0DiIgsTRfYg1oEsHUAEZGl6QK7X28Ytg4gIprSdBdPnQukrIohIvKXKbCLyO8C2Ajg1wFcoarDJgYVha0DiIiCZU3FPAdgDYAfGhgLEREZkGnGrqrPA4CImBkNERFl1nQXT4mIKFzkjF1E/hbAO33u+oKq/k3cJxKRdQDWAUBPT0/sARIRUTKRgV1VP2jiiVR1M4DNANDb26smzklERNMxFUNEVDKimn7yLCK/A+DPAHQBGAOwR1VXxXjcMQA/C7j7QgC/SD2ofHFsyRV1XADHllZRx1bUcQHmxvYuVe2KOihTYM+DiAyram+jx+GHY0uuqOMCOLa0ijq2oo4LqP/YmIohIioZBnYiopIpYmDf3OgBhODYkivquACOLa2ijq2o4wLqPLbC5diJiCibIs7YiYgog4YEdhH5XRE5ICLnRCTwSrGIfEhEXhSRgyIy4Lp9kYg8KyIvicgDIjLT4NjmiciT9rmfFJG5PsesEJE9rn9vi0i/fd83ReSQ675l9RqXfdyE67m3u25v9Gu2TET+zv697xORG1z3GX/Ngv52XPfPsl+Hg/brstB13wb79hdFJLJ81/C4PiciP7Ffo6dE5F2u+3x/t3Uc2ydF5JhrDP/Wdd8n7N//SyLyiQaM7Suucf1URMZc9+X2uonIvSJyVESeC7hfROSr9rj3icj7XPfl95qpat3/wWrzuxjADwD0BhzTDuBlAJcAmAlgL4B32/dtAXCj/f1fAvh9g2P7rwAG7O8HANwVcfw8AMcBzLZ//iaAj+XwmsUaF4CTAbc39DUD8M8AXGp/fzGANwB05vGahf3tuI75AwB/aX9/I4AH7O/fbR8/C8Ai+zztdRzXCtff0u874wr73dZxbJ8E8Oc+j50H4BX761z7+7n1HJvn+H8P4N46vW7vB/A+AM8F3P8RAI8BEABXAni2Hq9ZQ2bsqvq8qkbtPn0FgIOq+oqqngFwP4DrREQArATwoH3ctwD0GxzedfY54577YwAeU9VTBsfgJ+m4JhXhNVPVn6rqS/b3RwAchbWwLQ++fzshY34QwFX263QdgPtV9bSqHgJw0D5fXcalqjtdf0u7AMw39NyZxxZiFYAnVfW4qr4J4EkAH2rg2G4C8F2Dzx9IVX8Ia2IX5DoA31bLLgCdInIRcn7Nipxj7wbwuuvnw/ZtvwpgTFXPem4+8ptdAAADf0lEQVQ35ddU9Q0AsL/+k4jjb8T0P6I/tT92fUVEZtV5XOeJyLCI7HLSQyjYayYiV8Caeb3sutnkaxb0t+N7jP26nID1OsV5bJ7jcvs0rNmew+93a0rcsV1v/54eFJEFCR+b99hgp64WARhy3Zzn6xYlaOy5vma5bY0n2btC+jV515DbjYwt4XkuAnAZgMddN28A8P9gBa7NAG4D8KU6jqtHVY+IyCUAhkRkP4B/9Dmuka/Z/wbwCVU9Z9+c+jULehqf27z/vbn9fYWIfW4R+TiAXgC/7bp52u9WVV/2e3xOY3sEwHdV9bSIfAbWJ56VMR+b99gcNwJ4UFUnXLfl+bpFacTfWX6BXbN3hTwMYIHr5/kAjsDqt9ApIjPsmZZzu5GxicjPReQiVX3DDkJHQ061FsDDqlp1nfsN+9vTIvK/APyHeo7LTnNAVV8RkR8AWA7gIRTgNRORXwGwA8AX7Y+lzrlTv2YBgv52/I45LCIzAMyB9ZE6zmPzHBdE5IOw3jB/W1VPO7cH/G5NBajIsanqP7h+/BqAu1yP/YDnsT8wNK5YY3O5EcBn3Tfk/LpFCRp7rq9ZkVMxfw/gUrGqOWbC+oVtV+vKw05YuW0A+ASA2H3hY9hunzPOuafl8uzA5uS1+2FtH1iXcYnIXCeNISIXAugD8JMivGb27/BhWPnG73nuM/2a+f7thIz5YwCG7NdpO4AbxaqaWQTgUgA/yjie2OMSkeUA/grAalU96rrd93draFxxx3aR68fVAJ63v38cwNX2GOcCuBq1n2JzH5s9vsWwLkT+neu2vF+3KNsB/J5dHXMlgBP2RCbf1yyvq8Vh/wD8Dqx3rNMAfg7gcfv2iwF833XcRwD8FNa76xdct18C63+2gwC+B2CWwbH9KoCnALxkf51n394L4Ouu4xYCGAXQ5nn8EID9sILTfQDeUa9xAfgX9nPvtb9+uiivGYCPA6gC2OP6tyyv18zvbwdWeme1/f159utw0H5dLnE99gv2414E8GHDf/tR4/pb+/8J5zXaHvW7rePY7gRwwB7DTgBLXI/9lP1aHgTwb+o9NvvnjQAGPY/L9XWDNbF7w/7bPgzrushnAHzGvl8A/IU97v1wVQHm+Zpx5SkRUckUORVDREQpMLATEZUMAzsRUckwsBMRlQwDOxFRyTCwExGVDAM7EVHJMLATEZXM/wcJVynJJEYYNwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],y_hat);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for `a`? How do we find the best *fitting* linear regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like to find the values of `a` that minimize `mse_loss`.\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([-1., 1.], requires_grad=True)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(a); a" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def update():\n", " y_hat = x@a # prediction\n", " loss = mse(y, y_hat) # MSE\n", " if t % 10 == 0: print(loss) # printout\n", " loss.backward() # calculate gradient\n", " with torch.no_grad(): # turnoff gradient calculations when updating SGD\n", " a.sub_(lr * a.grad) # subtract learning rate x gradient from coeffs a inplace\n", " a.grad.zero_() # zero-out the gradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch keeps track of how our loss, mse, was calculated for us, and lets us calculate the derivative.\n", "> So if you do a mathematical operation on a tensor in pytorch, you can call `backward` to calculate the derivative." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(6.9835, grad_fn=)\n", "tensor(1.4813, grad_fn=)\n", "tensor(0.5689, grad_fn=)\n", "tensor(0.2572, grad_fn=)\n", "tensor(0.1468, grad_fn=)\n", "tensor(0.1077, grad_fn=)\n", "tensor(0.0939, grad_fn=)\n", "tensor(0.0889, grad_fn=)\n", "tensor(0.0872, grad_fn=)\n", "tensor(0.0866, grad_fn=)\n" ] } ], "source": [ "lr = 1e-1\n", "for t in range(100): update()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],x@a);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate it!" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "rc('animation', html='html5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may need to uncomment the following to install the necessary plugin the first time you run this:
(after you run following commands, make sure to restart the kernal for this notebook)
If you are running in colab, the installs are not needed; just change the cell above to be ... html='jshtml' instead of ... html='html5'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#! sudo add-apt-repository -y ppa:mc3man/trusty-media \n", "#! sudo apt-get update -y \n", "#! sudo apt-get install -y ffmpeg \n", "#! sudo apt-get install -y frei0r-plugins " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "instead of writing a loop, we call matplotlib's `animation.FuncAnimation`, to run `animate` 100 times.\n", "\n", "Our `animate` function just calls the `update` we wrote above, and updates the `y_data` of our line before returning it." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(tensor(-1.,1))\n", "\n", "fig = plt.figure()\n", "plt.scatter(x[:,0], y, c='orange')\n", "line, = plt.plot(x[:,0], x@a)\n", "plt.close()\n", "\n", "def animate(i):\n", " update()\n", " line.set_ydata(x@a)\n", " return line,\n", "\n", "animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's essentially SGD visualized. The only difference between SGD and this is mini-batches.\n", "\n", "___\n", "\n", "In practice, we don't calculate on the whole file at once, but we use *mini-batches*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Learning rate\n", "- Epoch\n", "- Minibatch\n", "- SGD\n", "- Model / Architecture\n", "- Parameters\n", "- Loss function\n", "\n", "For classification problems, we use *cross entropy loss*, also known as *negative log likelihood loss*. This penalizes incorrect confident predictions, and correct unconfident predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (FastAI)", "language": "python", "name": "fastai" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 1 }