{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Table of Contents\n", "

1  Linear Regression problem
2  Gradient Descent
3  Gradient Descent - Classification
4  Gradient descent with numpy
" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.learner import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks. We will illustrate the concepts with concrete examples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of linear regression is to fit a line to a set of points." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Here we generate some fake data\n", "def lin(a,b,x): return a*x+b\n", "\n", "def gen_fake_data(n, a, b):\n", " x = s = np.random.uniform(0,1,n) \n", " y = lin(a,b,x) + 0.1 * np.random.normal(0,3,n)\n", " return x, y\n", "\n", "x, y = gen_fake_data(50, 3., 8.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEKCAYAAAAB0GKPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAF5dJREFUeJzt3X2QZXV95/H3x4EoikmAaRTRdiQhW1KkkrgNUeO6JBgl\nlAUVVyNmLcFSp3RFs2Y3G63dCpbZZDUPlexqEndUCt2NBKNRp9SNUj6EuAnWNKhxkHXBcRxGUFrH\nkJoSdQa++8e9uE1ze+b00Oece+59v6q6+j78bvf30MP93N/TOakqJEk6mof0XYAkaRgMDElSIwaG\nJKkRA0OS1IiBIUlqxMCQJDViYEiSGjEwJEmNGBiSpEaO67uAzbR169batm1b32VI0mDccMMN36yq\nhSZtZyowtm3bxvLyct9lSNJgJPlq07YOSUmSGmktMJJcmeTOJLtXPfa8JDcluTfJ0hFeuzfJF5J8\nLoldBkmaAm32MK4CLljz2G7gOcB1DV7/81X101W1brBIkrrT2hxGVV2XZNuax24GSNLWr5UktWRa\n5zAK+FiSG5Js77sYSdL0rpL6uaq6PcmpwLVJ/k9VTRzGGgfKdoDFxcUua5SkuTKVPYyqun38/U7g\n/cC5R2i7o6qWqmppYaHRUmJJ0jGYusBI8ogkj7zvNvBMRpPlkjRX9qwc5Jpd+9izcrDvUoAWh6SS\nXA2cB2xNsh+4AjgAvBlYAD6c5HNV9awkjwHeXlUXAo8C3j+eGD8OeHdV/XVbdUrSNNqzcpBnv/nT\nVEECH3rV0zhj4cRea2pzldQL1nnq/RPa3g5cOL69B/iptuqSpCHYtfcAVXD3oXs44fgt7Np7oPfA\nmLohKUkSnLPtZBI44fgtJKP7fZvWVVKSNNfOWDiRD73qaezae4Bztp3ce+8CDAxJmlpnLJw4FUFx\nH4ekJEmNGBiSpEYMDElSIwaGJKkRA0OS1IiBIUlqxMCQJDViYEiSGjEwJEmNGBiSpEYMDElSIwaG\nJKkRA0OS1IiBIUlqxMCQJDViYEiSGmktMJJcmeTOJLtXPfa8JDcluTfJ0hFee0GSLyW5Nclr26pR\nktRcmz2Mq4AL1jy2G3gOcN16L0qyBfgT4JeAs4AXJDmrpRolSQ21FhhVdR1wYM1jN1fVl47y0nOB\nW6tqT1V9H/gL4OKWypSkQdmzcpBrdu1jz8rBzn/3NF7T+3TgtlX39wM/u17jJNuB7QCLi4vtViZJ\nPdqzcpBnv/nTVEECH3rV0zq95vc0TnpnwmO1XuOq2lFVS1W1tLCw0GJZktSvXXsPUAV3H7qHqtH9\nLk1jYOwHHrfq/mOB23uqRZKmxjnbTiaBE47fQjK636VpHJLaBZyZ5AnA14BLgF/ttyRJ6t8ZCyfy\noVc9jV17D3DOtpM7HY6CFgMjydXAecDWJPuBKxhNgr8ZWAA+nORzVfWsJI8B3l5VF1bV4SSXAx8F\ntgBXVtVNbdUpScdqz8rBzt+8z1g4sfOguE+q1p0eGJylpaVaXl7uuwxJc6DvCejNkuSGqlp3X9xq\n0ziHIUlT70gT0H0ufW3TNM5hSNLUW28CelZ6HpMYGJLm0oOdf1hvAnp1z+OE47ewa+8BA0OShmqz\negGTJqD7XvraJgND0txpsxfQ99LXNhkYkuZO272APpe+tsnAkDR3ZrkX0CYDQ9JcmtVeQJvchyFJ\nA9X1fg97GJI0QH3s97CHIUkD1Mepzg0MSRqgPvZ7OCQlSQPUx0ovA0OSBqrrlV4OSUmSGjEwJEmN\nGBiSpEYMDEkzb1YvaNQ1J70lzbRZvqBR11rrYSS5MsmdSXaveuzkJNcmuWX8/aR1XntPks+Nv3a2\nVaOk2dfHBrdZ1eaQ1FXABWseey3w8ao6E/j4+P4kd1fVT4+/LmqxRkkzbpYvaNS11oakquq6JNvW\nPHwxcN749juBTwG/2VYNkuSpzDdP13MYj6qqOwCq6o4kp67T7mFJloHDwBur6gPr/cAk24HtAIuL\ni5tdr6QZ4KnMN8e0rpJarKol4FeBP07yY+s1rKodVbVUVUsLCwvdVShJc6brwPhGktMAxt/vnNSo\nqm4ff9/DaNjqZ7oqUJI0WdeBsRO4dHz7UuCDaxskOSnJQ8e3twI/B3yxswolSRO1uaz2auDvgX+W\nZH+SlwBvBH4xyS3AL47vk2QpydvHL30isJzk88AnGc1hGBiS1LM2V0m9YJ2nzp/Qdhl46fj23wE/\n2VZdkqRjM62T3pKkKWNgSJIaMTAkSY0YGJKkRgwMSVIjBoYkqREDQ9LM8sJJm8sLKEmaSV44afPZ\nw5DUmj4/4XvhpM1nD0NSK/r+hL/RCyftWTnoNTOOwsCQ1IrVn/BPOH4Lu/Ye6PSNeCMXTuo73IbC\nwJDUimm4NGrTCyf1HW5DYWBIasWQLo06DeE2BAaGpNYM5dKoQwq3PhkYksRwwq1PLquVJDViYEiS\nGjEwJEmNtBoYSa5McmeS3aseOznJtUluGX8/aZ3XXjpuc0uSS9usU1I7PJfTbGm7h3EVcMGax14L\nfLyqzgQ+Pr5/P0lOBq4AfhY4F7hivWCRNJ3u2wz3+p1f5Nlv/vQDQsMwGZ5WV0lV1XVJtq15+GLg\nvPHtdwKfAn5zTZtnAddW1QGAJNcyCp6rWypV0iZbbzPcnpWDfOQLd/CWT95KiDurB6SPZbWPqqo7\nAKrqjiSnTmhzOnDbqvv7x49JGohJm+Hu63UcuudeDt1TAO6sHpBp3YeRCY/VxIbJdmA7wOLiYps1\nSdqASZvhrtm1jyp+EBbHb4k7qwekj8D4RpLTxr2L04A7J7TZz/8ftgJ4LKOhqweoqh3ADoClpaWJ\noSKpH2s3w63udRTF5T//41z4k6fZuxiIPgJjJ3Ap8Mbx9w9OaPNR4HdXTXQ/E3hdN+VJm89TZ494\nCo5hazUwklzNqKewNcl+Riuf3gi8J8lLgH3A88Ztl4CXV9VLq+pAkt8Gdo1/1BvumwCXhsZTZ9+f\np+AYrrZXSb1gnafOn9B2GXjpqvtXAle2VJrUmSGeOtsekSaZ1klvaWYM7dTZ9oi0HgNDatnQxu2H\n2CNSNwwMqQOrx+2nfbhnaD0idcfAkDo0hOGeofWI1B0DQ+rQUIZ7XMmkSTy9udShroZ7PLGf2mAP\nQ+pQF8M9Qxj20jAZGFLH2h7uGcqwl4bHISlpBqwegmpz2KvNoS6H0aafPQxpCm1k6e2kIag2hr3a\nHOpyGG0YDAxpymz0zXPSENTzz1nc9DfcNoe6HEYbBoekpA1qe+hk9Ztn1ej+kXS18qrN3+NmwWE4\nag8jyeXAn1fVtzuoR5pqXQydbPTNs6uNdm3+HjcLDkOTIalHA7uS3Mjo7LEfrSovVKS51MXQybG8\neXa10a7N3+Nmwel31CGpqvpPwJnAO4DLgFuS/G6SH2u5NmnqdDV0csbCia3MQ0gPRqNJ76qqJF8H\nvg4cBk4C3pvk2qr6D20WKE0Th040z5rMYbya0aVUvwm8HfiNqjqU5CHALYCBobni0InmVZMexlbg\nOVX11dUPVtW9SZ7dTllSfzZ6+vH12k/7acyljTpqYFTVbx3huZs3txypXxtdBbVeezeiaRb1sg8j\nya8l2Z3kpiT/dsLz5yW5K8nnxl/rhpa0mTa6B2K99hv9OdIQdL7TO8nZwMuAc4HvA3+d5MNVdcua\npn9bVQ55qVMbXQW1XvtH//DDuOfee3nocQ9xI5pmRh+nBnkicH1VfQcgyd8Avwz8Xg+1SPez0VVQ\nk9rvWTnIK/78RpJQBX/2r5/kcJRmQh+BsRv4nSSnAHcDFwLLE9o9JcnngduBf19VN3VYo+bYRq+/\nvXbV1H3DUd87fC8nHL+Fr//TdzupW2pb54FRVTcneRNwLXAQ+DyjvR2r3Qg8vqoOJrkQ+ACjzYMP\nkGQ7sB1gcXGxtbo1f4514trzImlW9TLpXVXvqKonVdXTgQOM9nOsfv6fqurg+PZHgOOTbF3nZ+2o\nqqWqWlpYWGi9ds2PY524vm+Y6vUXneXqKM2UXk5vnuTUqrozySLwHOApa55/NPCN8Q7zcxkF27d6\nKFVz7MH0FNzcp1nU1/Uw3jeewzgEvLKqvp3k5QBV9VbgucArkhxmNM9xiSc81Fptb4zzNCDS/WWW\n3oeXlpZqeXnS/LlmjRvjpM2R5IaqWmrS1gsoaZDcGCd1z8DQILkSSeqe1/TWIDm/IHXPwNBguRJJ\n6pZDUpKkRgwMSVIjBoYGY8/KQa7ZtY89Kwf7LkWaS85haBDcdyH1zx6GBsF9F1L/DAwNgvsupP45\nJKVBcN/FZG2fT0tazcDQYLjv4v6c11HXHJKSWtTmyi7nddQ1exganKEMw7TdA3BeR10zMDQoQxqG\nWd0DOOH4Lezae2BTa3VeR10zMDQobb8Jb6YuegDO66hLBoYGZUjDMPYANGsMDA3K0N6E7QFolhgY\nGpy23oSHMpku9aWXwEjya8DLgABvq6o/XvN8gP8KXAh8B7isqm7svFDNjSFNpkt96XwfRpKzGYXF\nucBPAc9OcuaaZr8EnDn+2g78WadFau64p0E6uj427j0RuL6qvlNVh4G/AX55TZuLgXfVyPXAjyY5\nretCNT+GNJku9aWPIandwO8kOQW4m9Gw0/KaNqcDt626v3/82B1rf1iS7Yx6ISwuLrZRr+bA0CbT\npT50HhhVdXOSNwHXAgeBzwOH1zTLpJeu8/N2ADsAlpaWJraRmnBFk3RkvZxLqqreUVVPqqqnAweA\nW9Y02Q88btX9xwK3d1WfJOmBegmMJKeOvy8CzwGuXtNkJ/CijDwZuKuqHjAcJUnqTl/7MN43nsM4\nBLyyqr6d5OUAVfVW4COM5jZuZbSs9sU91SlJGuslMKrqX0x47K2rbhfwyk6LUiNubpPmlzu91Zib\n26T55gWU1Jib26T5ZmCoMTe3SfPNISkd0do5Cze3SfPLwNC6E9nrzVkYFNJ8MjDm3JEmsod0dTtJ\n7XMOY84daSLbOQtJq9nDmHNHCgXnLCStZmDMuaOFgnMWku5jYMhQkNSIcxiSpEYMDElSIwaGJKkR\nA0OS1IiBMUf2rBzkml372LNysO9SJA2Qq6TmxKQd3YB7LCQ1ZmAMwGZctGjtaT4+8oU7+NNPfdlr\nW0hqzMCYcpt10aK1O7qBTs8T5ZX6pOEzMKbcZp0AcO2OboA//dSXOzlPlFfqk2ZDL4GR5DXAS4EC\nvgC8uKq+u+r5y4DfB742fugtVfX2rut8MDbrE/VmngBw7Y7urs4T5VlvpdnQeWAkOR14NXBWVd2d\n5D3AJcBVa5peU1WXd13fZtjMT9RtngCwq1OCeNZbaTb0NSR1HHBCkkPAw4Hbe6qjFZv9iXro53ry\nrLfSbOg8MKrqa0n+ANgH3A18rKo+NqHpv0rydOD/Aq+pqtsm/bwk24HtAIuLiy1VvTF+on6goYee\nJEhVdfsLk5OA9wHPB/4R+EvgvVX1P1e1OQU4WFXfS/Jy4Feq6heO9rOXlpZqeXm5pco35kiXPfWT\ntqRpkeSGqlpq0raPIalnAF+pqhWAJH8FPBX4QWBU1bdWtX8b8KZOK9wEkz5Rb/ZqIcNHUpf6CIx9\nwJOTPJzRkNT5wP26BUlOq6o7xncvAm7utsR2bObcRtdLVQ0nSX3MYXwmyXuBG4HDwGeBHUneACxX\n1U7g1UkuGj9/ALis6zrb8GDmNta+YXe5VNV9FJKgp1VSVXUFcMWah39r1fOvA17XaVEdONbVQpPe\nsLucWHcfhSRwp3fnjmW10KQ37Oefs9jZUlVXfUkCA2MQ1nvD7mqpqvsoJIGBMQjT8IbtPgpJBsZA\n+IYtqW9ecU+S1IiBIUlqxMCQJDViYEiSGjEwJEmNGBhTZM/KQa7ZtY89Kwf7LkWSHsBltVPC8zVJ\nmnb2MKbE6tN/VI3uS9I0MTCmhOdrkjTtHJKaEtNw+g9JOhID4xi0dTEhT/8haZoZGBvk5LSkeeUc\nxljTJa1OTkuaV/Yw2Fiv4WiT0177WtKs6iUwkrwGeClQwBeAF1fVd1c9/1DgXcA/B74FPL+q9rZV\nz0YuQXqkyWmHqyTNss6HpJKcDrwaWKqqs4EtwCVrmr0E+HZV/TjwR8Cb2qxpo0taz1g4keefs/iA\nMHC4StIs62tI6jjghCSHgIcDt695/mLg9ePb7wXekiRVVW0Us1lLWt1LIWmWdR4YVfW1JH8A7APu\nBj5WVR9b0+x04LZx+8NJ7gJOAb7ZVl2bsaTVvRSSZlkfQ1InMepBPAF4DPCIJC9c22zCSyf2LpJs\nT7KcZHllZWVziz0G6w1XSdLQ9bGs9hnAV6pqpaoOAX8FPHVNm/3A4wCSHAf8CDBxQqCqdlTVUlUt\nLSwstFi2JM23PgJjH/DkJA9PEuB84OY1bXYCl45vPxf4RFvzF5KkZjoPjKr6DKOJ7BsZLal9CLAj\nyRuSXDRu9g7glCS3Ar8OvLbrOiVJ95dZ+uC+tLRUy8vLfZchSYOR5IaqWmrS1lODrOFV7yRpMk8N\nsoo7tSVpffYwVnGntiStz8BYxZ3akrQ+h6RWcae2JK3PwFjDq95J0mQOSUmSGjEwJEmNGBiSpEYM\nDElSIwaGJKkRA0OS1MhMnXwwyQrw1Q28ZCstXsVvSs3jMYPHPU/m8Zjh2I/78VXV6GJCMxUYG5Vk\nuelZGmfFPB4zeNx919GleTxm6Oa4HZKSJDViYEiSGpn3wNjRdwE9mMdjBo97nszjMUMHxz3XcxiS\npObmvYchSWpoLgIjyQVJvpTk1iSvnfD8Q5NcM37+M0m2dV/l5mpwzL+e5ItJ/iHJx5M8vo86N9vR\njntVu+cmqSSDX03T5JiT/Mr4731Tknd3XWMbGvwbX0zyySSfHf87v7CPOjdTkiuT3Jlk9zrPJ8l/\nG/83+YckT9rUAqpqpr+ALcCXgTOAHwI+D5y1ps2/Ad46vn0JcE3fdXdwzD8PPHx8+xVDP+amxz1u\n90jgOuB6YKnvujv4W58JfBY4aXz/1L7r7ui4dwCvGN8+C9jbd92bcNxPB54E7F7n+QuB/wUEeDLw\nmc38/fPQwzgXuLWq9lTV94G/AC5e0+Zi4J3j2+8Fzk+SDmvcbEc95qr6ZFV9Z3z3euCxHdfYhiZ/\na4DfBn4P+G6XxbWkyTG/DPiTqvo2QFXd2XGNbWhy3AX88Pj2jwC3d1hfK6rqOuBI146+GHhXjVwP\n/GiS0zbr989DYJwO3Lbq/v7xYxPbVNVh4C7glE6qa0eTY17tJYw+lQzdUY87yc8Aj6uqD3VZWIua\n/K1/AviJJP87yfVJLuisuvY0Oe7XAy9Msh/4CPCqbkrr1Ub/39+Qebji3qSewtqlYU3aDEnj40ny\nQmAJ+JetVtSNIx53kocAfwRc1lVBHWjytz6O0bDUeYx6kn+b5Oyq+seWa2tTk+N+AXBVVf1hkqcA\n/2N83Pe2X15vWn0vm4cexn7gcavuP5YHdk1/0CbJcYy6r0fq9k27JsdMkmcA/xG4qKq+11FtbTra\ncT8SOBv4VJK9jMZ4dw584rvpv+8PVtWhqvoK8CVGATJkTY77JcB7AKrq74GHMTrf0ixr9P/+sZqH\nwNgFnJnkCUl+iNGk9s41bXYCl45vPxf4RI1nkAbqqMc8Hpr574zCYhbGtOEox11Vd1XV1qraVlXb\nGM3dXFRVy/2Uuyma/Pv+AKNFDiTZymiIak+nVW6+Jse9DzgfIMkTGQXGSqdVdm8n8KLxaqknA3dV\n1R2b9cNnfkiqqg4nuRz4KKOVFVdW1U1J3gAsV9VO4B2Muqu3MupZXNJfxQ9ew2P+feBE4C/H8/v7\nquqi3oreBA2Pe6Y0POaPAs9M8kXgHuA3qupb/VX94DU87n8HvC3JaxgNy1w28A+CJLma0dDi1vHc\nzBXA8QBV9VZGczUXArcC3wFevKm/f+D//SRJHZmHISlJ0iYwMCRJjRgYkqRGDAxJUiMGhiSpEQND\nktSIgSFJasTAkFqS5JzxNQkeluQR42tRnN13XdKxcuOe1KIk/5nRKSlOAPZX1X/puSTpmBkYUovG\n5znaxejaG0+tqnt6Lkk6Zg5JSe06mdE5ux7JqKchDZY9DKlFSXYyuhrcE4DTqurynkuSjtnMn61W\n6kuSFwGHq+rdSbYAf5fkF6rqE33XJh0LexiSpEacw5AkNWJgSJIaMTAkSY0YGJKkRgwMSVIjBoYk\nqREDQ5LUiIEhSWrk/wGHmmYdLod/igAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x,y, s=8); plt.xlabel(\"x\"); plt.ylabel(\"y\"); " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You want to find **parameters** (weights) $a$ and $b$ such that you minimize the *error* between the points and the line $a\\cdot x + b$. Note that here $a$ and $b$ are unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat - y) ** 2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we believe $a = 10$ and $b = 5$ then we can compute `y_hat` which is our *prediction* and then compute our error." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.1001300495563058" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = lin(10,5,x)\n", "mse(y_hat, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse_loss(a, b, x, y): return mse(lin(a,b,x), y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.1001300495563058" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse_loss(10, 5, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for $a$ and $b$? How do we find the best *fitting* linear regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a fixed dataset $x$ and $y$ `mse_loss(a,b)` is a function of $a$ and $b$. We would like to find the values of $a$ and $b$ that minimize that function.\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((10000,), (10000,))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# generate some more data\n", "x, y = gen_fake_data(10000, 3., 8.)\n", "x.shape, y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = V(x),V(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Variable containing:\n", " 1.00000e-02 *\n", " 2.9873\n", " [torch.FloatTensor of size 1], Variable containing:\n", " 0.1116\n", " [torch.FloatTensor of size 1])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create random weights a and b, and wrap them in Variables.\n", "a = V(np.random.randn(1), requires_grad=True)\n", "b = V(np.random.randn(1), requires_grad=True)\n", "a,b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "89.19391632080078\n", "0.6885505318641663\n", "0.11982045322656631\n", "0.11007291823625565\n", "0.10528462380170822\n", "0.10161882638931274\n", "0.09879907965660095\n", "0.09662991762161255\n", "0.09496115148067474\n", "0.09367774426937103\n" ] } ], "source": [ "learning_rate = 1e-3\n", "for t in range(10000):\n", " # Forward pass: compute predicted y using operations on Variables\n", " loss = mse_loss(a,b,x,y)\n", " if t % 1000 == 0: print(loss.data[0])\n", " \n", " # Computes the gradient of loss with respect to all Variables with requires_grad=True.\n", " # After this call a.grad and b.grad will be Variables holding the gradient\n", " # of the loss with respect to a and b respectively\n", " loss.backward()\n", " \n", " # Update a and b using gradient descent; a.data and b.data are Tensors,\n", " # a.grad and b.grad are Variables and a.grad.data and b.grad.data are Tensors\n", " a.data -= learning_rate * a.grad.data\n", " b.data -= learning_rate * b.grad.data\n", " \n", " # Zero the gradients\n", " a.grad.data.zero_()\n", " b.grad.data.zero_() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nearly all of deep learning is powered by one very important algorithm: **stochastic gradient descent (SGD)**. SGD can be seeing as an approximation of **gradient descent** (GD). In GD you have to run through *all* the samples in your training set to do a single itaration. In SGD you use *only one* or *a subset* of training samples to do the update for a parameter in a particular iteration. The subset use in every iteration is called a **batch** or **minibatch**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent - Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a fixed dataset $x$ and $y$ `mse_loss(a,b)` is a function of $a$ and $b$. We would like to find the values of $a$ and $b$ that minimize that function.\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gen_fake_data2(n, a, b):\n", " x = s = np.random.uniform(0,1,n) \n", " y = lin(a,b,x) + 0.1 * np.random.normal(0,3,n)\n", " return x, np.where(y>10, 1, 0).astype(np.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = gen_fake_data2(10000, 3., 8.)\n", "x,y = V(x),V(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def nll(y_hat, y):\n", " y_hat = torch.clamp(y_hat, 1e-5, 1-1e-5)\n", " return (y*y_hat.log() + (1-y)*(1-y_hat).log()).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = V(np.random.randn(1), requires_grad=True)\n", "b = V(np.random.randn(1), requires_grad=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "bool value of Variable objects containing non-empty torch.ByteTensor is ambiguous", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAttributeError\u001b[0m: 'Variable' object has no attribute 'clip'", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m1000\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mnll\u001b[0;34m(y_hat, y)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mnll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1e-5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1e-5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mclip\u001b[0;34m(a, a_min, a_max, out)\u001b[0m\n\u001b[1;32m 1705\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1706\u001b[0m \"\"\"\n\u001b[0;32m-> 1707\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'clip'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_min\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_max\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1708\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# a downstream library like 'pandas'.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mAttributeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 68\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapit\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mwrap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwrap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m raise RuntimeError(\"bool value of Variable objects containing non-empty \" +\n\u001b[0;32m--> 123\u001b[0;31m torch.typename(self.data) + \" is ambiguous\")\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0m__nonzero__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: bool value of Variable objects containing non-empty torch.ByteTensor is ambiguous" ] } ], "source": [ "learning_rate = 1e-2\n", "for t in range(3000):\n", " p = (-lin(a,b,x)).exp()\n", " y_hat = 1/(1+p)\n", " loss = nll(y_hat,y)\n", " if t % 1000 == 0:\n", " print(loss.data[0], np.mean(to_np(y)==(to_np(y_hat)>0.5)))\n", "# print(y_hat)\n", " \n", " loss.backward()\n", " a.data -= learning_rate * a.grad.data\n", " b.data -= learning_rate * b.grad.data\n", " a.grad.data.zero_()\n", " b.grad.data.zero_() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nearly all of deep learning is powered by one very important algorithm: **stochastic gradient descent (SGD)**. SGD can be seeing as an approximation of **gradient descent** (GD). In GD you have to run through *all* the samples in your training set to do a single itaration. In SGD you use *only one* or *a subset* of training samples to do the update for a parameter in a particular iteration. The subset use in every iteration is called a **batch** or **minibatch**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient descent with numpy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import rcParams, animation, rc\n", "from ipywidgets import interact, interactive, fixed\n", "from ipywidgets.widgets import *\n", "rc('animation', html='html5')\n", "rcParams['figure.figsize'] = 3, 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = gen_fake_data(50, 3., 8.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "65.167827371047636" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a_guess,b_guess = -1., 1.\n", "mse_loss(a_guess, b_guess, x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=0.01\n", "def upd():\n", " global a_guess, b_guess\n", " y_pred = lin(a_guess, b_guess, x)\n", " dydb = 2 * (y_pred - y)\n", " dyda = x*dydb\n", " a_guess -= lr*dyda.mean()\n", " b_guess -= lr*dydb.mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig = plt.figure(dpi=100, figsize=(5, 4))\n", "plt.scatter(x,y)\n", "line, = plt.plot(x,lin(a_guess,b_guess,x))\n", "plt.close()\n", "\n", "def animate(i):\n", " line.set_ydata(lin(a_guess,b_guess,x))\n", " for i in range(30): upd()\n", " return line,\n", "\n", "ani = animation.FuncAnimation(fig, animate, np.arange(0, 20), interval=100)\n", "ani" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }