{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiple Regression from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (4.5.0)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly) (1.14.0)\n", "Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly) (1.3.3)\n" ] } ], "source": [ "!pip3 install plotly" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from random import random\n", "from typing import Union, List\n", "from plotly import express as px\n", "from plotly import graph_objects as go" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# The data set we're using (source: https://miabellaai.net/)\n", "data = [\n", " [1.15, 0.59, 4.18],\n", " [-1.23, 1.65, -1],\n", " [-4.77, 4.02, -1.77],\n", " [0.57, 0.13, 6.06],\n", " [3.29, 3.14, 2.43],\n", " [2.22, 3.58, 4.49],\n", " [-3.98, 1.15, -2.64],\n", " [-0.98, 1.01, 1.87],\n", " [-2.26, 1.09, 3.43],\n", " [-0.96, 0.24, 1.03],\n", " [-2.92, 1.62, 1.8],\n", " [3.88, -1.13, 5.59],\n", " [0.01, 2.66, 4.42],\n", " [3.3, 1.04, 3.7],\n", " [0.44, 0.14, 1.2],\n", " [4.7, -0.73, 6.95],\n", " [-0.05, 1.3, 0.93],\n", " [3.74, -1.46, 3.97],\n", " [-3.69, 2.85, -2.07],\n", " [-4.39, 7.78, -2.8],\n", " [2.95, -1.02, 2.7],\n", " [1.19, -0.35, 4.24],\n", " [3.83, -1.72, 3.25],\n", " [-4.57, 2.72, -0.6],\n", " [-2.07, 5.79, 0.4],\n", " [-1.56, 1.34, -0.61],\n", " [0.85, 0.07, 1.06],\n", " [3.13, -0.98, 2.88],\n", " [-2.22, 0.6, 1.53],\n", " [-2.98, 2.43, 2.04],\n", " [2.59, 4.8, 1.8],\n", " [1.43, -0.91, 2.92],\n", " [-3.48, 2.24, 2.44],\n", " [2.69, 2.38, 7.48],\n", " [0.42, 4.33, 4.32],\n", " [1.75, -0.23, 3.57],\n", " [-4.17, 2.25, -0.3],\n", " [1.35, 0.13, 3.63],\n", " [-3.68, 1.77, -1.43],\n", " [-3.34, 4.32, 3.05],\n", " [-0.79, 0.62, 1.33],\n", " [4.56, -1.85, 3.36],\n", " [-4.25, 6.17, 0.95],\n", " [-2.96, 1.8, 4.44],\n", " [3.36, -1.06, 2.76],\n", " [1.13, 1.79, 4.03],\n", " [0.07, 0.72, 3.46],\n", " [3.94, 4.01, 7.62],\n", " [-0.81, 6.04, 0.31],\n", " [2.21, 4.37, 5.33],\n", " [-3.11, 6.65, -0.5],\n", " [3.88, -1.07, 7.86],\n", " [0.82, -0.46, -0.07],\n", " [4.27, -1.21, 3.77],\n", " [-3.98, 8.22, -2.81],\n", " [-0.54, 0.34, 2.92],\n", " [-1.34, 2.23, 3.63],\n", " [-4.96, 2.03, -2.55],\n", " [3.2, -1.22, 3.18],\n", " [-2.17, 5.18, 1.87],\n", " [-4.13, 7.58, -1.77],\n", " [2.82, 3.2, 7.1],\n", " [-1.16, 1.14, 0.71],\n", " [-4.22, 1.29, 1.58],\n", " [-1.21, 0.9, 0.16],\n", " [-2.53, 1.82, -1.66],\n", " [-3.56, 5.63, -2.12],\n", " [3.39, -0.33, 7.96],\n", " [4.2, -0.8, 3.76],\n", " [0.52, 2.22, 0.51],\n", " [3.86, -0.22, 3.88],\n", " [2.05, 5.4, 1.56],\n", " [1.27, 3.06, 1.48],\n", " [4.81, 0.65, 3.43],\n", " [4.58, -0.91, 7.02],\n", " [3.16, -0.23, 4.17],\n", " [2.51, 0.19, 2.9],\n", " [-4.09, 5.52, -2.09],\n", " [2.61, -0.66, 1.98],\n", " [4.86, 1.16, 5.41],\n", " [4.24, 2.87, 5.67],\n", " [-3.27, 3.01, 1.81],\n", " [-2.43, 3.56, 4.22],\n", " [1.34, 0.17, 3.5],\n", " [-0.74, 1.17, 1.41],\n", " [4.38, -2.08, 4.16],\n", " [4.42, -0.21, 4.72],\n", " [4.87, 2.71, 7.01],\n", " [-1.69, 4.08, -0.38],\n", " [0.34, 0.65, 1.18],\n", " [1.4, 4.44, 0.79],\n", " [4.28, 0.77, 7.04],\n", " [1.36, 3.11, 0.87],\n", " [0.42, 5.54, 2.76],\n", " [0.61, 1.6, 2.93],\n", " [-1.12, 2.63, 1.65],\n", " [0.49, 2.54, -0.23],\n", " [-3.19, 6.53, 2.05],\n", " [-2.45, 4.7, 1.29],\n", " [4.07, -1.54, 2.2]\n", "]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Turn the data into a list of x vectors (one for every pair of x items) and a vector containing all the y items\n", "xs: List[List[float]] = []\n", "ys: List[float] = []\n", "\n", "for item in data:\n", " x1: float = item[0]\n", " x2: float = item[1]\n", " y: float = item[2]\n", " xs.append([x1, x2])\n", " ys.append(y)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# A convenience function which creates a scatter plot with an optional hyperplane\n", "def plot(xs: List[List[float]], ys: List[float], ys_pred: Union[List[float], None] = None) -> None:\n", " # Translate our `xs` and `ys` into data Plotly understands\n", " x: List[float] = [item[0] for item in xs] # x1\n", " y: List[float] = [item[1] for item in xs] # x2\n", " z: List[float] = ys\n", " fig = px.scatter_3d(x=x, y=y, z=z, labels={'x': 'x1', 'y': 'x2', 'z': 'y'})\n", " # If present, add the hyperplane\n", " if ys_pred:\n", " fig.add_trace(\n", " go.Scatter3d(\n", " x=x, y=y, z=ys_pred, name='Guess', surfaceaxis=1\n", " )\n", " )\n", " fig.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hoverlabel": { "namelength": 0 }, "hovertemplate": "x1=%{x}
x2=%{y}
y=%{z}", "legendgroup": "", "marker": { "color": "#636efa", "symbol": "circle" }, "mode": "markers", "name": "", "scene": "scene", "showlegend": false, "type": "scatter3d", "x": [ 1.15, -1.23, -4.77, 0.57, 3.29, 2.22, -3.98, -0.98, -2.26, -0.96, -2.92, 3.88, 0.01, 3.3, 0.44, 4.7, -0.05, 3.74, -3.69, -4.39, 2.95, 1.19, 3.83, -4.57, -2.07, -1.56, 0.85, 3.13, -2.22, -2.98, 2.59, 1.43, -3.48, 2.69, 0.42, 1.75, -4.17, 1.35, -3.68, -3.34, -0.79, 4.56, -4.25, -2.96, 3.36, 1.13, 0.07, 3.94, -0.81, 2.21, -3.11, 3.88, 0.82, 4.27, -3.98, -0.54, -1.34, -4.96, 3.2, -2.17, -4.13, 2.82, -1.16, -4.22, -1.21, -2.53, -3.56, 3.39, 4.2, 0.52, 3.86, 2.05, 1.27, 4.81, 4.58, 3.16, 2.51, -4.09, 2.61, 4.86, 4.24, -3.27, -2.43, 1.34, -0.74, 4.38, 4.42, 4.87, -1.69, 0.34, 1.4, 4.28, 1.36, 0.42, 0.61, -1.12, 0.49, -3.19, -2.45, 4.07 ], "y": [ 0.59, 1.65, 4.02, 0.13, 3.14, 3.58, 1.15, 1.01, 1.09, 0.24, 1.62, -1.13, 2.66, 1.04, 0.14, -0.73, 1.3, -1.46, 2.85, 7.78, -1.02, -0.35, -1.72, 2.72, 5.79, 1.34, 0.07, -0.98, 0.6, 2.43, 4.8, -0.91, 2.24, 2.38, 4.33, -0.23, 2.25, 0.13, 1.77, 4.32, 0.62, -1.85, 6.17, 1.8, -1.06, 1.79, 0.72, 4.01, 6.04, 4.37, 6.65, -1.07, -0.46, -1.21, 8.22, 0.34, 2.23, 2.03, -1.22, 5.18, 7.58, 3.2, 1.14, 1.29, 0.9, 1.82, 5.63, -0.33, -0.8, 2.22, -0.22, 5.4, 3.06, 0.65, -0.91, -0.23, 0.19, 5.52, -0.66, 1.16, 2.87, 3.01, 3.56, 0.17, 1.17, -2.08, -0.21, 2.71, 4.08, 0.65, 4.44, 0.77, 3.11, 5.54, 1.6, 2.63, 2.54, 6.53, 4.7, -1.54 ], "z": [ 4.18, -1, -1.77, 6.06, 2.43, 4.49, -2.64, 1.87, 3.43, 1.03, 1.8, 5.59, 4.42, 3.7, 1.2, 6.95, 0.93, 3.97, -2.07, -2.8, 2.7, 4.24, 3.25, -0.6, 0.4, -0.61, 1.06, 2.88, 1.53, 2.04, 1.8, 2.92, 2.44, 7.48, 4.32, 3.57, -0.3, 3.63, -1.43, 3.05, 1.33, 3.36, 0.95, 4.44, 2.76, 4.03, 3.46, 7.62, 0.31, 5.33, -0.5, 7.86, -0.07, 3.77, -2.81, 2.92, 3.63, -2.55, 3.18, 1.87, -1.77, 7.1, 0.71, 1.58, 0.16, -1.66, -2.12, 7.96, 3.76, 0.51, 3.88, 1.56, 1.48, 3.43, 7.02, 4.17, 2.9, -2.09, 1.98, 5.41, 5.67, 1.81, 4.22, 3.5, 1.41, 4.16, 4.72, 7.01, -0.38, 1.18, 0.79, 7.04, 0.87, 2.76, 2.93, 1.65, -0.23, 2.05, 1.29, 2.2 ] } ], "layout": { "legend": { "tracegroupgap": 0 }, "margin": { "t": 60 }, "scene": { "domain": { "x": [ 0, 1 ], "y": [ 0, 1 ] }, "xaxis": { "title": { "text": "x1" } }, "yaxis": { "title": { "text": "x2" } }, "zaxis": { "title": { "text": "y" } } }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } } } }, "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot(xs, ys)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# The function which predicts a `y` value based on `x` and the `alpha` and `beta` parameters\n", "def predict(alpha: float, beta: List[float], x: List[float]) -> float:\n", " assert len(beta) == len(x)\n", " # Prepare data so that we can easily do a dot product calculation\n", " # Prepend `alpha` to the `beta` vector\n", " beta: List[float] = beta.copy()\n", " beta.insert(0, alpha)\n", " # Prepend a constant (1) to the `x` vector\n", " x: List[float] = x.copy()\n", " x.insert(0, 1)\n", " # Calculate the y value via the dot product (https://en.wikipedia.org/wiki/Dot_product)\n", " return sum([a * b for a, b in zip(x, beta)])\n", "\n", "# (5 * 1) + (1 * 3) + (2 * 4) = 16 <-- the 5 and 1 are the prepended `alpha` and constant values\n", "assert predict(5, [1, 2], [3, 4]) == 16" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# SSE (sum of squared estimate of errors), the function we use to calculate how \"wrong\" we are\n", "# \"How much do the actual y values (`ys`) differ from our predicted y values (`ys_pred`)?\"\n", "def sum_squared_error(ys: List[float], ys_pred: List[float]) -> float:\n", " assert len(ys) == len(ys_pred)\n", " return sum([(y - y_p) ** 2 for y, y_p in zip(ys, ys_pred)])\n", "\n", "assert sum_squared_error([1, 2, 3], [4, 5, 6]) == 27" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting with \"alpha\": 0.11130581731209588\n", "Starting with \"beta\": [0.9391701681748561, 0.18051073908894166]\n", "Epoch 1 --> loss: 634.0034428435199\n", "Epoch 101 --> loss: 540.813687006875\n", "Epoch 201 --> loss: 476.80912166653883\n", "Epoch 301 --> loss: 432.667702845224\n", "Epoch 401 --> loss: 402.139828655792\n", "Epoch 501 --> loss: 381.03068824639536\n", "Epoch 601 --> loss: 366.5184679981156\n", "Epoch 701 --> loss: 356.69807632541597\n", "Epoch 801 --> loss: 350.2761634751978\n", "Epoch 901 --> loss: 346.36753875260837\n", "Best estimate for \"alpha\": 1.4856456435223253\n", "Best estimate for \"beta\": [0.7759898691708914, 0.24225463600936092]\n" ] } ], "source": [ "# Find the best fitting hyperplane through the data points via Gradient Descent\n", "alpha: float = random()\n", "beta: List[float] = [random(), random()]\n", "\n", "print(f'Starting with \"alpha\": {alpha}')\n", "print(f'Starting with \"beta\": {beta}')\n", " \n", "epochs: int = 1000\n", "learning_rate: float = 0.00001\n", "\n", "for epoch in range(epochs):\n", " # Calculate predictions for `y` values given the current `alpha` and `beta`\n", " ys_pred: List[float] = [predict(alpha, beta, x) for x in xs]\n", "\n", " # Calculate and print the error\n", " if epoch % 100 == True:\n", " loss = sum_squared_error(ys, ys_pred)\n", " print(f'Epoch {epoch} --> loss: {loss}')\n", " \n", " # Calculate the gradient\n", " x: List[float]\n", " y: List[float]\n", " # Taking the (partial) derivative of SSE with respect to `alpha` results in `2 (y_pred - y)`\n", " grad_alpha: float = sum([2 * (predict(alpha, beta, x) - y) for x, y in zip(xs, ys)])\n", " # Taking the (partial) derivative of SSE with respect to `beta` results in `2 * x (y_pred - y)`\n", " grad_beta: List[float] = list(range(len(beta)))\n", " for x, y in zip(xs, ys):\n", " error: float = (predict(alpha, beta, x) - y)\n", " for i, x in enumerate(x):\n", " grad_beta[i] = 2 * error * x\n", "\n", " # Take a small step in the direction of greatest decrease\n", " alpha = alpha + (grad_alpha * -learning_rate)\n", " beta = [b + (gb * -learning_rate) for b, gb in zip(beta, grad_beta)]\n", "\n", "print(f'Best estimate for \"alpha\": {alpha}')\n", "print(f'Best estimate for \"beta\": {beta}')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hoverlabel": { "namelength": 0 }, "hovertemplate": "x1=%{x}
x2=%{y}
y=%{z}", "legendgroup": "", "marker": { "color": "#636efa", "symbol": "circle" }, "mode": "markers", "name": "", "scene": "scene", "showlegend": false, "type": "scatter3d", "x": [ 1.15, -1.23, -4.77, 0.57, 3.29, 2.22, -3.98, -0.98, -2.26, -0.96, -2.92, 3.88, 0.01, 3.3, 0.44, 4.7, -0.05, 3.74, -3.69, -4.39, 2.95, 1.19, 3.83, -4.57, -2.07, -1.56, 0.85, 3.13, -2.22, -2.98, 2.59, 1.43, -3.48, 2.69, 0.42, 1.75, -4.17, 1.35, -3.68, -3.34, -0.79, 4.56, -4.25, -2.96, 3.36, 1.13, 0.07, 3.94, -0.81, 2.21, -3.11, 3.88, 0.82, 4.27, -3.98, -0.54, -1.34, -4.96, 3.2, -2.17, -4.13, 2.82, -1.16, -4.22, -1.21, -2.53, -3.56, 3.39, 4.2, 0.52, 3.86, 2.05, 1.27, 4.81, 4.58, 3.16, 2.51, -4.09, 2.61, 4.86, 4.24, -3.27, -2.43, 1.34, -0.74, 4.38, 4.42, 4.87, -1.69, 0.34, 1.4, 4.28, 1.36, 0.42, 0.61, -1.12, 0.49, -3.19, -2.45, 4.07 ], "y": [ 0.59, 1.65, 4.02, 0.13, 3.14, 3.58, 1.15, 1.01, 1.09, 0.24, 1.62, -1.13, 2.66, 1.04, 0.14, -0.73, 1.3, -1.46, 2.85, 7.78, -1.02, -0.35, -1.72, 2.72, 5.79, 1.34, 0.07, -0.98, 0.6, 2.43, 4.8, -0.91, 2.24, 2.38, 4.33, -0.23, 2.25, 0.13, 1.77, 4.32, 0.62, -1.85, 6.17, 1.8, -1.06, 1.79, 0.72, 4.01, 6.04, 4.37, 6.65, -1.07, -0.46, -1.21, 8.22, 0.34, 2.23, 2.03, -1.22, 5.18, 7.58, 3.2, 1.14, 1.29, 0.9, 1.82, 5.63, -0.33, -0.8, 2.22, -0.22, 5.4, 3.06, 0.65, -0.91, -0.23, 0.19, 5.52, -0.66, 1.16, 2.87, 3.01, 3.56, 0.17, 1.17, -2.08, -0.21, 2.71, 4.08, 0.65, 4.44, 0.77, 3.11, 5.54, 1.6, 2.63, 2.54, 6.53, 4.7, -1.54 ], "z": [ 4.18, -1, -1.77, 6.06, 2.43, 4.49, -2.64, 1.87, 3.43, 1.03, 1.8, 5.59, 4.42, 3.7, 1.2, 6.95, 0.93, 3.97, -2.07, -2.8, 2.7, 4.24, 3.25, -0.6, 0.4, -0.61, 1.06, 2.88, 1.53, 2.04, 1.8, 2.92, 2.44, 7.48, 4.32, 3.57, -0.3, 3.63, -1.43, 3.05, 1.33, 3.36, 0.95, 4.44, 2.76, 4.03, 3.46, 7.62, 0.31, 5.33, -0.5, 7.86, -0.07, 3.77, -2.81, 2.92, 3.63, -2.55, 3.18, 1.87, -1.77, 7.1, 0.71, 1.58, 0.16, -1.66, -2.12, 7.96, 3.76, 0.51, 3.88, 1.56, 1.48, 3.43, 7.02, 4.17, 2.9, -2.09, 1.98, 5.41, 5.67, 1.81, 4.22, 3.5, 1.41, 4.16, 4.72, 7.01, -0.38, 1.18, 0.79, 7.04, 0.87, 2.76, 2.93, 1.65, -0.23, 2.05, 1.29, 2.2 ] }, { "name": "Guess", "surfaceaxis": 1, "type": "scatter3d", "x": [ 1.15, -1.23, -4.77, 0.57, 3.29, 2.22, -3.98, -0.98, -2.26, -0.96, -2.92, 3.88, 0.01, 3.3, 0.44, 4.7, -0.05, 3.74, -3.69, -4.39, 2.95, 1.19, 3.83, -4.57, -2.07, -1.56, 0.85, 3.13, -2.22, -2.98, 2.59, 1.43, -3.48, 2.69, 0.42, 1.75, -4.17, 1.35, -3.68, -3.34, -0.79, 4.56, -4.25, -2.96, 3.36, 1.13, 0.07, 3.94, -0.81, 2.21, -3.11, 3.88, 0.82, 4.27, -3.98, -0.54, -1.34, -4.96, 3.2, -2.17, -4.13, 2.82, -1.16, -4.22, -1.21, -2.53, -3.56, 3.39, 4.2, 0.52, 3.86, 2.05, 1.27, 4.81, 4.58, 3.16, 2.51, -4.09, 2.61, 4.86, 4.24, -3.27, -2.43, 1.34, -0.74, 4.38, 4.42, 4.87, -1.69, 0.34, 1.4, 4.28, 1.36, 0.42, 0.61, -1.12, 0.49, -3.19, -2.45, 4.07 ], "y": [ 0.59, 1.65, 4.02, 0.13, 3.14, 3.58, 1.15, 1.01, 1.09, 0.24, 1.62, -1.13, 2.66, 1.04, 0.14, -0.73, 1.3, -1.46, 2.85, 7.78, -1.02, -0.35, -1.72, 2.72, 5.79, 1.34, 0.07, -0.98, 0.6, 2.43, 4.8, -0.91, 2.24, 2.38, 4.33, -0.23, 2.25, 0.13, 1.77, 4.32, 0.62, -1.85, 6.17, 1.8, -1.06, 1.79, 0.72, 4.01, 6.04, 4.37, 6.65, -1.07, -0.46, -1.21, 8.22, 0.34, 2.23, 2.03, -1.22, 5.18, 7.58, 3.2, 1.14, 1.29, 0.9, 1.82, 5.63, -0.33, -0.8, 2.22, -0.22, 5.4, 3.06, 0.65, -0.91, -0.23, 0.19, 5.52, -0.66, 1.16, 2.87, 3.01, 3.56, 0.17, 1.17, -2.08, -0.21, 2.71, 4.08, 0.65, 4.44, 0.77, 3.11, 5.54, 1.6, 2.63, 2.54, 6.53, 4.7, -1.54 ], "z": [ 2.520736104854845, 0.9302012402927735, -1.2434074461559583, 1.9591564057090194, 4.799301874418488, 4.07537628270149, -1.3253299704826313, 0.9692387190508467, -0.004868854096976238, 0.7982749405264222, -0.3887722977488838, 4.223080478257401, 2.1372504918426984, 4.2984626920914515, 1.8606777131388885, 4.956406759503769, 1.7613014393413924, 4.034495314907502, -0.6885195839970125, -0.03782966408963628, 3.527894094541742, 2.324123052668325, 4.0413799583538665, -1.4030238462020261, 1.280898209789131, 0.5989897846404213, 2.1619493265740246, 3.6772902526393114, -0.09249602763674944, -0.23916722433143356, 4.65802774282617, 2.3747742140502597, -0.6732627187602856, 4.149541753459267, 2.8599341693785885, 2.7878346963432663, -1.2063901540328916, 2.5645600116160385, -0.9423241080580893, -0.060843582895259596, 1.0222544000215297, 4.576490831418468, -0.3190947331279339, -0.3762242850546477, 3.8364314332199934, 2.795841945197272, 1.7139697676011951, 5.51451091139138, 2.319405591229004, 4.258945462745775, 0.6819775255491112, 4.237611928743634, 2.0103034265389534, 4.506407013895933, 0.386959278478479, 1.148480581697, 0.9852944967240863, -1.8728373297472136, 3.6734955435321894, 1.0555459496012736, 0.11553336959192673, 4.449038844422286, 0.8610150039976032, -1.477701285268196, 0.7640812785970935, -0.03763232650793119, 0.08567155618294797, 4.036515501542508, 4.551374166047382, 2.4265273408572208, 4.4279509745830365, 4.384216674368256, 3.2120864994981533, 5.376007552289397, 4.819673391751683, 3.8822181377905696, 3.4794352862716837, -0.3523334250042738, 3.3511889179612075, 5.538332804877956, 5.471260892886424, -0.3237824923402022, 0.4613955845039788, 2.566486060582412, 1.194267286267576, 4.3810784134478284, 4.8650216256998835, 5.921490193796754, 1.1616920891938873, 1.9065791953607119, 3.6472104616660133, 4.993706464632735, 3.294050303790055, 3.152985087517618, 2.3462232813861883, 1.252925696503516, 2.4807436560437828, 0.5908219470477201, 0.7219699743601729, 4.27125234239934 ] } ], "layout": { "legend": { "tracegroupgap": 0 }, "margin": { "t": 60 }, "scene": { "domain": { "x": [ 0, 1 ], "y": [ 0, 1 ] }, "xaxis": { "title": { "text": "x1" } }, "yaxis": { "title": { "text": "x2" } }, "zaxis": { "title": { "text": "y" } } }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } } } }, "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot(xs, ys, ys_pred)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }