{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "# Gradient Descent Intro\n", "

" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import math,sys,os,numpy as np\n", "from numpy.random import random\n", "from matplotlib import pyplot as plt, rcParams, animation, rc\n", "from __future__ import print_function, division\n", "from ipywidgets import interact, interactive, fixed\n", "from ipywidgets.widgets import *\n", "rc('animation', html='html5')\n", "rcParams['figure.figsize'] = 3, 3\n", "%precision 4\n", "np.set_printoptions(precision=4, linewidth=100)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def lin(a,b,x): return a*x+b" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "a=3.\n", "b=8." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "n=30\n", "x = random(n)\n", "y = lin(a,b,x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.3046, 0.918 , 0.7925, 0.8476, 0.2508, 0.3504, 0.8326, 0.6875, 0.4449, 0.4687,\n", " 0.5901, 0.2757, 0.6629, 0.169 , 0.8677, 0.6612, 0.112 , 0.1669, 0.6226, 0.6174,\n", " 0.3871, 0.4724, 0.3242, 0.7871, 0.0157, 0.8589, 0.7008, 0.2942, 0.3166, 0.5847])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 8.9138, 10.7541, 10.3775, 10.5428, 8.7525, 9.0511, 10.4977, 10.0626, 9.3347,\n", " 9.4062, 9.7704, 8.827 , 9.9888, 8.507 , 10.603 , 9.9836, 8.336 , 8.5006,\n", " 9.8678, 9.8523, 9.1614, 9.4172, 8.9725, 10.3614, 8.0471, 10.5766, 10.1025,\n", " 8.8827, 8.9497, 9.7542])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAANYAAADFCAYAAAAooQwbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADyRJREFUeJzt3X+MHOV9x/H3x+dLe0ZVz8Hm15GriYRQIA4hPRnLJRFp\nqgBOKC5CkVFQK5TWKgK1jSJLICF+REhQuf0nTZvISlESpYAIMReTAiY/pCQKgvbMmR5W6tbl90KL\nAxyUcBVn8+0fO2vW69sft7vP7Ozu5yWtbnd2dufxig8z88zzfEcRgZl114peN8BsEDlYZgk4WGYJ\nOFhmCThYZgk4WGYJOFhmCThYZgk4WGYJrOx1A5ayZs2aWLduXa+bYXacvXv3/ioi1jZbr5DBWrdu\nHTMzM71uhtlxJD3Xyno+FDRLwMEyS8DBMkugkOdYZnmYni2xY88BXppf4LTxMbZfdBZbzpvoync7\nWDaUpmdL3LBrjoXFIwCU5he4YdccQFfC5UNBG0o79hw4GqqKhcUj7NhzoCvf72DZUHppfmFZy5fL\nwbKhdNr42LKWL5eDZUNp+0VnMTY6csyysdERtl90Vle+350XNpQqHRTuFTTrsi3nTXQtSLUcLBtY\nKa9TNeNg2UBKfZ2qGXde2EC6Zff+pNepmnGwbOBMz5aYX1hc8r1uXadqxsGygTI9W+JL9z5Z9/1u\nXadqxsGygVE5rzrS4H4E3bpO1UzTYEm6U9Irkp6qWvZ+ST+U9J/Z39V1PvuspDlJ+yR5SrAlU9lT\n1Z5XVVu9ajS3XsFW9ljfBC6uWXY98OOIOBP4cfa6nk9GxEcjYqq9Jpo11sqeamx0hJsvPSe3NjUN\nVkT8DHitZvFlwLey598CtnS5XWYtW2qkerURidsvX5/b3graP8c6OSJezp7/N3BynfUC+JGkvZK2\nNfpCSdskzUiaOXToUJvNsmHUqKdvbHSEv/3cubmGCrrQeRHlO9fV2wdfEBEfBS4BrpX0iQbfszMi\npiJiau3aptWlzI6q19PXiz1VRbvB+h9JpwJkf19ZaqWIKGV/XwHuBza0uT2zuuqNVO/Fnqqi3SFN\nu4E/Ae7I/n6/dgVJJwArIuJ/s+efBr7cbkPN6o39Sz1SvR1NgyXpbuBCYI2kF4GbKQfqXklfAJ4D\nPpetexrwjYjYTPm8635Jle3cFREPp/hH2OC7cXqOf3rs+aPnHLVj/1KOVG9H02BFxJV13vrUEuu+\nBGzOnj8NnNtR68wo76mqQ1VRGftXpEBVeOSFFd6OPQfq9o7lNfZvuRwsK7xG4clr7N9yOVhWePXC\nI/Ib+7dcDpYV3lLd6QI+v3GykOdX4BnE1geK2J3ejINlfaFo3enN+FDQLAEHyywBB8ssAQfLLAEH\nyywBB8ssAXe3W656WfY5Tw6W5abXZZ/z5ENBy0W98mR5ln3Ok4NlyTUrT1bUqR+dcLAsuWblyYo6\n9aMTDpYl16w8WVGnfnQidYnpiyUdkHRQUqNquTZgpmdL/N4dP+GM6/+ZFeW6J8fpZXmy1JKVmJY0\nAvw95ZqCZwNXSjq7o9ZaX5ieLbH9u09Sml8gYMlzq16XJ0stZYnpDcDBiHg6It4B7sk+ZwPult37\nWXz3+DApe0yMjw3snqqi3etYrZSYngBeqHr9InB+vS/MSlBvA5icnGyzWVYE9W76FsCzd3wm38b0\nSOoS08v5HpeYtoGRssR0CfhA1evTs2U24FavGl3W8kHUbrAqJaahTolp4F+BMyWdIel9wNbsczbg\nbr70HEZHju0JHB1Rrven6rVkJaYj4rCk64A9wAhwZ0TsT/PPsF5pNKh2GAbb1qNocBe8XpmamoqZ\nGd9Ztehq66lDuRt9kHv8JO1t5e6kHnlhbWlWT33YOVjWln6sp54nB8va0o/11PPkiY7WktpOivFV\no7z+9vEXgotcTz1PDpY1tdRN30ZXiNERsXjkvQPCotdTz5ODZQ3V66RYfDcYHxvlhN9YObRd6o04\nWNZQo06KNxYW2Xfzp3NtT79w54U15E6K9jhY1lA/3vStCBwsa6gfb/pWBD7HsoY87q89DpY11W83\nfSsCHwqaJeBgmSXgYJkl4GCZJeBgmSXgYJkl0FGwJP2lpKck7Zf0V0u8f6GkNyTtyx43dbI9s37R\n9nUsSR8G/oxyxdt3gIcl/SAiDtas+vOI+GwHbTTrO53ssT4EPB4Rb0fEYeCnwOXdaZZZf+skWE8B\nH5d0oqRVwGaOLdBZsUnSv0l6SFLdwnKStkmakTRz6NChDppl1nttHwpGxC8l/TXwCPBrYB9Qe3ex\nJ4DJiHhL0mZgGjizzvftBHZCufxZu+2y9wzLjbSLqKPOi4j4x4j43Yj4BPA68B81778ZEW9lzx8E\nRiWt6WSb1prp2RLb73vvVjql+QW23/ck07Ou8p2HTnsFT8r+TlI+v7qr5v1TpPJdxyRtyLb3aifb\ntNbc+sD+Y+pRACweCW59wMWI89Dp6PbvSToRWASujYh5SX8OEBFfB64ArpF0GFgAtkYRS+8OoKUq\nKDVabt3VUbAi4uNLLPt61fOvAl/tZBtm/cjzsQZIdWeFWPqmZeNjw3MrnV5ysAbE9GyJG3bNsbBY\n2zH7ntEV4pY/HJ5b6fSSgzUgduw5sGSoRiTejXB3e84crAFRr0zZuxE8MyT3/S0Sj24fEPXKlLn2\nX284WANiqTJlY6Mjrv3XIz4UHBAuU1YsDtYAcZmy4vChoFkC3mP1AY9S7z8OVsHVXvgtzS9ww645\nAIerwHwoWHBLXfj1nemLz8EquHoXfn1n+mJzsArOF377k4NVcL7w25/ceVEgjXr/3CvYXxysgmjW\n++cg9RcfChaEe/8GS+oS05L0FUkHs9qCH+tke4PMvX+DJXWJ6Uso1xE8Ezgf+Fr21zj2nGqFxJEl\n6uy4968/pS4xfRnw7Sh7DBiXdGoH2xwYlXOqSt2/pULl3r/+lbrE9ATwQtXrF7Nlxxm2EtO37N5f\ndyq9gInxMW6/fL07LfpU6hLTy/m+oSkxPT1bYn5h6fp+nko/GJKWmAZKHLsXOz1bNrSmZ0t86d4n\n677vc6rB0NF1LEknRcQrVSWmN9asshu4TtI9lDst3oiIlzvZZr+ani1x6wP7m1ai9TnVYEhdYvpB\nyudeB4G3gas73F5faqXmH8DqVaM+pxoQqUtMB3BtJ9sYBPVq/lUbGx3h5ktdTHNQeEhTQpXrVKUm\nF3lHJPcADhgHK5FWD//GRkccqgHksYKJtHL4Nz426lANKO+xEmk0xm/CUz8GnoPVRa2M/ZsYH+MX\n1/9+D1pneXKwuqT2nMpj/4abg9Ulvo2OVXOwusS30bFq7hXsEldTsmoOVpe4mpJV86Fgl7iaklVz\nsLrI1ZSswoeCZgk4WGYJOFhmCThYZgk4WGYJOFhmCXRaTOaLwJ8CAcwBV0fE/1W9fyHwfeCZbNGu\niPhyJ9vMy43Tc9z9+AsciWBE4srzP8BtW9b3ulnWJzopMT0B/AVwdkQsSLoX2Ap8s2bVn0fEZ9tv\nYv5unJ7jO489f/T1kYijrx0ua0Wnh4IrgTFJK4FVwEudN6n37n78hWUtN6vVdrAiogT8DfA88DLl\nmoGPLLHqpuxOIw9JqluGqEglppeaS9VouVmttoMlaTXlmx6cAZwGnCDpqprVngAmI+IjwN8B0/W+\nLyJ2RsRUREytXbu23WZ1xYi0rOVmtTo5FPwD4JmIOBQRi8AuYFP1ChHxZkS8lT1/EBiVtKaDbebi\nyvNr7+3QeLlZrU6C9TywUdIqSQI+BfyyegVJp2TvIWlDtr1XO9hmLm7bsp6rNk4e3UONSFy1cdId\nF9ayTu428rik+ygf7h0GZoGdNSWmrwCukXQYWAC2ZtVxC++2LesdJGubivjf+dTUVMzMzPS6GWbH\nkbQ3IqaarTeU87Gqy5R5QqKlMHTBqi1TVppf4IZdcwAOl3XN0I0VXKpM2cLiEXbsOdCjFtkgGrpg\n1StT1qgktNlyDV2wXKbM8jB0wXKZMsvD0HVeuEyZ5WHoggUuU2bpDd2hoFkeHCyzBBwsswQcLLME\nHCyzBBwsswQcLLME+vI6lqd9WNH1XbA87cP6Qd8dCnrah/WDjoIl6YuS9kt6StLdkn6z5n1J+oqk\ng1ltwY911lxP+7D+0EldwUqJ6amI+DAwQrnEdLVLgDOzxzbga+1ur8LTPqwfpC4xfRnw7Sh7DBiX\ndGonG/S0D+sHqUtMTwDVBc9fzJYdp9US01vOm+D2y9czMT6GgInxMW6/fL07LqxQOrnbSHWJ6Xng\nu5KuiojvtPN9EbET2Anl8meN1vW0Dyu6pCWmgRJQXZf59GyZ2UBLWmIa2A38cdY7uJHy4eLLHWzT\nrC+kLjH9ILAZOAi8DVzdcYvN+oBLTJstQ6slpgsZLEmHgOcarLIG+FVOzWlH0dsHxW9jUdv3OxHR\n9AZuhQxWM5JmWvm/Rq8UvX1Q/DYWvX3N9N1YQbN+4GCZJdCvwdrZ6wY0UfT2QfHbWPT2NdSX51hm\nRdeveyyzQnOwzBIoVLAkXSzpQDYx8vol3q87cbLZZ3Ns4+ezts1JelTSuVXvPZst3ycpyRXwFtp3\noaQ3sjbsk3RTq5/NsY3bq9r3lKQjkt6fvZf8N+yKiCjEg/JEyf8CPgi8D3gSOLtmnc3AQ4CAjcDj\nrX42xzZuAlZnzy+ptDF7/Sywpse/4YXAD9r5bF5trFn/UuAnef2G3XoUaY+1ATgYEU9HxDvAPZSn\npVSrN3Gylc/m0saIeDQiXs9ePkZ5RH9eOvkdCvMb1rgSuDtBO5IqUrBamRRZb52WJ1Tm0MZqX6C8\nh60I4EeS9kra1sP2bcoOVx+SdM4yP5tXG5G0CrgY+F7V4tS/YVf0XfmzfiHpk5SDdUHV4gsioiTp\nJOCHkv49In6Wc9OeACYj4i1Jm4FpyjVJiuhS4BcR8VrVsiL8hk0VaY/VyqTIeuvkNaGype1I+gjw\nDeCyiHi1sjzK5QyIiFeA+ykfFuXavoh4MyLeyp4/CIxKWtPKZ/NqY5Wt1BwG5vAbdkevT/KqTkpX\nAk9TnupfOak9p2adz3Bs58W/tPrZHNs4SXn+2aaa5ScAv1X1/FHg4h607xTeGxiwgfKEVRXpN8zW\n+23gNeCEPH/Dbj0KcygYEYclXQfsodxzdGdE7G9l4mS9z/aojTcBJwL/UJ5YzeEoj9I+Gbg/W7YS\nuCsiHu5B+64ArpF0GFgAtkb5v9Qi/YYAfwQ8EhG/rvp48t+wWzykySyBIp1jmQ0MB8ssAQfLLAEH\nyywBB8ssAQfLLAEHyyyB/we26e4GZgQSIAAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x,y)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sse(y,y_pred): return ((y-y_pred)**2).sum()\n", "def loss(y,a,b,x): return sse(y, lin(a,b,x))\n", "def avg_loss(y,a,b,x): return np.sqrt(loss(y,a,b,x)/n)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.1074" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a_guess=-1.\n", "b_guess=1.\n", "avg_loss(y, a_guess, b_guess, x)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "lr=0.01\n", "# d[(y-(a*x+b))**2,b] = 2 (b + a x - y) = 2 (y_pred - y)\n", "# d[(y-(a*x+b))**2,a] = 2 x (b + a x - y) = x * dy/db" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def upd():\n", " global a_guess, b_guess\n", " \n", " # make a prediction using the current weights\n", " y_pred = lin(a_guess, b_guess, x)\n", " \n", " # calculate the derivate of the loss\n", " dydb = 2 * (y_pred - y)\n", " dyda = x*dydb\n", " \n", " # update our weights by moving in direction of steepest descent\n", " a_guess -= lr*dyda.mean()\n", " b_guess -= lr*dydb.mean()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig = plt.figure(dpi=100, figsize=(5, 4))\n", "plt.scatter(x,y)\n", "line, = plt.plot(x,lin(a_guess,b_guess,x))\n", "plt.close()\n", "\n", "def animate(i):\n", " line.set_ydata(lin(a_guess,b_guess,x))\n", " for i in range(10): upd()\n", " return line,\n", "\n", "ani = animation.FuncAnimation(fig, animate, np.arange(0, 40), interval=100)\n", "ani" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": true, "toc_section_display": "block", "toc_window_display": false }, "widgets": { "state": {}, "version": "1.1.2" } }, "nbformat": 4, "nbformat_minor": 1 }