{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "# Gradient Descent Intro\n", "

" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import math,sys,os,numpy as np\n", "from numpy.random import random\n", "from matplotlib import pyplot as plt, rcParams, animation, rc\n", "from __future__ import print_function, division\n", "from ipywidgets import interact, interactive, fixed\n", "from ipywidgets.widgets import *\n", "rc('animation', html='html5')\n", "rcParams['figure.figsize'] = 3, 3\n", "%precision 4\n", "np.set_printoptions(precision=4, linewidth=100)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def lin(a,b,x): return a*x+b" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "a=3.\n", "b=8." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "n=30\n", "x = random(n)\n", "y = lin(a,b,x)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.0132, 0.4765, 0.0034, 0.8314, 0.5044, 0.0817, 0.7193, 0.8595, 0.7664, 0.0203,\n", " 0.4296, 0.5188, 0.2627, 0.4388, 0.3663, 0.0792, 0.5146, 0.9705, 0.4546, 0.0884,\n", " 0.9854, 0.3523, 0.5519, 0.1516, 0.1063, 0.0296, 0.9777, 0.3856, 0.5225, 0.2138])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 8.0395, 9.4294, 8.0103, 10.4941, 9.5131, 8.245 , 10.1578, 10.5785, 10.2991,\n", " 8.061 , 9.2889, 9.5564, 8.7882, 9.3164, 9.099 , 8.2375, 9.5439, 10.9116,\n", " 9.3637, 8.2653, 10.9562, 9.0568, 9.6556, 8.4548, 8.3188, 8.0888, 10.9331,\n", " 9.1568, 9.5675, 8.6414])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAANYAAADFCAYAAAAooQwbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEFdJREFUeJzt3X2MXNV5x/Hvz8uiLijKQgwBL2xNVGSVl/DSleNaJCJN\nWowhwUUpARE1RbSICtQUVZZIi1qIIkFF21Q0KNShKIkgEJTAhqQmJkmlEoWYxrAQ26VOXF490JgQ\nHErYCmM//WPukGG8d2Z27j3z+vtI1s7ce+7cs8M+3HOfOecZRQRmVq4lve6A2TByYJkl4MAyS8CB\nZZaAA8ssAQeWWQIOLLMEHFhmCTiwzBI4qNcdWMjSpUtj+fLlve6G2QEeeeSRn0XEEa3a9WVgLV++\nnC1btvS6G2YHkPRMO+08FDRLoGVgSbpN0m5J2+q2/YGk7ZL2S5ppcuwaSTsk7ZR0dVmdNut37Vyx\nvgCsadi2DTgfeDDvIEljwM3A2cAJwEWSTuism2aDpeU9VkQ8KGl5w7YnACQ1O3QlsDMinsza3gWc\nB/xnh301K93sXIUbN+3g+T3zLJucYP1ZK1h32lTh102ZvJgCnqt7vgt4T15jSZcBlwFMT08n7JZZ\n1TWzW7lj87PUViRW9szzyXu2AhQOrr5JXkTEhoiYiYiZI45omc0069jsXIVTr3uA2+uCqmZ+7z5u\n3LSj8DlSXrEqwLF1z4/Jtpn1zOxchU/es5X5vfty2zy/Z77weVJesX4IHC/pOEkHAxcC9yU8n1lL\nN27a0TSoAJZNThQ+Tzvp9juBHwArJO2SdKmk35e0C/ht4F8lbcraLpO0ESAi3gCuBDYBTwB3R8T2\nwj02K6DV1UjA+rNWFD5PO1nBi3J23btA2+eBtXXPNwIbO+6dWcmWTU5QyQkuARevmi4lK9g3yQuz\nblh/1gomxscO2H7YIeN85qOn8ul1J5dynr6cK2iWSu1qlOKzq3oOLBs5606bKj2QGnkoaJaAA8ss\nAQeWWQIOLLMEHFhmCTiwzBJwYJkl4MAyS8CBZZaAZ17YwEm1nL5MDiwbKI0LFctcTl8mDwVtoCy0\nULGs5fRlcmDZQMlbqFjGcvoytRwKSroNOBfYHREnZdsOB74CLAeeBi6IiJcXOPZp4H+BfcAbEZFb\n3NMsT/091RKJfdFYAqac5fRl6rRg59XAdyPieOC72fM874+IUx1U1onaPVVlzzwBCwbVxPhYKcvp\ny9QysCLiQeDnDZvPA76YPf4isK7kfpkB+cVfxiQETE1OcP35J/dV4gI6zwq+MyJeyB7/D/DOnHYB\nfEfSPuCfI2JD3gu6YKctJO/eaX8ET91wTpd7077CyYuICDig7mHNGRFxKtX67VdIel+T13HBTjtA\n3r1Tv91TNeo0sH4q6WiA7OfuhRpFRCX7uZtqVaeVHZ7PRtRCxV/68Z6qUaeBdR/w8ezxx4GvNzaQ\ndKikt9UeA79H9VtKzNq27rQprj//ZKYmJ/r6nqpRO+n2O4EzgaVZkc6/AW4A7pZ0KfAMcEHWdhlw\na0SspXrfdW/2jSQHAV+OiG+l+CVsuHWj+EvZihTs/MACbd8s2Jl9fc8phXpnNqA888IsAU/CtZ4Z\nhFnqnXJgWU8Myiz1TnkoaD0xKLPUO+XAsp4YlFnqnXJgWU8M6oyKdjmwrCcGdUZFu5y8sJ7o1tfp\n9IoDy5JqllIfxBkV7XJgWTLDnlJvxvdYlsywp9SbcWBZMsOeUm/GgWXJDHtKvRkHliUz7Cn1Zpy8\nsGSGPaXejAPLkhrmlHozLYeCkm6TtFvStrpth0v6tqSfZD8Pyzl2jaQdknZKalZ70GyoJCvYKWkM\nuJlqhaYTgIsknVCot2YDImXBzpXAzoh4MiJeB+7KjjMbep1mBdsp2DkFPFf3fFe2bUGSLpO0RdKW\nF198scNumfWH1AU7F/M6LthpQyNlwc4KcGzd82OybWZDL1nBTuCHwPGSjpN0MHBhdpzZ0Gsn3X4n\n8ANghaRdWZHOG4DflfQT4IPZcyQtk7QRICLeAK4ENgFPAHdHxPY0v4ZZf0lWsDN7vhHY2HHvzAaU\n5wqaJeDAMkvAgWWWgAPLLAEHllkCDiyzBBxYZgk4sMwS8AriETbM30/Vaw6sETXKxTS7wUPBETXK\nxTS7wYE1oka5mGY3eCg4Ihrvp94+Mc6e+b0HtBuFYprd4MAaAdfMbuX2zc+++byyZ56xJWJ8idi7\n/1eLv0elmGY3eCg45GbnKm8Jqpp9+4ODD1rC1OQEAqYmJ7j+/JOduChJoSuWpE8AfwII+HxE/GPD\n/jOpri5+Ktt0T0R8qsg5bXGu+0b+2tJfvr6P7Z/6nS72ZnR0HFiSTqIaVCuB14FvSfpmROxsaPq9\niDi3QB+tQ7NzFV5+7cD7KEuvyFDwN4GHI+K1bBn+vwPnl9MtK0Or1PnkxHiXejJ6igTWNuC9kt4h\n6RCqS/KPXaDdakk/knS/pBPzXsx1BcvXLHW+BLj2w7n/OaygjoeCEfGEpL8FHgB+CTwG7Gto9igw\nHRGvSloLzALH57zeBmADwMzMTOE6haOqPq2+RGJfHPhWCviHj57qREVChbKCEfEvEfFbEfE+4GXg\nxw37X4mIV7PHG4FxSUuLnNPy1aYpVfbME7BgUE2Mj/EZB1VyRbOCR0bEbknTVO+vVjXsPwr4aUSE\npJVUA/mlIue0fAtNUwIYk9gf4Ym2XVT0A+KvSXoHsBe4IiL2SLocICJuAT4C/KmkN4B54MKsJLWV\npH7ol/fG7o/gqRvO6Wq/Rl2hwIqI9y6w7Za6x58FPlvkHJavcYZ6Hk9T6j7PvBhgeUO/ep6m1Bue\nKzigZucqVJqk0wW+p+ohB9YAumZ2K3csMP+vZmpygu9f7alKveSh4ICZnatwx+ZncxMVHvr1BwfW\ngLlx046m3/LnGer9wYE1YJpNU5qanHBQ9QkH1oDJS50LPATsIw6sAbP+rBVMjI+9ZZuAi1dN+2rV\nR5wV7FN5Nf9qweN6gP3NgdWHWtX8qw8w608eCvYh1/wbfL5i9ZHa8C9vRoVr/g0OB1afaGdCrSfT\nDg4HVh+YnavwF3c/vuDCxBrPqBgsDqweq12pmgXVlDN/A8eB1WOtln54Qu1gKpQVlPQJSdskbZf0\n5wvsl6SbJO3MKjWdXuR8w6hZQsLDv8HVcWA1FOw8BThX0m80NDubalWm44HLgM91er5hlZeQGJM8\noXaApS7YeR7wpajaDExKOrrAOYfOQlOUJsbH+PsLTnFQDbDUBTungOfqnu/Kth1gVAt2rjttiuvP\nP9lfTjBkUhfsXMzrjWzBTk9RGj5JC3YCFd56FTsm22Y21IpmBY/MftYKdn65ocl9wB9m2cFVwC8i\n4oUi5zQbBKkLdm6keu+1E3gNuKTg+cwGQuqCnQFcUeQcZoPIy0bMEnBgmSXguYIly1tSb6PFgVWi\nVkvqbXR4KFgiL6m3GgdWifJmqntJ/ehxYJUob6a6l9SPHgdWifJmqntN1ehx8qJELqZpNQ6sRWqV\nTvdMdQMH1qI4nW7t8j3WIjidbu1yYC2C0+nWLgfWIjidbu1yYC2C0+nWLicvFsHpdGtXocCSdBXw\nx0AAW4FLIuL/6vafCXwdeCrbdE9EfKrIOXvN6XRrR8eBJWkK+DPghIiYl3Q3cCHwhYam34uIczvv\notngKToUPAiYkLQXOAR4vniXesdrqawsHScvIqIC/B3wLPAC1QpMDyzQdHVWt/1+SSfmvV6vC3bW\nPvyt7Jkn+NWHv7NzrtZmi1ekdvthVEtIHwcsAw6V9LGGZo8C0xHxbuCfgNm814uIDRExExEzRxxx\nRKfd6pg//LUyFUm3fxB4KiJejIi9wD3A6voGEfFKRLyaPd4IjEtaWuCcyfjDXytTkcB6Flgl6RBJ\nAj4APFHfQNJR2T4krczO91KBcybjD3+tTEXusR4Gvkp1uLc1e60Nki6vFe0EPgJsk/Q4cBNwYVZr\nsO/4w18rk/rx73xmZia2bNnS9fM6K2itSHokImZatRu5mRfNgscf/lpZRiqwvJ7KumWkJuE6pW7d\nMjKBNTtXoeKUunXJ0A8FZ+cq/OU9P+K1vftz2zilbmUb6sCanauw/quPs3dffubTKXVLYaiHgjdu\n2tE0qAB/kbYlMdSB1ereaWpywkFlSQx1YDW7dxJ4CGjJDHVgrT9rBeNjWnDfxaumfbWyZIY6eVEL\nnOu+sZ2XX9sLwOTEONd++EQHlSU11IEFnqZkvTEUgeXJs9ZvBjqwZucqXHvfdvbM731zm+f/WT8Y\n2OTFNbNbueorj70lqGo8/896bSCvWNfMbuX2zc82beP5f9ZLha5Ykq6StF3SNkl3Svq1hv2SdJOk\nnVmlptOLdbc6/LujRVCB5/9ZbxWp0lQr2DkTEScBY1QLdtY7Gzg++3cZ8LlOz1dz46YdtFrz7Pl/\n1mtF77FqBTsPYuGCnecBX4qqzcCkpKOLnLDVEO+wQ8Y9/896ruN7rIioSKoV7JwHHligYOcU8Fzd\n813ZthcaX0/SZVSvakxPT+eed9nkRO66qo+tmubT605exG9hlkbqgp1ta7dg50LVlISDyvpLkazg\nmwU7ASTVCnbeXtemAhxb9/yYbFvH/FU6NgiKBNabBTupDgU/ADTWLLsPuFLSXcB7qNZ3P2AYuFie\npmT9rsg91sOSagU73wDmyAp2ZvtvATYCa4GdwGvAJYV7bDYAXLDTbBHaLdg5sFOazPqZA8ssgb4c\nCkp6EXimRbOlwM+60J3F6sd+uU/ta9WvX4+Ill/g1peB1Q5JW9oZ63ZbP/bLfWpfWf3yUNAsAQeW\nWQKDHFgbet2BHP3YL/epfaX0a2Dvscz62SBfscz6lgPLLIG+DCxJayTtyJb0X73A/twl/62OTdin\ni7O+bJX0kKRT6vY9nW1/TFKpc7Xa6NeZkn6RnfsxSX/d7rEJ+7S+rj/bJO2TdHi2L8l7Jek2Sbsl\nbcvZX+7fVET01T+qS/z/G3gXcDDwOHBCQ5u1wP1Ul2KtAh5u99iEfVoNHJY9PrvWp+z508DSHr1X\nZwLf7OTYVH1qaP8h4N+68F69Dzgd2Jazv9S/qX68Yq0EdkbEkxHxOnAX1QWV9fKW/LdzbJI+RcRD\nEfFy9nQz1bVnqRX5fXv2XjW4CLizhPM2FREPAj9v0qTUv6l+DKy85fzttGnn2FR9qncp1f/71QTw\nHUmPZCUIytJuv1Znw5v7JZ24yGNT9YlsLd8a4Gt1m1O9V62U+jc1kHUF+5mk91MNrDPqNp8R1Roh\nRwLflvRf2f9Bu+FRYDoiXpW0FpilWjWrH3wI+H5E1F9JevlelaYfr1jtLOfPa1N6KYBF9AlJ7wZu\nBc6LiJdq2yOikv3cDdxLdXhRhpb9iohXIuLV7PFGYFzS0naOTdWnOhfSMAxM+F61Uu7fVNk3iSXc\nZB4EPEm1SE3tZvHEhjbn8NYbzf9o99iEfZqmulJ6dcP2Q4G31T1+CFjTxffqKH41EWAl1ZIK6uV7\nlbV7O9V7nkO78V5lr7mc/ORFqX9TPQ+knF9yLfBjqtmYv8q2XQ5cnj0WcHO2fyvVoqG5x3apT7cC\nLwOPZf+2ZNvflf3HeBzYXmaf2uzXldl5H6eaVFnd7Nhu9Cl7/kfAXQ3HJXuvqF4ZXwD2Ur1PujTl\n35SnNJkl0I/3WGYDz4FlloADyywBB5ZZAg4sswQcWGYJOLDMEvh/qGF7F6BJSVcAAAAASUVORK5C\nYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x,y)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sse(y,y_pred): return ((y-y_pred)**2).sum()\n", "def loss(y,a,b,x): return sse(y, lin(a,b,x))\n", "def avg_loss(y,a,b,x): return np.sqrt(loss(y,a,b,x)/n)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "8.9867" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a_guess=-1.\n", "b_guess=1.\n", "avg_loss(y, a_guess, b_guess, x)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "lr=0.01\n", "# d[(y-(a*x+b))**2,b] = 2 (b + a x - y) = 2 (y_pred - y)\n", "# d[(y-(a*x+b))**2,a] = 2 x (b + a x - y) = x * dy/db" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def upd():\n", " global a_guess, b_guess\n", " \n", " # make a prediction using the current weights\n", " y_pred = lin(a_guess, b_guess, x)\n", " \n", " # calculate the derivate of the loss\n", " dydb = 2 * (y_pred - y)\n", " dyda = x*dydb\n", " \n", " # update our weights by moving in direction of steepest descent\n", " a_guess -= lr*dyda.mean()\n", " b_guess -= lr*dydb.mean()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig = plt.figure(dpi=100, figsize=(5, 4))\n", "plt.scatter(x,y)\n", "line, = plt.plot(x,lin(a_guess,b_guess,x))\n", "plt.close()\n", "\n", "def animate(i):\n", " line.set_ydata(lin(a_guess,b_guess,x))\n", " for i in range(10): upd()\n", " return line,\n", "\n", "ani = animation.FuncAnimation(fig, animate, np.arange(0, 40), interval=100)\n", "ani" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": true, "toc_section_display": "block", "toc_window_display": false }, "widgets": { "state": {}, "version": "1.1.2" } }, "nbformat": 4, "nbformat_minor": 1 }