{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple Linear Regression with NumPy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In school, students are taught to draw lines like the following.\n", "\n", "$$ y = 2 x + 1$$\n", "\n", "They're taught to pick two values for $x$ and calculate the corresponding values for $y$ using the equation.\n", "Then they draw a set of axes, plot the points, and then draw a line extending through the two dots on their axes." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import matplotlib.\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Draw some axes.\n", "plt.plot([-1, 10], [0, 0], 'k-')\n", "plt.plot([0, 0], [-1, 10], 'k-')\n", "\n", "# Plot the red, blue and green lines.\n", "plt.plot([1, 1], [-1, 3], 'b:')\n", "plt.plot([-1, 1], [3, 3], 'r:')\n", "\n", "# Plot the two points (1,3) and (2,5).\n", "plt.plot([1, 2], [3, 5], 'ko')\n", "# Join them with an (extending) green lines.\n", "plt.plot([-1, 10], [-1, 21], 'g-')\n", "\n", "# Set some reasonable plot limits.\n", "plt.xlim([-1, 10])\n", "plt.ylim([-1, 10])\n", "\n", "# Show the plot.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simple linear regression is about the opposite problem - what if you have some points and are looking for the equation?\n", "It's easy when the points are perfectly on a line already, but usually real-world data has some noise.\n", "The data might still look roughly linear, but aren't exactly so.\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example (contrived and simulated)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![weights.png](https://github.com/ianmcloughlin/images/raw/master/weights.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Scenario\n", "Suppose you are trying to weigh your suitcase to avoid an airline's extra charges.\n", "You don't have a weighing scales, but you do have a spring and some gym-style weights of masses 7KG, 14KG and 21KG.\n", "You attach the spring to the wall hook, and mark where the bottom of it hangs.\n", "You then hang the 7KG weight on the end and mark where the bottom of the spring is.\n", "You repeat this with the 14KG weight and the 21KG weight.\n", "Finally, you place your case hanging on the spring, and the spring hangs down halfway between the 7KG mark and the 14KG mark.\n", "Is your case over the 10KG limit set by the airline?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Hypothesis\n", "When you look at the marks on the wall, it seems that the 0KG, 7KG, 14KG and 21KG marks are evenly spaced.\n", "You wonder if that means your case weighs 10.5KG.\n", "That is, you wonder if there is a *linear* relationship between the distance the spring's hook is from its resting position, and the mass on the end of it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Experiment\n", "You decide to experiment.\n", "You buy some new weights - a 1KG, a 2KG, a 3Kg, all the way up to 20KG.\n", "You place them each in turn on the spring and measure the distance the spring moves from the resting position.\n", "You tabulate the data and plot them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Analysis\n", "Here we'll import the Python libraries we need for or investigations below." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# numpy efficiently deals with numerical multi-dimensional arrays.\n", "import numpy as np\n", "\n", "# matplotlib is a plotting library, and pyplot is its easy-to-use module.\n", "import matplotlib.pyplot as plt\n", "\n", "# This just sets the default plot size to be bigger.\n", "plt.rcParams['figure.figsize'] = (8, 6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ignore the next couple of lines where I fake up some data. I'll use the fact that I faked the data to explain some results later. Just pretend that w is an array containing the weight values and d are the corresponding distance measurements." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "w = np.arange(0.0, 21.0, 1.0)\n", "d = 5.0 * w + 10.0 + np.random.normal(0.0, 5.0, w.size)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,\n", " 13., 14., 15., 16., 17., 18., 19., 20.])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at w.\n", "w" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 15.99410275, 10.9911427 , 10.53445534, 15.97058549,\n", " 27.44753565, 39.51180328, 43.06251898, 47.47682646,\n", " 52.27808676, 56.67322568, 59.16156245, 56.69603504,\n", " 67.15387974, 80.96017265, 84.50699528, 87.26011678,\n", " 90.77740446, 100.07004239, 107.93739931, 105.51085343,\n", " 118.07859596])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at d.\n", "d" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Let's have a look at the data from our experiment." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Create the plot.\n", "\n", "plt.plot(w, d, 'k.')\n", "\n", "# Set some properties for the plot.\n", "plt.xlabel('Weight (KG)')\n", "plt.ylabel('Distance (CM)')\n", "\n", "# Show the plot.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model\n", "It looks like the data might indeed be linear.\n", "The points don't exactly fit on a straight line, but they are not far off it.\n", "We might put that down to some other factors, such as the air density, or errors, such as in our tape measure.\n", "Then we can go ahead and see what would be the best line to fit the data. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Straight lines\n", "All straight lines can be expressed in the form $y = mx + c$.\n", "The number $m$ is the slope of the line.\n", "The slope is how much $y$ increases by when $x$ is increased by 1.0.\n", "The number $c$ is the y-intercept of the line.\n", "It's the value of $y$ when $x$ is 0." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Fitting the model\n", "To fit a straight line to the data, we just must pick values for $m$ and $c$.\n", "These are called the parameters of the model, and we want to pick the best values possible for the parameters.\n", "That is, the best parameter values *given* the data observed.\n", "Below we show various lines plotted over the data, with different values for $m$ and $c$." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot w versus d with black dots.\n", "plt.plot(w, d, 'k.', label=\"Data\")\n", "\n", "# Overlay some lines on the plot.\n", "x = np.arange(0.0, 21.0, 1.0)\n", "plt.plot(x, 5.0 * x + 10.0, 'r-', label=r\"$5x + 10$\")\n", "plt.plot(x, 6.0 * x + 5.0, 'g-', label=r\"$6x + 5$\")\n", "plt.plot(x, 5.0 * x + 15.0, 'b-', label=r\"$5x + 15$\")\n", "\n", "# Add a legend.\n", "plt.legend()\n", "\n", "# Add axis labels.\n", "plt.xlabel('Weight (KG)')\n", "plt.ylabel('Distance (CM)')\n", "\n", "# Show the plot.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculating the cost\n", "You can see that each of these lines roughly fits the data.\n", "Which one is best, and is there another line that is better than all three?\n", "Is there a \"best\" line?\n", "\n", "It depends how you define the word best.\n", "Luckily, everyone seems to have settled on what the best means.\n", "The best line is the one that minimises the following calculated value.\n", "\n", "$$ \\sum_i (y_i - mx_i - c)^2 $$\n", "\n", "Here $(x_i, y_i)$ is the $i^{th}$ point in the data set and $\\sum_i$ means to sum over all points. \n", "The values of $m$ and $c$ are to be determined.\n", "We usually denote the above as $Cost(m, c)$.\n", "\n", "Where does the above calculation come from?\n", "It's easy to explain the part in the brackets $(y_i - mx_i - c)$.\n", "The corresponding value to $x_i$ in the dataset is $y_i$.\n", "These are the measured values.\n", "The value $m x_i + c$ is what the model says $y_i$ should have been.\n", "The difference between the value that was observed ($y_i$) and the value that the model gives ($m x_i + c$), is $y_i - mx_i - c$.\n", "\n", "Why square that value?\n", "Well note that the value could be positive or negative, and you sum over all of these values.\n", "If we allow the values to be positive or negative, then the positive could cancel the negatives.\n", "So, the natural thing to do is to take the absolute value $\\mid y_i - m x_i - c \\mid$.\n", "Well it turns out that absolute values are a pain to deal with, and instead it was decided to just square the quantity instead, as the square of a number is always positive.\n", "There are pros and cons to using the square instead of the absolute value, but the square is used.\n", "This is usually called *least squares* fitting." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cost with m = 5.00 and c = 10.00: 567.04\n", "Cost with m = 6.00 and c = 5.00: 1073.18\n", "Cost with m = 5.00 and c = 15.00: 911.51\n" ] } ], "source": [ "# Calculate the cost of the lines above for the data above.\n", "cost = lambda m,c: np.sum([(d[i] - m * w[i] - c)**2 for i in range(w.size)])\n", "\n", "print(\"Cost with m = %5.2f and c = %5.2f: %8.2f\" % (5.0, 10.0, cost(5.0, 10.0)))\n", "print(\"Cost with m = %5.2f and c = %5.2f: %8.2f\" % (6.0, 5.0, cost(6.0, 5.0)))\n", "print(\"Cost with m = %5.2f and c = %5.2f: %8.2f\" % (5.0, 15.0, cost(5.0, 15.0)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Minimising the cost\n", "We want to calculate values of $m$ and $c$ that give the lowest value for the cost value above.\n", "For our given data set we can plot the cost value/function.\n", "Recall that the cost is:\n", "\n", "$$ Cost(m, c) = \\sum_i (y_i - mx_i - c)^2 $$\n", "\n", "This is a function of two variables, $m$ and $c$, so a plot of it is three dimensional.\n", "See the **Advanced** section below for the plot.\n", "\n", "In the case of fitting a two-dimensional line to a few data points, we can easily calculate exactly the best values of $m$ and $c$.\n", "Some of the details are discussed in the **Advanced** section, as they involve calculus, but the resulting code is straight-forward.\n", "We first calculate the mean (average) values of our $x$ values and that of our $y$ values.\n", "Then we subtract the mean of $x$ from each of the $x$ values, and the mean of $y$ from each of the $y$ values.\n", "Then we take the *dot product* of the new $x$ values and the new $y$ values and divide it by the dot product of the new $x$ values with themselves.\n", "That gives us $m$, and we use $m$ to calculate $c$.\n", "\n", "Remember that in our dataset $x$ is called $w$ (for weight) and $y$ is called $d$ (for distance).\n", "We calculate $m$ and $c$ below." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "m is 5.395020 and c is 6.909486.\n" ] } ], "source": [ "# Calculate the best values for m and c.\n", "\n", "# First calculate the means (a.k.a. averages) of w and d.\n", "w_avg = np.mean(w)\n", "d_avg = np.mean(d)\n", "\n", "# Subtract means from w and d.\n", "w_zero = w - w_avg\n", "d_zero = d - d_avg\n", "\n", "# The best m is found by the following calculation.\n", "m = np.sum(w_zero * d_zero) / np.sum(w_zero * w_zero)\n", "# Use m from above to calculate the best c.\n", "c = d_avg - m * w_avg\n", "\n", "print(\"m is %8.6f and c is %6.6f.\" % (m, c))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that numpy has a function that will perform this calculation for us, called polyfit.\n", "It can be used to fit lines in many dimensions." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([5.39501974, 6.90948551])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.polyfit(w, d, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Best fit line\n", "So, the best values for $m$ and $c$ given our data and using least squares fitting are about $4.95$ for $m$ and about $11.13$ for $c$.\n", "We plot this line on top of the data below." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot the best fit line.\n", "plt.plot(w, d, 'k.', label='Original data')\n", "plt.plot(w, m * w + c, 'b-', label='Best fit line')\n", "\n", "# Add axis labels and a legend.\n", "plt.xlabel('Weight (KG)')\n", "plt.ylabel('Distance (CM)')\n", "plt.legend()\n", "\n", "# Show the plot.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the $Cost$ of the best $m$ and best $c$ is not zero in this case." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cost with m = 5.40 and c = 6.91: 431.37\n" ] } ], "source": [ "print(\"Cost with m = %5.2f and c = %5.2f: %8.2f\" % (m, c, cost(m, c)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summary\n", "In this notebook we:\n", "1. Investigated the data.\n", "2. Picked a model.\n", "3. Picked a cost function.\n", "4. Estimated the model parameter values that minimised our cost function." ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Advanced\n", "In the following sections we cover some of the more advanced concepts involved in fitting the line." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Simulating data\n", "Earlier in the notebook we glossed over something important: we didn't actually do the weighing and measuring - we faked the data.\n", "A better term for this is *simulation*, which is an important tool in research, especially when testing methods such as simple linear regression.\n", "\n", "We ran the following two commands to do this:\n", "\n", "```python\n", "w = np.arange(0.0, 21.0, 1.0)\n", "d = 5.0 * w + 10.0 + np.random.normal(0.0, 5.0, w.size)\n", "```\n", "\n", "The first command creates a numpy array containing all values between 1.0 and 21.0 (including 1.0 but not including 21.0) in steps of 1.0." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,\n", " 13., 14., 15., 16., 17., 18., 19., 20.])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ " np.arange(0.0, 21.0, 1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The second command is more complex.\n", "First it takes the values in the `w` array, multiplies each by 5.0 and then adds 10.0." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 10., 15., 20., 25., 30., 35., 40., 45., 50., 55., 60.,\n", " 65., 70., 75., 80., 85., 90., 95., 100., 105., 110.])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "5.0 * w + 10.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It then adds an array of the same length containing random values.\n", "The values are taken from what is called the normal distribution with mean 0.0 and standard deviation 5.0." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 2.77924728, -7.36852008, -1.37782612, 9.98075174, 4.21995499,\n", " -2.60231604, 9.32075467, 2.83510953, -9.14550022, 6.52177982,\n", " 11.11216945, 6.69576981, 6.84188875, -3.55449951, 6.14654406,\n", " 0.97382891, -2.89697444, 3.99446255, -1.07846657, -9.81654276,\n", " -3.02725231])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.random.normal(0.0, 5.0, w.size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The normal distribution follows a bell shaped curve.\n", "The curve is centred on the mean (0.0 in this case) and its general width is determined by the standard deviation (5.0 in this case)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot the normal distrution.\n", "normpdf = lambda mu, s, x: (1.0 / (2.0 * np.pi * s**2)) * np.exp(-((x - mu)**2)/(2 * s**2))\n", "\n", "x = np.linspace(-20.0, 20.0, 100)\n", "y = normpdf(0.0, 5.0, x)\n", "plt.plot(x, y)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The idea here is to add a little bit of randomness to the measurements of the distance.\n", "The random values are entered around 0.0, with a greater than 99% chance they're within the range -15.0 to 15.0.\n", "The normal distribution is used because of the [Central Limit Theorem](https://en.wikipedia.org/wiki/Central_limit_theorem) which basically states that when a bunch of random effects happen together the outcome looks roughly like the normal distribution. (Don't quote me on that!)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plotting the cost function\n", "We can plot the cost function for a given set of data points.\n", "Recall that the cost function involves two variables: $m$ and $c$, and that it looks like this:\n", "\n", "$$ Cost(m,c) = \\sum_i (y_i - mx_i - c)^2 $$\n", "\n", "To plot a function of two variables we need a 3D plot.\n", "It can be difficult to get the viewing angle right in 3D plots, but below you can just about make out that there is a low point on the graph around the $(m, c) = (\\approx 5.0, \\approx 10.0)$ point. " ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# This code is a little bit involved - don't worry about it.\n", "# Just look at the plot below.\n", "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "\n", "# Ask pyplot a 3D set of axes.\n", "ax = plt.figure().gca(projection='3d')\n", "\n", "# Make data.\n", "mvals = np.linspace(4.5, 5.5, 100)\n", "cvals = np.linspace(0.0, 20.0, 100)\n", "\n", "# Fill the grid.\n", "mvals, cvals = np.meshgrid(mvals, cvals)\n", "\n", "# Flatten the meshes for convenience.\n", "mflat = np.ravel(mvals)\n", "cflat = np.ravel(cvals)\n", "\n", "# Calculate the cost of each point on the grid.\n", "C = [np.sum([(d[i] - m * w[i] - c)**2 for i in range(w.size)]) for m, c in zip(mflat, cflat)]\n", "C = np.array(C).reshape(mvals.shape)\n", "\n", "# Plot the surface.\n", "surf = ax.plot_surface(mvals, cvals, C)\n", "\n", "# Set the axis labels.\n", "ax.set_xlabel('$m$', fontsize=16)\n", "ax.set_ylabel('$c$', fontsize=16)\n", "ax.set_zlabel('$Cost$', fontsize=16)\n", "\n", "# Show the plot.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Coefficient of determination\n", "Earlier we used a cost function to determine the best line to fit the data.\n", "Usually the data do not perfectly fit on the best fit line, and so the cost is greater than 0.\n", "A quantity closely related to the cost is the *coefficient of determination*, also known as the *R-squared* value.\n", "The purpose of the R-squared value is to measure how much of the variance in $y$ is determined by $x$.\n", "\n", "For instance, in our example the main thing that affects the distance the spring is hanging down is the weight on the end.\n", "It's not the only thing that affects it though.\n", "The room temperature and density of the air at the time of measurment probably affect it a little.\n", "The age of the spring, and how many times it has been stretched previously probably also have a small affect.\n", "There are probably lots of unknown factors affecting the measurment.\n", "\n", "The R-squared value estimates how much of the changes in the $y$ value is due to the changes in the $x$ value compared to all of the other factors affecting the $y$ value.\n", "It is calculated as follows:\n", "\n", "$$ R^2 = 1 - \\frac{\\sum_i (y_i - m x_i - c)^2}{\\sum_i (y_i - \\bar{y})^2} $$\n", "\n", "Note that sometimes the [*Pearson correlation coefficient*](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) is used instead of the R-squared value.\n", "You can just square the Pearson coefficient to get the R-squred value." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The R-squared value is 0.9811\n" ] } ], "source": [ "# Calculate the R-squared value for our data set.\n", "rsq = 1.0 - (np.sum((d - m * w - c)**2)/np.sum((d - d_avg)**2))\n", "\n", "print(\"The R-squared value is %6.4f\" % rsq)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9811159849088361" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The same value using numpy.\n", "np.corrcoef(w, d)[0][1]**2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### The minimisation calculations\n", "Earlier we used the following calculation to calculate $m$ and $c$ for the line of best fit.\n", "The code was:\n", "\n", "```python\n", "w_zero = w - np.mean(w)\n", "d_zero = d - np.mean(d)\n", "\n", "m = np.sum(w_zero * d_zero) / np.sum(w_zero * w_zero)\n", "c = np.mean(d) - m * np.mean(w)\n", "```\n", "\n", "In mathematical notation we write this as:\n", "\n", "$$ m = \\frac{\\sum_i (x_i - \\bar{x}) (y_i - \\bar{y})}{\\sum_i (x_i - \\bar{x})^2} \\qquad \\textrm{and} \\qquad c = \\bar{y} - m \\bar{x} $$\n", "\n", "where $\\bar{x}$ is the mean of $x$ and $\\bar{y}$ that of $y$.\n", "\n", "Where did these equations come from?\n", "They were derived using calculus.\n", "We'll give a brief overview of it here, but feel free to gloss over this section if it's not for you.\n", "If you can understand the first part, where we calculate the partial derivatives, then great!\n", "\n", "The calculations look complex, but if you know basic differentiation, including the chain rule, you can easily derive them.\n", "First, we differentiate the cost function with respect to $m$ while treating $c$ as a constant, called a partial derivative.\n", "We write this as $\\frac{\\partial m}{ \\partial Cost}$, using $\\delta$ as opposed to $d$ to signify that we are treating the other variable as a constant.\n", "We then do the same with respect to $c$ while treating $m$ as a constant.\n", "We set both equal to zero, and then solve them as two simultaneous equations in two variables.\n", "\n", "###### Calculate the partial derivatives\n", "$$\n", "\\begin{align}\n", "Cost(m, c) &= \\sum_i (y_i - mx_i - c)^2 \\\\[1cm]\n", "\\frac{\\partial Cost}{\\partial m} &= \\sum 2(y_i - m x_i -c)(-x_i) \\\\\n", " &= -2 \\sum x_i (y_i - m x_i -c) \\\\[0.5cm]\n", "\\frac{\\partial Cost}{\\partial c} & = \\sum 2(y_i - m x_i -c)(-1) \\\\\n", " & = -2 \\sum (y_i - m x_i -c) \\\\\n", "\\end{align}\n", "$$\n", "\n", "###### Set to zero\n", "$$\n", "\\begin{align}\n", "& \\frac{\\partial Cost}{\\partial m} = 0 \\\\[0.2cm]\n", "& \\Rightarrow -2 \\sum x_i (y_i - m x_i -c) = 0 \\\\\n", "& \\Rightarrow \\sum (x_i y_i - m x_i x_i - x_i c) = 0 \\\\\n", "& \\Rightarrow \\sum x_i y_i - \\sum_i m x_i x_i - \\sum x_i c = 0 \\\\\n", "& \\Rightarrow m \\sum x_i x_i = \\sum x_i y_i - c \\sum x_i \\\\[0.2cm]\n", "& \\Rightarrow m = \\frac{\\sum x_i y_i - c \\sum x_i}{\\sum x_i x_i} \\\\[0.5cm]\n", "& \\frac{\\partial Cost}{\\partial c} = 0 \\\\[0.2cm]\n", "& \\Rightarrow -2 \\sum (y_i - m x_i - c) = 0 \\\\\n", "& \\Rightarrow \\sum y_i - \\sum_i m x_i - \\sum c = 0 \\\\\n", "& \\Rightarrow \\sum y_i - m \\sum_i x_i = c \\sum 1 \\\\\n", "& \\Rightarrow c = \\frac{\\sum y_i - m \\sum x_i}{\\sum 1} \\\\\n", "& \\Rightarrow c = \\frac{\\sum y_i}{\\sum 1} - m \\frac{\\sum x_i}{\\sum 1} \\\\[0.2cm]\n", "& \\Rightarrow c = \\bar{y} - m \\bar{x} \\\\\n", "\\end{align}\n", "$$\n", "\n", "###### Solve the simultaneous equations\n", "Here we let $n$ be the length of $x$, which is also the length of $y$.\n", "\n", "$$\n", "\\begin{align}\n", "& m = \\frac{\\sum_i x_i y_i - c \\sum_i x_i}{\\sum_i x_i x_i} \\\\[0.2cm]\n", "& \\Rightarrow m = \\frac{\\sum x_i y_i - (\\bar{y} - m \\bar{x}) \\sum x_i}{\\sum x_i x_i} \\\\\n", "& \\Rightarrow m \\sum x_i x_i = \\sum x_i y_i - \\bar{y} \\sum x_i + m \\bar{x} \\sum x_i \\\\\n", "& \\Rightarrow m \\sum x_i x_i - m \\bar{x} \\sum x_i = \\sum x_i y_i - \\bar{y} \\sum x_i \\\\[0.3cm]\n", "& \\Rightarrow m = \\frac{\\sum x_i y_i - \\bar{y} \\sum x_i}{\\sum x_i x_i - \\bar{x} \\sum x_i} \\\\[0.2cm]\n", "& \\Rightarrow m = \\frac{\\sum (x_i y_i) - n \\bar{y} \\bar{x}}{\\sum (x_i x_i) - n \\bar{x} \\bar{x}} \\\\\n", "& \\Rightarrow m = \\frac{\\sum (x_i y_i) - n \\bar{y} \\bar{x} - n \\bar{y} \\bar{x} + n \\bar{y} \\bar{x}}{\\sum (x_i x_i) - n \\bar{x} \\bar{x} - n \\bar{x} \\bar{x} + n \\bar{x} \\bar{x}} \\\\\n", "& \\Rightarrow m = \\frac{\\sum (x_i y_i) - \\sum y_i \\bar{x} - \\sum \\bar{y} x_i + n \\bar{y} \\bar{x}}{\\sum (x_i x_i) - \\sum x_i \\bar{x} - \\sum \\bar{x} x_i + n \\bar{x} \\bar{x}} \\\\\n", "& \\Rightarrow m = \\frac{\\sum_i (x_i - \\bar{x}) (y_i - \\bar{y})}{\\sum_i (x_i - \\bar{x})^2} \\\\\n", "\\end{align}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "#### Using sklearn neural networks\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 15.99410275, 15.99136402],\n", " [ 10.9911427 , 10.99025661],\n", " [ 10.53445534, 9.83175482],\n", " [ 15.97058549, 18.21173347],\n", " [ 27.44753565, 26.74674379],\n", " [ 39.51180328, 35.26744107],\n", " [ 43.06251898, 40.53315812],\n", " [ 47.47682646, 45.79886552],\n", " [ 52.27808676, 51.06457291],\n", " [ 56.67322568, 56.33028031],\n", " [ 59.16156245, 61.5959877 ],\n", " [ 56.69603504, 66.86169509],\n", " [ 67.15387974, 72.12740249],\n", " [ 80.96017265, 77.39310988],\n", " [ 84.50699528, 82.65881728],\n", " [ 87.26011678, 87.92452467],\n", " [ 90.77740446, 93.19023207],\n", " [100.07004239, 98.45593946],\n", " [107.93739931, 103.72164686],\n", " [105.51085343, 108.98735425],\n", " [118.07859596, 114.25306165]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sklearn.neural_network as sknn\n", "\n", "# Expects a 2D array of inputs.\n", "w2d = w.reshape(-1, 1)\n", "\n", "# Train the neural network.\n", "regr = sknn.MLPRegressor(max_iter=10000).fit(w2d, d)\n", "\n", "# Show the predictions.\n", "np.array([d, regr.predict(w2d)]).T" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9895666755071412" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The score.\n", "regr.score(w2d, d)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "#### End" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 2 }