{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (4.5.0)\n", "Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly) (1.3.3)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly) (1.14.0)\n" ] } ], "source": [ "!pip3 install plotly" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from random import randint\n", "from typing import List\n", "from plotly import graph_objects as go" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# The function we want to run Gradient Descent on\n", "# See: https://en.wikipedia.org/wiki/Paraboloid or https://www.wolframalpha.com/input/?i=x%5E2+%2B+y%5E2\n", "def paraboloid(x: float, y: float) -> float:\n", " return x ** 2 + y ** 2" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "xs: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n", "\n", "ys: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n", "\n", "zs: [[200, 181, 164, 149, 136, 125, 116, 109, 104, 101, 100, 101, 104, 109, 116, 125, 136, 149, 164, 181, 200], [181, 162, 145, 130, 117, 106, 97, 90, 85, 82, 81, 82, 85, 90, 97, 106, 117, 130, 145, 162, 181], [164, 145, 128, 113, 100, 89, 80, 73, 68, 65, 64, 65, 68, 73, 80, 89, 100, 113, 128, 145, 164], [149, 130, 113, 98, 85, 74, 65, 58, 53, 50, 49, 50, 53, 58, 65, 74, 85, 98, 113, 130, 149], [136, 117, 100, 85, 72, 61, 52, 45, 40, 37, 36, 37, 40, 45, 52, 61, 72, 85, 100, 117, 136]] ...\n", "\n" ] } ], "source": [ "# Test data generation (only really necessary for the plotting below)\n", "xs_start = ys_start = -10\n", "xs_stop = ys_stop = 11\n", "xs_step = ys_step = 1\n", "\n", "xs: List[float] = [i for i in range(xs_start, xs_stop, xs_step)]\n", "ys: List[float] = [i for i in range(ys_start, ys_stop, ys_step)]\n", "zs: List[List[float]] = []\n", "\n", "for x in xs:\n", " temp_res: List[float] = []\n", " for y in ys:\n", " result: float = paraboloid(x, y)\n", " temp_res.append(result)\n", " zs.append(temp_res)\n", "\n", "print(f'xs: {xs}\\n')\n", "print(f'ys: {ys}\\n')\n", "print(f'zs: {zs[:5]} ...\\n')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "colorscale": [ [ 0, "#440154" ], [ 0.1111111111111111, "#482878" ], [ 0.2222222222222222, "#3e4989" ], [ 0.3333333333333333, "#31688e" ], [ 0.4444444444444444, "#26828e" ], [ 0.5555555555555556, "#1f9e89" ], [ 0.6666666666666666, "#35b779" ], [ 0.7777777777777778, "#6ece58" ], [ 0.8888888888888888, "#b5de2b" ], [ 1, "#fde725" ] ], "type": "surface", "x": [ -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ], "y": [ -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ], "z": [ [ 200, 181, 164, 149, 136, 125, 116, 109, 104, 101, 100, 101, 104, 109, 116, 125, 136, 149, 164, 181, 200 ], [ 181, 162, 145, 130, 117, 106, 97, 90, 85, 82, 81, 82, 85, 90, 97, 106, 117, 130, 145, 162, 181 ], [ 164, 145, 128, 113, 100, 89, 80, 73, 68, 65, 64, 65, 68, 73, 80, 89, 100, 113, 128, 145, 164 ], [ 149, 130, 113, 98, 85, 74, 65, 58, 53, 50, 49, 50, 53, 58, 65, 74, 85, 98, 113, 130, 149 ], [ 136, 117, 100, 85, 72, 61, 52, 45, 40, 37, 36, 37, 40, 45, 52, 61, 72, 85, 100, 117, 136 ], [ 125, 106, 89, 74, 61, 50, 41, 34, 29, 26, 25, 26, 29, 34, 41, 50, 61, 74, 89, 106, 125 ], [ 116, 97, 80, 65, 52, 41, 32, 25, 20, 17, 16, 17, 20, 25, 32, 41, 52, 65, 80, 97, 116 ], [ 109, 90, 73, 58, 45, 34, 25, 18, 13, 10, 9, 10, 13, 18, 25, 34, 45, 58, 73, 90, 109 ], [ 104, 85, 68, 53, 40, 29, 20, 13, 8, 5, 4, 5, 8, 13, 20, 29, 40, 53, 68, 85, 104 ], [ 101, 82, 65, 50, 37, 26, 17, 10, 5, 2, 1, 2, 5, 10, 17, 26, 37, 50, 65, 82, 101 ], [ 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100 ], [ 101, 82, 65, 50, 37, 26, 17, 10, 5, 2, 1, 2, 5, 10, 17, 26, 37, 50, 65, 82, 101 ], [ 104, 85, 68, 53, 40, 29, 20, 13, 8, 5, 4, 5, 8, 13, 20, 29, 40, 53, 68, 85, 104 ], [ 109, 90, 73, 58, 45, 34, 25, 18, 13, 10, 9, 10, 13, 18, 25, 34, 45, 58, 73, 90, 109 ], [ 116, 97, 80, 65, 52, 41, 32, 25, 20, 17, 16, 17, 20, 25, 32, 41, 52, 65, 80, 97, 116 ], [ 125, 106, 89, 74, 61, 50, 41, 34, 29, 26, 25, 26, 29, 34, 41, 50, 61, 74, 89, 106, 125 ], [ 136, 117, 100, 85, 72, 61, 52, 45, 40, 37, 36, 37, 40, 45, 52, 61, 72, 85, 100, 117, 136 ], [ 149, 130, 113, 98, 85, 74, 65, 58, 53, 50, 49, 50, 53, 58, 65, 74, 85, 98, 113, 130, 149 ], [ 164, 145, 128, 113, 100, 89, 80, 73, 68, 65, 64, 65, 68, 73, 80, 89, 100, 113, 128, 145, 164 ], [ 181, 162, 145, 130, 117, 106, 97, 90, 85, 82, 81, 82, 85, 90, 97, 106, 117, 130, 145, 162, 181 ], [ 200, 181, 164, 149, 136, 125, 116, 109, 104, 101, 100, 101, 104, 109, 116, 125, 136, 149, 164, 181, 200 ] ] } ], "layout": { "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } } } }, "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plotting the generated test data\n", "fig = go.Figure(go.Surface(x=xs, y=ys, z=zs, colorscale='Viridis'))\n", "fig.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# The Gradient is a vector pointing in the direction of greatest increase\n", "# This function computes gradients for our Paraboloid function (defined above)\n", "# See: https://www.wolframalpha.com/input/?i=gradient+of+x%5E2+%2B+y%5E2\n", "def compute_gradient(vec: List[float]) -> List[float]:\n", " assert len(vec) == 2\n", " x: float = vec[0]\n", " y: float = vec[1]\n", " # The derivative of z with respect to x is 2 * x\n", " # The derivative of z with respect to y is 2 * y\n", " return [2 * x, 2 * y]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# This function computes the next position based on the current position, its computed gradient and the learning rate\n", "def compute_step(curr_pos: List[float], learning_rate: float) -> List[float]:\n", " grad: List[float] = compute_gradient(curr_pos)\n", " grad[0] *= -learning_rate\n", " grad[1] *= -learning_rate\n", " next_pos: List[float] = [0, 0]\n", " next_pos[0] = curr_pos[0] + grad[0]\n", " next_pos[1] = curr_pos[1] + grad[1]\n", " return next_pos" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4, 7]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pick a random starting position on the surface of our Paraboloid\n", "start_pos: List[float]\n", "\n", "# Ensure that we don't start at a minimum (0, 0 in our case)\n", "while True:\n", " start_x: float = randint(xs_start, xs_stop)\n", " start_y: float = randint(ys_start, ys_stop)\n", " if start_x != 0 and start_y != 0:\n", " start_pos = [start_x, start_y]\n", " break\n", "\n", "start_pos" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: [3.992, 6.986]\n", "Epoch 500: [1.4671049293897798, 2.5674336264321123]\n", "Epoch 1000: [0.539177573607161, 0.9435607538125328]\n", "Epoch 1500: [0.19815382666720605, 0.34676919666761064]\n", "Epoch 2000: [0.07282376149321286, 0.12744158261312272]\n", "Epoch 2500: [0.026763551969789107, 0.04683621594713106]\n", "Epoch 3000: [0.009835906568851993, 0.017212836495491053]\n", "Epoch 3500: [0.003614806365776567, 0.006325911140109013]\n", "Epoch 4000: [0.0013284820235521919, 0.0023248435412163435]\n", "Epoch 4500: [0.00048823209553084355, 0.00085440616717898]\n", "Best guess for a minimum: [0.00017979037083174428, 0.00031463314895555317]\n" ] } ], "source": [ "epochs: int = 5000\n", "learning_rate: float = 0.001\n", " \n", "best_pos: List[float] = start_pos\n", "\n", "for i in range(0, epochs):\n", " next_pos: List[float] = compute_step(best_pos, learning_rate)\n", " # Print some debug information every once in a while \n", " if i % 500 == 0:\n", " print(f'Epoch {i}: {next_pos}')\n", " best_pos = next_pos \n", "\n", "print(f'Best guess for a minimum: {best_pos}')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }