{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple linear regression\n", "\n", "In a previous notebook, we solved the problem of simple linear regression - finding a straight line that best fits a data set with two variables.\n", "In that case, we were able to find the exact solution.\n", "In this notebook, we'll use a common technique to approximate that solution.\n", "\n", "Why would we want to approximate a solution when we can easily find an exact solution?\n", "We don't - it's just that the technique we discuss here can also be used in situations where we can't find an exact solution or don't want to try, for whatever reason.\n", "The approximation technique is called gradient descent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preliminaries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# numpy efficiently deals with numerical multi-dimensional arrays.\n", "import numpy as np\n", "\n", "# matplotlib is a plotting library, and pyplot is its easy-to-use module.\n", "import matplotlib.pyplot as pl\n", "\n", "# This just sets the default plot size to be bigger.\n", "pl.rcParams['figure.figsize'] = (16.0, 8.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simple linear regression model\n", "In simple linear regression, we have some data points $(x_i, y_i)$, and we decide that they belong to a straight line with a little bit of error involved.\n", "Straight lines in two dimensions are of the form $y = mx + c$, and to fit a line to our data points we must find appropriate values for $m$ and $c$.\n", "Numpy has a function called `polyfit` that finds such values for us." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best fit is m = 5.404076 and c = 7.755549\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "w = np.arange(1.0, 16.0, 1.0)\n", "d = 5.0 * w + 10.0 + np.random.normal(0.0, 5.0, w.size)\n", "\n", "m, c = np.polyfit(w, d, 1)\n", "print(\"Best fit is m = %f and c = %f\" % (m, c))\n", "\n", "# Plot the best fit line.\n", "pl.plot(w, d, 'k.', label='Original data')\n", "pl.plot(w, m * w + c, 'b-', label='Best fit: $%0.1f x + %0.1f$' % (m,c))\n", "pl.legend()\n", "pl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient descent\n", "In gradient descent, we select a random guess of a parameter and iteratively improve that guess.\n", "For instance, we might pick $1.0$ as our initial guess for $m$ and then create a `for` loop to iteratively improve the value of $m$.\n", "The way we improve $m$ is to first take the partial derivative of our cost function with respect to $m$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cost function\n", "Recall that our cost function for simple linear regression is:\n", "\n", "$$\n", "Cost(m, c) = \\sum_i (y_i - mx_i - c)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculate the partial derivatives\n", "We calculate the partial derivative of $Cost$ with respect to $m$ while treating $c$ as a constant.\n", "Note that the $x_i$ and $y_i$ values are all just constants.\n", "We'll also calculate the partial derivative with respect to $c$ here.\n", "\n", "$$\n", "\\begin{align}\n", "Cost(m, c) &= \\sum_i (y_i - mx_i - c)^2 \\\\[1cm]\n", "\\frac{\\partial Cost}{\\partial m} &= \\sum 2(y_i - m x_i -c)(-x_i) \\\\\n", " &= -2 \\sum x_i (y_i - m x_i -c) \\\\[0.5cm]\n", "\\frac{\\partial Cost}{\\partial c} & = \\sum 2(y_i - m x_i -c)(-1) \\\\\n", " & = -2 \\sum (y_i - m x_i -c) \\\\\n", "\\end{align}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code the partial derivatives\n", "Once we've calculated the partial derivatives, we'll code them up in python.\n", "Here we create two functions, each taking four parameters.\n", "The first two parameters are arrays with our $x_i$ and $y_i$ data set values.\n", "The second two are our current guesses for $m$ and $c$." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def grad_m(x, y, m, c):\n", " return -2.0 * np.sum(x * (y - m * x - c))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def grad_c(x, y, m , c):\n", " return -2.0 * np.sum(y - m * x - c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Iterate\n", "Now we can run our gradient descent algorithm.\n", "For $m$, we keep replacing its value with $m - \\eta grad\\_m(x, y, m, c)$ until it doesn't change.\n", "For $c$, we keep replacing its value with $c - \\eta grad\\_c(x, y, m, c)$ until it doesn't change.\n", "\n", "What is $\\eta$? It is called the learning rate and we set it to a small value relative to the data points.\n", "\n", "You can see on each iteration, $m$ and $c$ are getting closer to their true values." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "m: 5.7149506777702488 c: 4.5518695120877757\n", "m: 5.5629397868829846 c: 6.1183980049248472\n", "m: 5.4852588608786546 c: 6.9189286986499470\n", "m: 5.4455621902673004 c: 7.3280175818245992\n", "m: 5.4252763139437015 c: 7.5370710455442040\n", "m: 5.4149097826482908 c: 7.6439019874021357\n", "m: 5.4096122559659570 c: 7.6984949605832087\n", "m: 5.4069051027099357 c: 7.7263931768575951\n", "m: 5.4055216875445105 c: 7.7406497821666767\n", "m: 5.4048147318094859 c: 7.7479352226744274\n", "m: 5.4044534617990276 c: 7.7516582438453998\n", "m: 5.4042688448364773 c: 7.7535607899748742\n" ] } ], "source": [ "eta = 0.0001\n", "m, c = 1.0, 1.0\n", "delta = 0.0000001\n", "\n", "mold, cold = m - 1.0, c - 1.0\n", "i = 0\n", "while abs(mold - m) > delta and abs(cold - c) > delta:\n", " mold, cold = m, c\n", " m = mold - eta * grad_m(w, d, mold, cold)\n", " c = cold - eta * grad_c(w, d, mold, cold)\n", "\n", " i = i + 1\n", " if i % 1000 == 0:\n", " print(\"m: %20.16f c: %20.16f\" % (m, c))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Newton's method for square roots\n", "Newton's method for square roots is a method for approximating the square root of a number $x$.\n", "We begin with an initial guess $z_0$ of the square root - it doesn't have to be particularly good.\n", "We then apply the following calculation repeatedly, to calculate $z_1$, then $z_2$, and so on:\n", "\n", "$$\n", "z_{i+1} = z_i - \\frac{z_i^2 - x}{2z_i}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Coding the calculation\n", "We can create a function that calculates the next value of $z$ based on the current value of $z$ as follows." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def next_z(x, z):\n", " return z - (z**2 - x) / (2 * z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculating the square root of $x$\n", "Suppose we want to calculate the square root of $x$.\n", "We start with a random guess for the square root, $z_0$.\n", "We then apply the `next_z` function repeatedly until the value of $z$ stops changing.\n", "Let's create a function to do this.\n", "We'll include the next_z function inside the `newtsqrt` function to make it one all-inclusive package." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.75\n", "3.341666666666667\n", "3.316718620116376\n", "3.3166247916826186\n", "3.3166247903554\n" ] }, { "data": { "text/plain": [ "3.3166247903554" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def newtsqrt(x):\n", " next_z = lambda x, z: z - (z**2 - x) / (2 * z)\n", " z = 2.0\n", " n = next_z(x, z)\n", " \n", " while z != n:\n", " z, n = n, next_z(x, n)\n", " print(z)\n", " \n", " return z\n", "\n", "newtsqrt(11)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Comparison with the standard library\n", "We can compare our square root method return value to the value calculated by Python's `math` standard library package. It has a `sqrt` function." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.3166247903554" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import math\n", "math.sqrt(11)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Being careful\n", "Due to the complexities of floating point numbers, the `nextsqrt` function could get into an infinite loop.\n", "For instance, calculating the square root of 10 gives an infinite loop." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Uncommenting will result in infinite loop.\n", "# newtsqrt(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To counteract this problem, the condition of the loop is better written as:\n", "```python\n", "abs(z - n) > 0.001\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient descent for square roots\n", "Newton's method for square roots is efficient, but we can also use gradient descent to approximate the square root of a real number $x$.\n", "Here, we use the following cost function.\n", "\n", "$$\n", "Cost(z \\mid x) = (x - z^2)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example value\n", "Let's use it to calculate the square root of 20, i.e. $x = 20$.\n", "Then the cost function is:\n", "\n", "$$\n", "Cost(z \\mid x=20) = (20 - z^2)^2\n", "$$\n", "\n", "Our goal is to find the $z$ that minimises this." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plotting the cost function\n", "Let's plot the cost function.\n", "Given that we know the best $z$ will be between $4$ and $5$, we'll let $z$ range over 0 to 10." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "i = np.linspace(0.0, 10.0, 1000)\n", "j = (20.0 - i**2)**2\n", "\n", "pl.plot(i, j, 'k-', label='$(20-z^2)^2$')\n", "pl.legend()\n", "pl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The derivative\n", "Looks like there's a low point at about $4.5$.\n", "Let's take the derivative of the cost function with respect to $z$.\n", "\n", "$$\n", "\\begin{align}\n", "Cost(z) &= ( 20.0 - z^2 )^2 \\\\\n", "\\Rightarrow \\frac{\\partial Cost}{\\partial z} &= 2(20.0 - z^2)(-2z) \\\\\n", " &= 4z(z^2 - 20) \\\\\n", " &= 4z^3 - 80z\n", "\\end{align}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The derivative tells us what the slope of the tangent to the curve is at any point on the cost function.\n", "What does that mean?\n", "It means that if we pick a value of $z$, e.g. $8.0$, that the derivative tells us that a line going through the point $(8.0, (20.0 - (8.0)^2)^2)$ with the slope $4(8.0)^3 - 80(8.0)$ perfectly touches the graph above.\n", "Let's plot that.\n", "\n", "When you simply, the point $(8.0, (20.0 - (8.0)^2)^2)$ becomes $(8,1936)$.\n", "The slope is $4(8.0)^3 - 80(8.0)$ which when simplified becomes $1408$.\n", "So, the claim is that the line with slope $1408$ going through the point $(8,1936)$ touches the graph.\n", "To calculate the equation of the line, we'll use $(y - y_1) = m(x - x_1)$:\n", "\n", "$$\n", "y - 1936 = 1408(x - 8) \\\\\n", "\\Rightarrow y = 1408x - 11264 + 1936 \\\\\n", "\\Rightarrow y = 1408x - 9328\n", "$$\n", "\n", "Let's plot that line and the cost function together." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "i = np.linspace(0.0, 10.0, 1000)\n", "j = (20.0 - i**2)**2\n", "k = 1408 * i - 9328\n", "\n", "pl.plot(i, j, 'k-', label='$(20-z^2)^2$')\n", "pl.plot(i, k, 'b-', label='$1408z - 9328$')\n", "pl.legend()\n", "pl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Why do we care about the slope?\n", "It's a bit hard to see, but the blue line is perfectly touching the curve.\n", "We care about this because the slope of the blue line tells us in which way to change $z$ in order to make the cost less.\n", "If we increase $z$ the cost goes up.\n", "If we decrease it the cost goes down." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient descent\n", "Let's use the gradient descent algorithm to calculate the best $z$.\n", "We'll start with the guess $z=8$, and then use the derivative to move $z$ ever so slightly in the direction that decreases the cost.\n", "By ever so slightly, we mean $0.001$ times the slope:\n", "\n", "$$\n", "\\begin{align}\n", "z_{i+1} &= z_i - \\eta \\frac{\\partial Cost}{\\partial z} \\\\\n", " &= z_i - (0.001) (4 z_i^3 - 80 z_i)\n", "\\end{align}\n", "$$\n", "\n", "So, for our initial guess $z_0 = 8.0$ we get:\n", "\n", "$$\n", "\\begin{align}\n", "z_1 &= 8.0 - (0.001) (4 (8.0)^3 - 80 (8.0)) \\\\\n", " &= 8.0 - 1.408 = 6.592\n", "\\end{align}\n", "$$\n", "\n", "Let's code it up." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Current: 8.00000000\tNext: 6.59200000\n", "Current: 6.59200000\tNext: 5.97355269\n", "Current: 5.97355269\tNext: 5.59881186\n", "Current: 5.59881186\tNext: 5.34469983\n", "Current: 5.34469983\tNext: 5.16157297\n", "Current: 5.16157297\tNext: 5.02444369\n", "Current: 5.02444369\tNext: 4.91903017\n", "Current: 4.91903017\tNext: 4.83645229\n", "Current: 4.83645229\tNext: 4.77084541\n", "Current: 4.77084541\tNext: 4.71815685\n", "Current: 4.71815685\tNext: 4.67548576\n", "Current: 4.67548576\tNext: 4.64069702\n", "Current: 4.64069702\tNext: 4.61218330\n", "Current: 4.61218330\tNext: 4.58871218\n", "Current: 4.58871218\tNext: 4.56932433\n", "Current: 4.56932433\tNext: 4.55326362\n", "Current: 4.55326362\tNext: 4.53992784\n", "Current: 4.53992784\tNext: 4.52883326\n", "Current: 4.52883326\tNext: 4.51958845\n", "Current: 4.51958845\tNext: 4.51187478\n", "Current: 4.51187478\tNext: 4.50543157\n", "Current: 4.50543157\tNext: 4.50004463\n", "Current: 4.50004463\tNext: 4.49553736\n", "Current: 4.49553736\tNext: 4.49176369\n", "Current: 4.49176369\tNext: 4.48860255\n", "Current: 4.48860255\tNext: 4.48595333\n", "Current: 4.48595333\tNext: 4.48373229\n", "Current: 4.48373229\tNext: 4.48186965\n", "Current: 4.48186965\tNext: 4.48030717\n", "Current: 4.48030717\tNext: 4.47899619\n", "Current: 4.47899619\tNext: 4.47789603\n", "Square root: 4.477896027970839 \tSquared: 20.05155283731702\n" ] } ], "source": [ "def next_z(z, x, eta=0.001):\n", " return z - eta * (4.0 * z**3 - 80 * z)\n", "\n", "def sqrt_grad_desc(x, z, verbose=False):\n", " while abs(z - next_z(z, x)) > 0.001:\n", " if verbose:\n", " print(\"Current: %14.8f\\tNext: %14.8f\" % (z, next_z(z, x)))\n", " z = next_z(z, x)\n", " return z\n", "\n", "ans =sqrt_grad_desc(20.0, 8.0, True)\n", "print(\"Square root:\", ans, \"\\tSquared:\", ans**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A question\n", "Let's try some other initial guesses: 4.0, 1.0 and -1.0.\n", "Can you explain the square root returned with -1.0?" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "With initial guess 4.00: 4.465951\n", "With initial guess 1.00: 4.465963\n", "With initial guess -1.00: -4.465963\n" ] } ], "source": [ "print(\"With initial guess %6.2f: %10.6f\" % (4.0, sqrt_grad_desc(20.0, 4.0, False)))\n", "print(\"With initial guess %6.2f: %10.6f\" % (1.0, sqrt_grad_desc(20.0, 1.0, False)))\n", "print(\"With initial guess %6.2f: %10.6f\" % (-1.0, sqrt_grad_desc(20.0, -1.0, False)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### End" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }