{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import scipy as sp\n", "import matplotlib.pyplot as plt\n", "import random\n", "from scipy import stats\n", "from scipy.optimize import fmin" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Gradient descent, also known as steepest descent, is an optimization algorithm for finding the local minimum of a function. To find a local minimum, the function \"steps\" in the direction of the negative of the gradient. Gradient ascent is the same as gradient descent, except that it steps in the direction of the positive of the gradient and therefore finds local maximums instead of minimums. The algorithm of gradient descent can be outlined as follows:\n", "\n", "    1:   Choose initial guess $x_0$
\n", "    2:   for k = 0, 1, 2, ... do
\n", "    3:       $s_k$ = -$\\nabla f(x_k)$
\n", "    4:       choose $\\alpha_k$ to minimize $f(x_k+\\alpha_k s_k)$
\n", "    5:       $x_{k+1} = x_k + \\alpha_k s_k$
\n", "    6:   end for" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a simple example, let's find a local minimum for the function $f(x) = x^3-2x^2+2$" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "f = lambda x: x**3-2*x**2+2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deVyVZf7/8dfnsIiACAi4oeKCC2puuGeZLbbbvu81VlNTzTRrM9PM9O1Xs3yrmVYza2qqqcZssdIcK/cd3BFBFhUQkEVZBYFz/f7gOF/GWA544D7L5/l48Ohwzs3hfQ6nt/e5znVftxhjUEop5f1sVgdQSinVNbTwlVLKR2jhK6WUj9DCV0opH6GFr5RSPkILXymlfESbhS8iQSKyVUR2iUiKiPyhmW26ichHIpIhIltEJK4zwiqllOo4Z/bwa4E5xphxwHjgYhGZdto29wLHjDHDgBeAP7k2plJKqTPVZuGbRpWObwMcX6cfrTUPeMdx+WPgfBERl6VUSil1xvyd2UhE/IBkYBjwijFmy2mb9AdyAIwx9SJSBvQCik+7n/nAfICQkJBJI0eOPLP0SillEWMgJb+MyJBA+vXs3mW/Nzk5udgYE92Rn3Wq8I0xDcB4EQkHPhWRMcaYve39ZcaYhcBCgMTERJOUlNTeu1BKKbew4/Axrn51I6/dOpFLxvbtst8rIoc6+rPtmqVjjDkOrAIuPu2mPGCAI4w/0BMo6WgopZRyd0kHjwEwKS7C4iTOc2aWTrRjzx4R6Q5cCOw/bbOlwJ2Oy9cB3xldlU0p5cW2HSwlrlcwMT2CrI7iNGeGdPoC7zjG8W3Av4wxX4rIU0CSMWYp8CbwrohkAKXATZ2WWCmlLGaMIenQMeaMjLE6Sru0WfjGmN3AhGauf7LJ5RrgetdGU0op95RVXEVp1Ukme9BwDuiRtkop1W5JB0sBSIyLtDhJ+2jhK6VUO207eIzIkECGRIVYHaVdtPCVUqqdkg6WkjgoAk87vlQLXyml2uFoRQ0HS6qZ7GHDOaCFr5RS7ZLsmH+f6GEf2IIWvlJKtcu2g8cICrAxul9Pq6O0mxa+Ukq1Q9KhUsYPCCfQ3/Pq0/MSK6WURapq60k5Uu6R4/egha+UUk7bcfg4DXbDpEGeN34PWvhKKeW0Ldkl+NnE4w64OkULXymlnLQ5q4Sx/XsS2s2pleXdjha+Uko54cTJBnbmHGfakF5WR+kwLXyllHLC9sPHqGswTBvimcM5oIWvlFJO2Zzl2eP3oIWvlFJO8fTxe9DCV0qpNnnD+D1o4SulVJu8YfwetPCVUqpN3jB+D1r4SinVJm8YvwctfKWUapW3jN+DFr5SSrXKW8bvQQtfKaVa5S3j96CFr5RSrfKW8XvQwldKqRZV1dZ7zfg9aOErpVSLth4spa7BcPawKKujuIQWvlJKtWDDgWIC/W0eecLy5mjhK6VUC9ZnFDM5LoKgAD+ro7hEm4UvIgNEZJWI7BORFBF5tJltZotImYjsdHw92TlxlVKqaxRV1LK/oIKZXjKcA+DMx871wOPGmO0i0gNIFpGVxph9p223zhhzuesjKqVU19uYWQzArGHRFidxnTb38I0x+caY7Y7LFUAq0L+zgymllJXWHygmPDiAhH5hVkdxmXaN4YtIHDAB2NLMzdNFZJeILBeR0S7IppRSljDGsCGjmBlDe+FnE6vjuIzThS8iocAS4DFjTPlpN28HBhljxgEvAZ+1cB/zRSRJRJKKioo6mlkppTpVVnEVR8pqvGr8HpwsfBEJoLHs3zfGfHL67caYcmNMpePyMiBARL73TBljFhpjEo0xidHR3jMuppTyLhsyGsfvvWX+/SnOzNIR4E0g1RjzfAvb9HFsh4hMcdxviSuDKqVUV1l/oJjYiO4MjAy2OopLOTNLZyZwO7BHRHY6rnsCGAhgjFkAXAc8KCL1wAngJmOM6YS8SinVqeob7GzKKuGysX1x7Md6jTYL3xizHmj1URtjXgZedlUopZSyyu68Mipq6r1u/B70SFullPova9OLEEELXymlvN3qtCLGxYYTGRJodRSX08JXSimH0qqT7Mo9zrnDvXMWoRa+Uko5rDtQhDEwe4QWvlJKebU1aUVEBAdwVmy41VE6hRa+UkoBdrth7YEiZsVHe9VyCk1p4SulFJBypJziypNeO5wDWvhKKQXAmvSjAMyK18JXSimvtjqtiLH9exLdo5vVUTqNFr5SyueVVdex/fAxr52OeYoWvlLK563PKMbuxdMxT9HCV0r5vNVpRwkL8mf8AO+cjnmKFr5SyqfZ7YZVaUc5d0QM/n7eXYne/eiUUqoNO3OPU1x5kgtGxVgdpdNp4SulfNo3+wrxswmzh2vhK6WUV/s29SiT4yLoGRxgdZROp4WvlPJZOaXVpBVWcMGo3lZH6RJa+Eopn/VNaiGAFr5SSnm7b1ILGRYTSlxUiNVRuoQWvlLKJ5XX1LElq5TzfWB2zila+Eopn7QmrYh6u/GZ4RzQwldK+ahvUwuJCA5g4sAIq6N0GS18pZTPqWuwsyqtiPNGxnjtyU6ao4WvlPI5W7JKKTtRx0UJfayO0qW08JVSPmf53ny6B/h5/XLIp9PCV0r5lAa7YUVKIeeNjKZ7oJ/VcbqUFr5SyqckHzpGcWUtF4/pa3WULqeFr5TyKV/vLSDQz8ackb4z//6UNgtfRAaIyCoR2SciKSLyaDPbiIi8KCIZIrJbRCZ2TlyllOo4YwwrUgqYFR9FaDd/q+N0OWf28OuBx40xCcA04CERSThtm0uAeMfXfOA1l6ZUSikX2J1bRt7xE1w8xrdm55zS5j9xxph8IN9xuUJEUoH+wL4mm80D/mGMMcBmEQkXkb6On1UWqKytZ29eGfvzy0krrKCgrIaiylqOVdVhjMEA3fxtRIQEEhkcyIDIYIbGhDI8JpSzYsN97sMs5RuW7y3A3yZcmOA7R9c21a73NCISB0wAtpx2U38gp8n3uY7r/qvwRWQ+je8AGDhwYPuSqjZlF1exbE8+a9KL2H7oGPV2A0BEcACxEcFEh3ZjeEwPbDZBgJp6O8erT3KkrIZNWSVUn2wAwN8mjI3tydTBvbgwIYYJAyKw+dDBKco7GWP4em8+04f2Ijw40Oo4lnC68EUkFFgCPGaMKe/ILzPGLAQWAiQmJpqO3If6bzV1DXy2I4/FybkkHzoGwJj+YfzgnCFMGRzJ6L5hRPfohkjrhW2MIb+shv0F5Ww7eIxt2aW8uT6LBWsyie7RjYtH9+G6SbGcFduzzftSyh2lFVZwsKSaH5wzxOoolnGq8EUkgMayf98Y80kzm+QBA5p8H+u4TnWSipo63t9ymEXrsimurCU+JpRfXTKSqyb0p3dYULvvT0ToF96dfuHdmTOy8e1ueU0dq/YfZUVKAYuTc3h38yES+oZx89SBXDuxP8GBvvehl/JcX+7Kxyb43NG1TbX5f6w07s69CaQaY55vYbOlwMMi8iEwFSjT8fvO0WA3/Csph/9dkUZJ1UlmxUfx4OzxTB/Sy+V73mFBAcwb35954/tTXlPH5zvy+OfWHH772V5eWJnO3TPiuGN6nE+cGk55NmMMS3cdYcbQKKJ7dLM6jmWc2UWbCdwO7BGRnY7rngAGAhhjFgDLgEuBDKAauNv1UdWunOM88ekeUo6UMyUukjfvGsX4AeFd8rvDggK4fXoct00bRNKhY7y2OpPnVqazYE0m95w9mPnnDKFHkBa/ck+7css4XFrNw+cNszqKpZyZpbMeaHXX0TE75yFXhVL/ra7BzkvfZfDKqgyiQ7vx0s0TuPysvpaMpYsIk+MimXxXJKn55by8KoOXvsvgn1sO8+gF8dw8ZSABfno8n3IvS3ceIdDPxlwfnY55iv6f6eZySqu55tWNvPjtAeaN78eKH5/DFeP6ucUHp6P6hvHKLRP5/KGZDIsJ5cnPU5j717VszCy2OppS/9FgN3y5+wizR0TTs7tvvwvVwndj6w8Uc8XL6zlUUsVrt07k+RvGu+ULdtyAcD6cP40370ykvsFwyxtb+MlHOymurLU6mlJsyS7haEUtV47vZ3UUy+k0Cze1aF0WzyxLZVhMKAtvT3T7kyyLCOeP6s3MYVG8siqDBWsy+Xb/UX57eQLXTuzvFu9IlG/6YtcRQgL9OH+kbx5s1ZTu4bsZu93wzLJUnv4qlYsS+vDpD2e6fdk3FRTgx+MXjWD5o7MY0bsHP128i/vfTda9fWWJk/V2lu0p4MKE3nr0OFr4bqWuwc7PPt7NwrVZ3Dl9EK/eOpEQD13gaVhMDz6cP43fXDaK1elFzH1hLf9OKbA6lvIx6w4UUXaiTodzHLTw3UR9g51HP9zBku25/OTC4fz+ytEev5yBzSbcN2sIXzx8Nr3Dgpj/bjK/+WwPNXUNVkdTPuLznUcIDw7g7GG+dWarlmjhu4EGu+Gni3exbE8Bv7lsFI+cH+9VY94j+vTgs4dmMv+cIby3+TDXLdjIoZIqq2MpL1deU8e/9xVw2di+BPpr1YEWvuXsdsMTn+zhs51H+NncEdw3yzvX+Qj0t/HEpaNYdEciOaUnuPzF9Xy9Vw/GVp1n2e58aursXJ84oO2NfYQWvsWeW5nGR0k5/GjOMB7ygaMAL0jozZc/OpshMaE88N52nl2eSoNd19FTrvdxci7DYkIZF9vT6ihuQwvfQv/alsMrqzK5ecoAfnLhcKvjdJkBkcEsvn86t04dyOtrsrjvnW2U19RZHUt5keziKpIOHePaibFeNTx6prTwLbL+QDFPfLqHWfFRPDVvjM+9KAP9bfy/q8fy9FVjWHegmKte2UBWUaXVsZSX+GR7LjaBqyf0tzqKW9HCt8ChkioefD+ZYTGhvHrrRJ9ee+a2aYN4/76pHK+uY94rG1iTXmR1JOXh7HbDkuRcZsVH06dn+5cK92a+2zQWOXGygfvfTcbPJrxxR6KuMAlMHdKLpQ/PJDYimHve3saHWw9bHUl5sE1ZJRwpq+G6SbFWR3E7WvhdyBjDrz/dQ1phBX+9cTwDIoOtjuQ2YiOCWfzAdGYOi+KXn+zhuX+n0bgIq1Lt83FyLj2C/H32vLWt0cLvQh9szeGTHXk8dv5wZo+IsTqO2wnt5s+bdyZyY+IAXvoug8f/tYuT9XarYykPUl5Tx9d7C7hiXD+CAnQphdN55nH7HijjaCVPfZnCrPgofjTH+6dfdlSAn40/XjuW2IjuPLcynYLyGl67bZJbrhKq3M/nO/I4UdfAjTr3vlm6h98FTtbbeeyjHQQH+vPc9eM8fsmEziYi/Oj8eJ6/YRxbs0u58fVNHK2osTqWcnPGGN7fcpjR/cI4S+feN0sLvws8vzKdvXnl/PGascR04ATjvuqaibH8/e7JHC6t5voFm8gprbY6knJjO3KOs7+gglumDvS5ac7O0sLvZEkHS3l9bePBVReN9u3Tq3XErPho3nNM27xuwUbSCyusjqTc1D+3HCYk0I9543XufUu08DtRTV0Dv1iym349u/ObyxKsjuOxJg6M4KP7p2E3cMPrm9iZc9zqSMrNlJ2o48vdR5g3oT+hHrqkeFfQwu9Er6zKILOoimevGeux69q7i5F9wljywAx6BPlz6xub2Zih581V/+fT7bnU1Nm5ZcpAq6O4NS38TpKaX85rqzO5ZmJ/zhmua3G7wsBewXz8wAxiI4K56+/bWLmv0OpIyg0YY/jn1sOMi+3JmP76YW1rtPA7QYPd8Mslu+nZPYDf6lCOS/UOC+Kj+6cxql8YD76XzPI9usSyr9ucVUp6YSW3Th1kdRS3p4XfCf655RC7cst48ooEIkICrY7jdcKDA3n33imMGxDOwx/s4ItdR6yOpCz09w3ZRAQH6GkMnaCF72LHqk7yv/9OZ8bQXlw5Tl+AnSUsKIB37pnCpEERPPrhDj7Znmt1JGWBnNJqVqYWcsvUgXpkrRO08F3s+ZXpVNbW87srRutc4E4W2s2ft++ezLQhvXh88S7+tS3H6kiqi72z8SB+Itw+Lc7qKB5BC9+F9h0p5/0th7h92iBG9OlhdRyfEBzoz1t3TebsYVH8fMlu3t9yyOpIqotU1dbzUVIOl4ztq8sgO6nNwheRt0TkqIjsbeH22SJSJiI7HV9Puj6m+zPG8PsvUujZPYAfX+A7Z69yB0EBfrxxRyJzRsbw60/38vaGbKsjqS6wZHsuFTX13D0zzuooHsOZPfy3gYvb2GadMWa84+upM4/leVakFLA1u5Sfzh1Bz2Bd6KurBQX4seC2SVyU0Jvff7GPReuyrI6kOpHdbnh740HGDQhn4sAIq+N4jDYL3xizFijtgiweq77Bzp9XpBEfE8pNk/XAD6sE+tt45daJXDq2D09/lcrCtZlWR1Kd5Nv9R8kqquIe3btvF1eN4U8XkV0islxERre0kYjMF5EkEUkqKvKeU9ktTs4lq6iKn80dgZ+uhGmpAD8bf7tpAped1Zdnlu3ntdVa+t7GGMOrqzOIjejOZWP7Wh3Ho7jieP/twCBjTKWIXAp8BsQ3t6ExZiGwECAxMdErTmd04mQDf/0mnUmDIvQMO24iwM/G324cj58If/p6P3ZjeOg8PQeBt9iaXcqOw8f5n3mj8ffh80F3xBkXvjGmvMnlZSLyqohEGWN8YrGTtzcepLC8lpdunqjTMN2Iv5+N528Yh03gLyvSsNsNPzq/2f0Q5WFeW5NJr5BArteTnLTbGRe+iPQBCo0xRkSm0DhMVHLGyTxAWXUdr63OYM7IGKYMjrQ6jjqNv5+N524Yj02E51am02AMj+kMKo+270g5q9OK+OlFw/VAqw5os/BF5ANgNhAlIrnA74AAAGPMAuA64EERqQdOADcZHzn79Jsbsimvqednc0dYHUW1wM8m/MVxlrG/fnMAu4EfXxCv78Y81II1mYQE+umBVh3UZuEbY25u4/aXgZddlshDlJ2o4+8bsrlkTB9G9Q2zOo5qhZ9N+PO1Z2ETePHbAxhj+MmFw7X0PczB4iq+3H2Ee88erFOfO0gXae+gdzYepKKmnof1hOQewWYT/njNWdhEeOm7DBrshp/NHaGl70Fe/O4Agf42fnDOEKujeCwt/A6oqKnjzfXZXJjQm9H9dP1tT2GzCc9cPRabTXh1dSZ2A7+4WEvfE2QWVfLZjjzumzWEmB66jEJHaeF3wD82HaLsRB2PzNFZH57GZhOenjcGmzSOB9uN4VeXjNTSd3MvfnuAoAA/7te9+zOihd9OVbX1LFqXxXkjohkbq3v3nshmE/5n3hhsIixcm0WD3fCby0Zp6bupA4UVLN11hAfOHUqv0G5Wx/FoWvjt9MHWwxyrrtM53R5ORPjDlaOxifDm+mzsxvDk5Qla+m7or98eIDjAj/mzdO/+TGnht0Ndg52/bzjIlMGRumCTFxARfndFAjYR3tqQjTHwuyu09N3J3rwyvtqdz8PnDdOzx7mAFn47LNuTT97xE/zhyhaXC1IeRkT47eWj8LPBG+uyabCbxj1/XRPJcsYYnlmWSkRwAPPP1b17V9DCd5IxhoVrsxgaHcKckTFWx1EuJCI8cekobDbh9TVZ2I1pHOPX0rfU6rQiNmaW8PsrEggL0nn3rqCF76RNmSWkHCnnj9eM1SLwQiLCLy8eiU2E11Y3zt75f1fp39oq9Q12nlmWyuCoEG6ZOsjqOF5DC99Jr6/NIiq0G1dN6G91FNVJRISfzx2Bnwgvr8rAbodn9R94SyxOzuXA0UoW3DaRQH9dEdNVtPCdkFZQwZp0XbDJF4gIj180HJtNePHbAzQYw5+uPUvPc9CFKmvreX5lOomDIpg7uo/VcbyKFr4T3tl0kG7+Nm7Vt5Y+QUT4yYXDsQmOBdcMf7lunJZ+F/nbN+kUV9byxh2JOmPKxbTw21B2oo5Pt+cxb3w/nRbmYx67YDg2EZ5fmY4x8L/Xa+l3trSCCt7acJCbJg9g/IBwq+N4HS38NixJzuVEXQN3TI+zOoqywCPnxzcusbwiDbsxPHf9OD3LUicxxvDbz/YSFuTPz+eOtDqOV9LCb4Xdbnh38yEmDgxnTH9dRsFXPXTeMETgz1+nYTfwwg1a+p3hk+15bD1Yyh+vGavvpjuJFn4r1mcUk11cxaM3jrc6irLYD2cPw0+EZ5fvx243/PWm8QRo6btMWXUdzy5PZcLAcG7QUxd2Gi38Vvxj00GiQgO5ZKzOFFBw/7lD8bMJT3+Vit0YXrx5gpa+i/zhyxSOVdfxzj16wFtn0ldrC3JKq/l2/1FumjyQbv46FVM1um/WEH57eQLL9xbw8D+3c7LebnUkj/dtaiGfbM/jodlD9fwSnUwLvwXvbzmMTYRbpg60OopyM/eePZjfXZHAipRC5r+bRPXJeqsjeayy6jqe+HQPI/v04GE9v0Sn08JvRl2DnY+Tc5gzMoZ+4d2tjqPc0N0zB/PsNWNZm17ELW9sobTqpNWRPNJTX+6juPIkf7lunB5R2wX0GW7Gt6lHKa48yU2T9cMj1bKbpwzktdsmkZpfznULNpJTWm11JI+ydNcRlmzP5aHZQ/VkQl1EC78Z/0rKIaZHN84dHm11FOXm5o7uw3v3TaW4opZrX9tIan651ZE8wuGSan79yR4mDgznET2ZUJfRwj9NQVkNq9OOcn1irM61Vk6ZHBfJxw/OwM8m3LBgE5syS6yO5NbqGuw88uEOEPjbTRP0/7MupM/0aZZsz8Vu0LnAql2G9+7Bkgdn0KdnEHe+tZUlyblWR3Jbf/56PztzjvPHa85iQGSw1XF8ihZ+E3a74aNtOUwbEsmgXiFWx1Eepl94dz5+YAaTB0fw+OJd/PnrxoO01P/5fGceb6zL5o7pg7jsrL5Wx/E5WvhNbM4u4XBpNTfqh7Wqg3oGB/D23VO4ecpAXl2dyQ/f367TNh325pXxiyW7mTI4kt9enmB1HJ+khd/Ev7bl0CPIn0vG6J6H6rgAPxvPXD2G31w2ihX7Crjx9c0UlNVYHctSRRW13P9uMhHBgbx660Q9QtkibT7rIvKWiBwVkb0t3C4i8qKIZIjIbhGZ6PqYna+8po7lewuYN76fnuREnTER4b5ZQ1h0RyJZRZVc/tJ6tmT55oe5VbX13PP2NkqrTvL67ZOICu1mdSSf5cw/s28DF7dy+yVAvONrPvDamcfqel/vKaC23s51k3Q4R7nO+aN68+lDMwkL8ueWRVtYtC4LY3xnXL+uwc6D729nX345r9w6gbNidY17K7VZ+MaYtUBpK5vMA/5hGm0GwkXE48ZEPt2Rx+CoEMbpASDKxYb37sHnD8/kglExPP1VKg//cweVtd4/rm+3G37x8W7WphfxzNVjmDOyt9WRfJ4rBtL6AzlNvs91XPc9IjJfRJJEJKmoqMgFv9o1jhw/websEq4a319PqaY6RY+gABbcNolfXTKS5XvzmffyelKOlFkdq9PY7YZfLNnNJzvyePzC4dw4Wdekcgdd+smJMWahMSbRGJMYHe0+R7Eu3XUEY2De+H5WR1FeTES4/9yhvHffVCpq6rn6lY0sWpfldVM3T5X94uRcHj0/nh/pkbRuwxWFnwc0HfiOdVznMT7bkceEgeHERence9X5ZgyN4uvHzuHcEdE8/VUqd7y1lcJy75jFU1vfwGMf7fxP2f/4wuFWR1JNuKLwlwJ3OGbrTAPKjDH5LrjfLpGaX87+ggquntDsKJRSnSIyJJCFt0/imavHknSolIteWMvHybke/YFu2Yk67nxrK0t3HeHnF4/QsndDbZ7xSkQ+AGYDUSKSC/wOCAAwxiwAlgGXAhlANXB3Z4XtDJ/tzMPfJlw21uM+Z1YeThznW5g6JJKff7ybny7exec783jm6rEet+TAoZIqfvCPJLKLq/jrjeO5Sneg3JJYtUeRmJhokpKSLPndp9jthhl//I7R/cJ4867JlmZRvs1uN7y35RB/Wr4fu4HHLojnrplxHnG2tRUpBfx08S4EWHDbJGYMi7I6klcTkWRjTGJHftanD3fbnF1CQXmN7o0oy9lswh3T41j5k3OZOawXzy7fz0UvrGVFSoHbDvPU1DXw9Jf7uP/dZAZHhfDVI7O07N2cTxf+l7vzCQ7044JROj9YuYd+4d1ZdOdk3rlnCoF+Nu5/N5mb39jM9sPHrI72X5IOlnLpi+tYtD6b26YNZPED0z1uGMoXtTmG763qG+x8vbeA80f1pnug+79tVr7l3OHRzHx0Fh9sy+GFlelc8+pGZsVH8aM58UwZHGlZrqMVNbyw8gAfbjtMv57deffeKcyKd58p1qp1Plv4W7JLKa06yWVj+1gdRalm+fvZuH3aIK6Z0J/3txxi4dosbnh9E4mDIrhjRhwXj+7TZeeBLauu460N2byxLouT9XbumhHHTy8aQUg3n60Qj+Szf61TwzmzR8RYHUWpVoV082f+OUO5fVocH2w9zNsbD/LIBzuICu3GdZNiuWJcXxL6hnXKUeJZRZW8s/Egi5NzqT7ZwGVj+/KzuSP0mBUP5ZOFX99gZ0VK43COroypPEX3QD/uOXswd82IY82BIt7bdIg31mWxYE0mQ6JDmDu6DzOHRpEYF9Hh17UxhkMl1XyTWsgXu46wK7eMAD/hynH9uffswST0C3Pxo1JdyScLf3PWqeEcnXuvPI/NJpw3IobzRsRQWnWS5Xvz+XJXPm+szeK11ZkE+ts4q39PRvUNY2TfHgyICKZPzyCiQrsR6G/D3yY02A1lJ+ooO1HHoZJqsoor2Z9fwdbsUgocR/2O6R/Gry4ZydUT+hMTFmTxo1au4JOF/9WeI4QE+jF7hH7YpDxbZEggt04dxK1TB1FZW8/W7BI2ZJSwO/c4n+3Io2Kz86ty9gkLIjEugmlDenH2sCgdtvFCPlf4TWfn6HCO8iah3fyZM7L3f5YhNsaQd/wER47XUFheQ3FlLXUNduoaDH42oWf3AHp2DyA2ojuDo0LoERRg8SNQnc3nCn9TVgnHquv0BMrK64kIsRHBxEbo/HjVyOcOvFq2J5+QQD/OHa7DOUop3+JThd9gN6zcV8h5I2N0OEcp5XN8qvB3HD5GceVJ5o7Wg62UUr7Hpwr/3/sKCfATnZ2jlPJJPlP4xhhWpBQwY2iUzkZQSvkknyn8A0crOYIEQFgAAAlvSURBVFRSzUWjdWVMpZRv8pnC/3dKAQAX6lLISikf5TOFv3JfIRMGhush4kopn+UThZ9fdoJduWVclKCzc5RSvssnCv+bfYUAXJigwzlKKd/lE4X/732FDIkOYVhMqNVRlFLKMl5f+OU1dWzKLNG9e6WUz/P6wl+XXky93ejsHKWUz/P6wv9u/1F6dg9g/IBwq6MopZSlvLrw7XbDmvSjnDs8Gn8/r36oSinVJq9uwT15ZRRXnmTOSD1RuVJKeXXhf7f/KCLo2vdKKYWThS8iF4tImohkiMgvm7n9LhEpEpGdjq/7XB+1/ValHWXCgHAiQgKtjqKUUpZrs/BFxA94BbgESABuFpGEZjb9yBgz3vG1yMU5262oopbduWU6nKOUUg7O7OFPATKMMVnGmJPAh8C8zo115lanHQXgPC18pZQCnCv8/kBOk+9zHded7loR2S0iH4vIAJekOwOr0o7SO6wbCX3DrI6ilFJuwVUf2n4BxBljzgJWAu80t5GIzBeRJBFJKioqctGv/r66Bjvr0os5b0QMItJpv0cppTyJM4WfBzTdY491XPcfxpgSY0yt49tFwKTm7sgYs9AYk2iMSYyO7ryZM0kHj1FRW6/DOUop1YQzhb8NiBeRwSISCNwELG26gYj0bfLtlUCq6yK235r0IgL8hJnDoqyMoZRSbsW/rQ2MMfUi8jCwAvAD3jLGpIjIU0CSMWYp8IiIXAnUA6XAXZ2YuU3rDhQxcWAEod3afHhKKeUznGpEY8wyYNlp1z3Z5PKvgF+5NlrHFFfWknKknJ/NHWF1FKWUcited6TthoxiAGbF63COUko15XWFvza9mIjgAEb362l1FKWUciteVfjGGNYdKGLmsCj8bDodUymlmvKqwk8vrORoRS3nxOtiaUopdTqvKvx1BxoP5jpbx++VUup7vKrw1x4oZlhMKP3Cu1sdRSml3I7XFH5NXQNbskp0do5SSrXAawo/6eAxauvtOn6vlFIt8JrCX5fRuJzC1CGRVkdRSim35D2Fn15M4qBIggN1OQWllGqOVxR+cWUt+/LLdXaOUkq1wisKf1NmCYCujqmUUq3wjsLPKqFHN3/G9NOzWymlVEu8o/AzS5gyOBJ/P694OEop1Sk8viELymrILq5i+tBeVkdRSim35vGFvymrcTlkLXyllGqdxxf+xowSwoMDGNVHx++VUqo1Hl/4m7JKmDo4Epsuh6yUUq3y6MLPKa0m99gJZgzV6ZhKKdUWjy78U/PvdfxeKaXa5tmFn1VCVGgg8TGhVkdRSim357GFb4xhU2YJ04b0QkTH75VSqi0eW/jZxVUUlNfocI5SSjnJYwt/U1bj+L1+YKuUUs7x3MLPLKFPWBBxvYKtjqKUUh7BIwvfGMOW7FKmDYnU8XullHKSRxb+wZJqiipqmTJYx++VUspZHln427JLAZgyOMLiJEop5TmcKnwRuVhE0kQkQ0R+2czt3UTkI8ftW0QkztVBm9p6sJTIkECGRuv8e6WUclabhS8ifsArwCVAAnCziCScttm9wDFjzDDgBeBPrg7a1NbsUhIHRej4vVJKtYMze/hTgAxjTJYx5iTwITDvtG3mAe84Ln8MnC+d1MaF5TUcLq1myuDIzrh7pZTyWv5ObNMfyGnyfS4wtaVtjDH1IlIG9AKKm24kIvOB+Y5va0Vkb0dCA/zgT/CDjv6wa0Rx2uPzMJrfWp6c35Ozg+fnH9HRH3Sm8F3GGLMQWAggIknGmMSu/P2upPmtpfmt48nZwTvyd/RnnRnSyQMGNPk+1nFds9uIiD/QEyjpaCillFKu50zhbwPiRWSwiAQCNwFLT9tmKXCn4/J1wHfGGOO6mEoppc5Um0M6jjH5h4EVgB/wljEmRUSeApKMMUuBN4F3RSQDKKXxH4W2LDyD3O5A81tL81vHk7ODD+cX3RFXSinf4JFH2iqllGo/LXyllPIRXVb4InK9iKSIiF1EWpwS1dYyDlYRkUgRWSkiBxz/bXYhHxFpEJGdjq/TP9zucu62LEZ7OJH9LhEpavJ832dFzpaIyFsicrSl402k0YuOx7dbRCZ2dcaWOJF9toiUNXnun+zqjK0RkQEiskpE9jl659FmtnHn59+Z/O3/GxhjuuQLGEXjAQOrgcQWtvEDMoEhQCCwC0joqoxt5P8z8EvH5V8Cf2phu0qrs7bn+QR+CCxwXL4J+Mjq3O3IfhfwstVZW3kM5wATgb0t3H4psBwQYBqwxerM7cg+G/jS6pyt5O8LTHRc7gGkN/P6cefn35n87f4bdNkevjEm1RiT1sZmzizjYJWmy0e8A1xlYRZnudWyGO3kzq8Fpxhj1tI4a60l84B/mEabgXAR6ds16VrnRHa3ZozJN8Zsd1yuAFJpXBGgKXd+/p3J327uNobf3DIOZ/wgXaS3MSbfcbkA6N3CdkEikiQim0XE6n8UnHk+/2tZDODUshhWc/a1cK3j7fjHIjKgmdvdmTu/3p0xXUR2ichyERltdZiWOIYpJwBbTrvJI57/VvJDO/8GLl1aQUS+Afo0c9OvjTGfu/J3dYbW8jf9xhhjRKSl+ayDjDF5IjIE+E5E9hhjMl2dVQHwBfCBMaZWRO6n8Z3KHIsz+YrtNL7WK0XkUuAzIN7iTN8jIqHAEuAxY0y51Xnaq4387f4buLTwjTEXnOFdOLOMQ6dpLb+IFIpIX2NMvuNt39EW7iPP8d8sEVlN47/MVhV+e5bFyHWzZTHazG6MaZpzEY2fs3gSS1/vZ6Jp+RhjlonIqyISZYxxm0XJRCSAxrJ83xjzSTObuPXz31b+jvwN3G1Ix5llHKzSdPmIO4HvvWMRkQgR6ea4HAXMBPZ1WcLv8+RlMdrMftp465U0jnN6kqXAHY7ZItOAsibDhm5NRPqc+qxHRKbQ2CXusKMANM7AoXEFgFRjzPMtbOa2z78z+Tv0N+jCT52vpnGMrBYoBFY4ru8HLGuy3aU0fiKdSeNQkOWfmDty9QK+BQ4A3wCRjusTgUWOyzOAPTTOKNkD3OsGub/3fAJPAVc6LgcBi4EMYCswxOrM7cj+LJDieL5XASOtznxa/g+AfKDO8dq/F3gAeMBxu9B4cqFMx+ul2dlrbpr94SbP/WZghtWZT8t/NmCA3cBOx9elHvT8O5O/3X8DXVpBKaV8hLsN6SillOokWvhKKeUjtPCVUspHaOErpZSP0MJXSikfoYWvlFI+QgtfKaV8xP8HtIDZrHTvyNMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = np.linspace(-1,2.5,1000)\n", "plt.plot(x,f(x))\n", "plt.xlim([-1,2.5])\n", "plt.ylim([0,3])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see from plot above that our local minimum is gonna be near around 1.4 or 1.5 (on the x-axis), but let's pretend that we don't know that, so we set our starting point (arbitrarily, in this case) at $x_0 = 2$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local minimum occurs at: 1.3334253508453249\n", "Number of steps: 17\n" ] } ], "source": [ "x_old = 0\n", "x_new = 2 # The algorithm starts at x=2\n", "n_k = 0.1 # step size\n", "precision = 0.0001\n", "\n", "x_list, y_list = [x_new], [f(x_new)]\n", "\n", "# returns the value of the derivative of our function\n", "def f_prime(x):\n", " return 3*x**2-4*x\n", " \n", "while abs(x_new - x_old) > precision:\n", " x_old = x_new\n", " s_k = -f_prime(x_old)\n", " x_new = x_old + n_k * s_k\n", " x_list.append(x_new)\n", " y_list.append(f(x_new))\n", "print(\"Local minimum occurs at:\", x_new)\n", "print(\"Number of steps:\", len(x_list))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The figures below show the route that was taken to find the local minimum." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=[10,3])\n", "plt.subplot(1,2,1)\n", "plt.scatter(x_list,y_list,c=\"r\")\n", "plt.plot(x_list,y_list,c=\"r\")\n", "plt.plot(x,f(x), c=\"b\")\n", "plt.xlim([-1,2.5])\n", "plt.ylim([0,3])\n", "plt.title(\"Gradient descent\")\n", "plt.subplot(1,2,2)\n", "plt.scatter(x_list,y_list,c=\"r\")\n", "plt.plot(x_list,y_list,c=\"r\")\n", "plt.plot(x,f(x), c=\"b\")\n", "plt.xlim([1.2,2.1])\n", "plt.ylim([0,3])\n", "plt.title(\"Gradient descent (zoomed in)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You'll notice that the step size (also called learning rate) in the implementation above is constant, unlike the algorithm in the pseudocode. Doing this makes it easier to implement the algorithm. However, it also presents some issues: If the step size is too small, then convergence will be very slow, but if we make it too large, then the method may fail to converge at all. \n", "\n", "A solution to this is to use adaptive step sizes as the algorithm below does (using scipy's fmin function to find optimal step sizes):" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local minimum occurs at 1.3333333284505209\n", "Number of steps: 4\n" ] } ], "source": [ "# we setup this function to pass into the fmin algorithm\n", "def f2(n,x,s):\n", " x = x + n*s\n", " return f(x)\n", "\n", "x_old = 0\n", "x_new = 2 # The algorithm starts at x=2\n", "precision = 0.0001\n", "\n", "x_list, y_list = [x_new], [f(x_new)]\n", "\n", "# returns the value of the derivative of our function\n", "def f_prime(x):\n", " return 3*x**2-4*x\n", "\n", "while abs(x_new - x_old) > precision:\n", " x_old = x_new\n", " s_k = -f_prime(x_old)\n", " \n", " # use scipy fmin function to find ideal step size.\n", " n_k = fmin(f2,0.1,(x_old,s_k), full_output = False, disp = False)\n", "\n", " x_new = x_old + n_k * s_k\n", " x_list.append(x_new)\n", " y_list.append(f(x_new))\n", " \n", "print(\"Local minimum occurs at \", float(x_new))\n", "print(\"Number of steps:\", len(x_list))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With adaptive step sizes, the algorithm converges in just 4 iterations rather than 17. Of course, it takes time to compute the appropriate step size at each iteration. Here are some plots of the path taken below. You can see that it converges very quickly to a point near the local minimum, so it's hard to even discern the dots after the first two steps until we zoom in very close in the third frame below:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=[15,3])\n", "plt.subplot(1,3,1)\n", "plt.scatter(x_list,y_list,c=\"r\")\n", "plt.plot(x_list,y_list,c=\"r\")\n", "plt.plot(x,f(x), c=\"b\")\n", "plt.xlim([-1,2.5])\n", "plt.title(\"Gradient descent\")\n", "plt.subplot(1,3,2)\n", "plt.scatter(x_list,y_list,c=\"r\")\n", "plt.plot(x_list,y_list,c=\"r\")\n", "plt.plot(x,f(x), c=\"b\")\n", "plt.xlim([1.2,2.1])\n", "plt.ylim([0,3])\n", "plt.title(\"zoomed in\")\n", "plt.subplot(1,3,3)\n", "plt.scatter(x_list,y_list,c=\"r\")\n", "plt.plot(x_list,y_list,c=\"r\")\n", "plt.plot(x,f(x), c=\"b\")\n", "plt.xlim([1.3333,1.3335])\n", "plt.ylim([0,3])\n", "plt.title(\"zoomed in more\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another approach to update the step size is choosing a decrease constant $d$ that shrinks the step size over time:\n", "$\\eta(t+1) = \\eta(t) / (1+t \\times d)$." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local minimum occurs at: 1.3308506740900838\n", "Number of steps: 6\n" ] } ], "source": [ "x_old = 0\n", "x_new = 2 # The algorithm starts at x=2\n", "n_k = 0.17 # step size\n", "precision = 0.0001\n", "t, d = 0, 1\n", "\n", "x_list, y_list = [x_new], [f(x_new)]\n", "\n", "# returns the value of the derivative of our function\n", "def f_prime(x):\n", " return 3*x**2-4*x\n", " \n", "while abs(x_new - x_old) > precision:\n", " x_old = x_new\n", " s_k = -f_prime(x_old)\n", " x_new = x_old + n_k * s_k\n", " x_list.append(x_new)\n", " y_list.append(f(x_new))\n", " n_k = n_k / (1 + t * d)\n", " t += 1\n", "\n", "print(\"Local minimum occurs at:\", x_new)\n", "print(\"Number of steps:\", len(x_list))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now consider an example which is a little bit more complicated. Consider a simple linear regression where we want to see how the temperature affects the noises made by crickets. We have a data set of cricket chirp rates at various temperatures. First we'll load that data set in and plot it:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#Load the dataset\n", "data = np.loadtxt('SGD_data.txt', delimiter=',')\n", " \n", "#Plot the data\n", "plt.scatter(data[:, 0], data[:, 1], marker='o', c='b')\n", "plt.title('cricket chirps vs temperature')\n", "plt.xlabel('chirps/sec for striped ground crickets')\n", "plt.ylabel('temperature in degrees Fahrenheit')\n", "plt.xlim([13,21])\n", "plt.ylim([65,95])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal is to find the equation of the straight line $h_\\theta(x) = \\theta_0 + \\theta_1 x$ that best fits our data points. The function that we are trying to minimize in this case is:\n", "\n", "$J(\\theta_0,\\theta_1) = {1 \\over 2m} \\sum\\limits_{i=1}^m (h_\\theta(x_i)-y_i)^2$\n", "\n", "In this case, our gradient will be defined in two dimensions:\n", "\n", "$\\frac{\\partial}{\\partial \\theta_0} J(\\theta_0,\\theta_1) = \\frac{1}{m} \\sum\\limits_{i=1}^m (h_\\theta(x_i)-y_i)$\n", "\n", "$\\frac{\\partial}{\\partial \\theta_1} J(\\theta_0,\\theta_1) = \\frac{1}{m} \\sum\\limits_{i=1}^m ((h_\\theta(x_i)-y_i) \\cdot x_i)$\n", "\n", "Below, we set up our function for h, J and the gradient:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "h = lambda theta_0,theta_1,x: theta_0 + theta_1*x\n", "\n", "def J(x,y,m,theta_0,theta_1):\n", " returnValue = 0\n", " for i in range(m):\n", " returnValue += (h(theta_0,theta_1,x[i])-y[i])**2\n", " returnValue = returnValue/(2*m)\n", " return returnValue\n", "\n", "def grad_J(x,y,m,theta_0,theta_1):\n", " returnValue = np.array([0.,0.])\n", " for i in range(m):\n", " returnValue[0] += (h(theta_0,theta_1,x[i])-y[i])\n", " returnValue[1] += (h(theta_0,theta_1,x[i])-y[i])*x[i]\n", " returnValue = returnValue/(m)\n", " return returnValue" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we'll load our data into the x and y variables;" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "x = data[:, 0]\n", "y = data[:, 1]\n", "m = len(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we run our gradient descent algorithm (without adaptive step sizes in this example):" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local minimum occurs where:\n", "theta_0 = 25.128552558595363\n", "theta_1 = 3.297264756251897\n", "This took 565859 steps to converge\n" ] } ], "source": [ "theta_old = np.array([0.,0.])\n", "theta_new = np.array([1.,1.]) # The algorithm starts at [1,1]\n", "n_k = 0.001 # step size\n", "precision = 0.001\n", "num_steps = 0\n", "s_k = float(\"inf\")\n", "\n", "while np.linalg.norm(s_k) > precision:\n", " num_steps += 1\n", " theta_old = theta_new\n", " s_k = -grad_J(x,y,m,theta_old[0],theta_old[1])\n", " theta_new = theta_old + n_k * s_k\n", "\n", "print(\"Local minimum occurs where:\")\n", "print(\"theta_0 =\", theta_new[0])\n", "print(\"theta_1 =\", theta_new[1])\n", "print(\"This took\",num_steps,\"steps to converge\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For comparison, let's get the actual values for $\\theta_0$ and $\\theta_1$:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Actual values for theta are:\n", "theta_0 = 25.232304983426026\n", "theta_1 = 3.2910945679475647\n" ] } ], "source": [ "actualvalues = sp.stats.linregress(x,y)\n", "print(\"Actual values for theta are:\")\n", "print(\"theta_0 =\", actualvalues.intercept)\n", "print(\"theta_1 =\", actualvalues.slope)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we see that our values are relatively close to the actual values (even though our method was pretty slow). If you look at the source code of [linregress](https://github.com/scipy/scipy/blob/master/scipy/stats/_stats_mstats_common.py), it uses the convariance matrix of x and y to compute fastly. Below, you can see a plot of the line drawn with our theta values against the data:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xx = np.linspace(0,21,1000)\n", "plt.scatter(data[:, 0], data[:, 1], marker='o', c='b')\n", "plt.plot(xx,h(theta_new[0],theta_new[1],xx))\n", "plt.xlim([13,21])\n", "plt.ylim([65,95])\n", "plt.title('cricket chirps vs temperature')\n", "plt.xlabel('chirps/sec for striped ground crickets')\n", "plt.ylabel('temperature in degrees Fahrenheit')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that in the method above we need to calculate the gradient in every step of our algorithm. In the example with the crickets, this is not a big deal since there are only 15 data points. But imagine that we had 10 million data points. If this were the case, it would certainly make the method above far less efficient.\n", "\n", "In machine learning, the algorithm above is often called batch gradient descent to contrast it with mini-batch gradient descent (which we will not go into here) and stochastic gradient descent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stochastic gradient descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we said above, in batch gradient descent, we must look at every example in the entire training set on every step (in cases where a training set is used for gradient descent). This can be quite slow if the training set is sufficiently large. In stochastic gradient descent, we update our values after looking at each item in the training set, so that we can start making progress right away. Recall the linear regression example above. In that example, we calculated the gradient for each of the two theta values as follows:\n", "\n", "$\\frac{\\partial}{\\partial \\theta_0} J(\\theta_0,\\theta_1) = \\frac{1}{m} \\sum\\limits_{i=1}^m (h_\\theta(x_i)-y_i)$\n", "\n", "$\\frac{\\partial}{\\partial \\theta_1} J(\\theta_0,\\theta_1) = \\frac{1}{m} \\sum\\limits_{i=1}^m ((h_\\theta(x_i)-y_i) \\cdot x_i)$\n", "\n", "Where $h_\\theta(x) = \\theta_0 + \\theta_1 x$\n", "\n", "Then we followed this algorithm (where $\\alpha$ was a non-adapting stepsize):\n", "\n", "    1:   Choose initial guess $x_0$
\n", "    2:   for k = 0, 1, 2, ... do
\n", "    3:       $s_k$ = -$\\nabla f(x_k)$
\n", "    4:       $x_{k+1} = x_k + \\alpha s_k$
\n", "    5:   end for\n", "\n", "When the sample data had 15 data points as in the example above, calculating the gradient was not very costly. But for very large data sets, this would not be the case. So instead, we consider a stochastic gradient descent algorithm for simple linear regression such as the following, where m is the size of the data set:\n", "\n", "    1:   Randomly shuffle the data set
\n", "    2:   for k = 0, 1, 2, ... do
\n", "    3:       for i = 1 to m do
\n", "    4:            $\\begin{bmatrix}\n", " \\theta_{1} \\\\ \n", " \\theta_2 \\\\ \n", " \\end{bmatrix}=\\begin{bmatrix}\n", " \\theta_1 \\\\ \n", " \\theta_2 \\\\ \n", " \\end{bmatrix}-\\alpha\\begin{bmatrix}\n", " 2(h_\\theta(x_i)-y_i) \\\\ \n", " 2x_i(h_\\theta(x_i)-y_i) \\\\ \n", " \\end{bmatrix}$
\n", "    5:       end for
\n", "    6:   end for\n", "\n", "Typically, with stochastic gradient descent, you will run through the entire data set 1 to 10 times (see value for k in line 2 of the pseudocode above), depending on how fast the data is converging and how large the data set is.\n", "\n", "With batch gradient descent, we must go through the entire data set before we make any progress. With this algorithm though, we can make progress right away and continue to make progress as we go through the data set. Therefore, stochastic gradient descent is often preferred when dealing with large data sets.\n", "\n", "Unlike gradient descent, stochastic gradient descent will tend to oscillate near a minimum value rather than continuously getting closer. It may never actually converge to the minimum though. One way around this is to slowly decrease the step size $\\alpha$ as the algorithm runs. However, this is less common than using a fixed $\\alpha$.\n", "\n", "Let's look at another example where we illustrate the use of stochastic gradient descent for linear regression. In the example below, we'll create a set of 500,000 points around the line $y = 2x+17+\\epsilon$, for values of x between 0 and 100:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "f = lambda x: x*2+17+np.random.randn(len(x))*10\n", "\n", "x = np.random.random(500000)*100\n", "y = f(x) \n", "m = len(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's randomly shuffle around our dataset. Note that in this example, this step isn't strictly necessary since the data is already in a random order. However, that obviously may not always be the case:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from random import shuffle\n", "\n", "x_shuf = []\n", "y_shuf = []\n", "index_shuf = list(range(len(x)))\n", "shuffle(index_shuf)\n", "for i in index_shuf:\n", " x_shuf.append(x[i])\n", " y_shuf.append(y[i])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll setup our h function and our cost function, which we will use to check how the value is improving." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "h = lambda theta_0,theta_1,x: theta_0 + theta_1*x\n", "cost = lambda theta_0,theta_1, x_i, y_i: 0.5*(h(theta_0,theta_1,x_i)-y_i)**2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll run our stochastic gradient descent algorithm. To see it's progress, we'll take a cost measurement at every step. Every 10,000 steps, we'll get an average cost from the last 10,000 steps and then append that to our cost_list variable. We will run through the entire list 10 times here:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local minimum occurs where:\n", "theta_0 = 16.955415496837382\n", "theta_1 = 1.999169728242732\n" ] } ], "source": [ "theta_old = np.array([0.,0.])\n", "theta_new = np.array([1.,1.]) # The algorithm starts at [1,1]\n", "n_k = 0.000005 # step size\n", "\n", "iter_num = 0\n", "s_k = np.array([float(\"inf\"),float(\"inf\")])\n", "sum_cost = 0\n", "cost_list = []\n", "\n", "for j in range(10):\n", " for i in range(m):\n", " iter_num += 1\n", " theta_old = theta_new\n", " s_k[0] = (h(theta_old[0],theta_old[1],x[i])-y[i])\n", " s_k[1] = (h(theta_old[0],theta_old[1],x[i])-y[i])*x[i]\n", " s_k = (-1)*s_k\n", " theta_new = theta_old + n_k * s_k\n", " sum_cost += cost(theta_old[0],theta_old[1],x[i],y[i])\n", " if (i+1) % 10000 == 0:\n", " cost_list.append(sum_cost/10000.0)\n", " sum_cost = 0 \n", " \n", "print(\"Local minimum occurs where:\")\n", "print(\"theta_0 =\", theta_new[0])\n", "print(\"theta_1 =\", theta_new[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, our values for $\\theta_0$ and $\\theta_1$ are close to their true values of 17 and 2.\n", "\n", "Now, we plot our cost versus the number of iterations. As you can see, the cost goes down quickly at first, but starts to level off as we go through more iterations:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "iterations = np.arange(len(cost_list))*10000\n", "plt.plot(iterations,cost_list)\n", "plt.xlabel(\"iterations\")\n", "plt.ylabel(\"avg cost\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 1 }