{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Linear Fitting using Gradient Descent Algorithm" ] }, { "cell_type": "code", "collapsed": false, "input": [ "# y = mx + b\n", "# m is slope, b is y-intercept\n", "def computeErrorForLineGivenPoints(b, m, points):\n", " totalError = 0\n", " for i in range(0, len(points)):\n", " x = points[i, 0]\n", " y = points[i, 1]\n", " totalError += (y - (m * x + b)) ** 2\n", " return totalError / float(len(points))\n", "\n", "def stepGradient(b_current, m_current, points, learningRate):\n", " b_gradient = 0\n", " m_gradient = 0\n", " N = float(len(points))\n", " for i in range(0, len(points)):\n", " x = points[i, 0]\n", " y = points[i, 1]\n", " b_gradient += -(2/N) * (y - ((m_current * x) + b_current))\n", " m_gradient += -(2/N) * x * (y - ((m_current * x) + b_current))\n", " new_b = b_current - (learningRate * b_gradient)\n", " new_m = m_current - (learningRate * m_gradient)\n", " return [new_b, new_m]\n", "\n", "import random\n", "from numpy.random import randn\n", "from numpy import *\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "random.seed(100)\n", "\n", "n = 100\n", "x = randn(n) * 10 + 50\n", "y = x * 1.5 + randn(n) * 10\n", "points = array([x, y]).transpose()\n", "\n", "b = 0\n", "m = 0\n", "for i in range(100):\n", " b, m = stepGradient(b, m, array(points), 0.02)\n", " if i % 10 == 0:\n", " print b, m, computeErrorForLineGivenPoints(b, m, points)\n", "# plt.scatter(x, y)\n", "# plt.plot(arange(0,100), m*arange(0,100)+b)\n", "# plt.show()" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ " 2.90940202215 147.414059472 53091441.0482\n", "2.51063478e+20 1.27737150853e+22 4.06758732977e+47\n", "2.19755719833e+40 1.1180825566e+42 3.11637835829e+87\n", "1.92352056874e+60 9.7865702736e+61 2.38761046405e+127\n", "1.68365646235e+80 8.56617940729e+81 1.8292656002e+167\n", "1.47370354614e+100 7.4979719745e+101 1.40149018714e+207\n", "1.28993187772e+120 6.56297061471e+121 1.07375044086e+247\n", "1.12907664062e+140 5.74456445504e+141 8.22652930299e+286\n", "9.88280142863e+159 5.02821400786e+161 inf\n", "8.65041048269e+179 4.40119286792e+181 inf\n" ] } ], "prompt_number": 55 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] } ], "metadata": {} } ] }