{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2-final" }, "orig_nbformat": 2, "kernelspec": { "name": "Python 3.8.2 64-bit ('stackoverflow': conda)", "display_name": "Python 3.8.2 64-bit ('stackoverflow': conda)", "metadata": { "interpreter": { "hash": "f3131ab0c53afd587e929ab5f3e0081bb3b1675889ab2c259292a9bd8773e05a" } } } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "source": [ "# Stochastic gradient descent" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "From the [Data Science from Scratch book](https://www.oreilly.com/library/view/data-science-from/9781492041122/)." ], "cell_type": "markdown", "metadata": {} }, { "source": [ "## Libraries and helper functions" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from typing import List\n", "import random" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "Vector = List[float]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def add(vector1: Vector, vector2: Vector) -> Vector:\n", " assert len(vector1) == len(vector2)\n", " return [v1 + v2 for v1, v2 in zip(vector1, vector2)]\n", "\n", "\n", "def scalar_multiply(c: float, vector: Vector) -> Vector:\n", " return [c * v for v in vector]\n", "\n", "\n", "def gradient_step(v: Vector, gradient: Vector, step_size: float) -> Vector:\n", " \"\"\"Return vector adjusted with step. Step is gradient times step size.\n", " \"\"\"\n", " step = scalar_multiply(step_size, gradient)\n", " return add(v, step)\n", "\n", "def linear_gradient(x: float, y: float, theta: Vector) -> Vector:\n", " slope, intercept = theta\n", " predicted = slope * x + intercept\n", " error = (predicted - y) #** 2\n", " # print(x, y, theta, predicted, error)\n", " return [2 * error * x, 2 * error]\n", "\n" ] }, { "source": [ "## Stochastic gradients" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "Here we use one training example at a time to calculate the gradient steps" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0 [20.108274621088928, -0.3890550572184463]\n1 [20.103628550173042, -0.15784430337372637]\n2 [20.09918250047512, 0.06344662581205483]\n3 [20.094927182760102, 0.2752433412342787]\n4 [20.090854449810823, 0.47795318030884215]\n5 [20.086956448727392, 0.6719660044610173]\n6 [20.0832257045743, 0.8576549486610185]\n7 [20.07965500742264, 1.0353771386943718]\n8 [20.076237509713653, 1.2054743780282082]\n9 [20.07296662801438, 1.3682738056445862]\n10 [20.06983608174483, 1.5240885250583422]\n11 [20.06683986242782, 1.6732182068542554]\n12 [20.063972181799524, 1.8159496643823287]\n13 [20.06122752813499, 1.9525574051756995]\n14 [20.05860064833006, 2.083304159961789]\n15 [20.056086446554076, 2.2084413869124764]\n16 [20.05368014168373, 2.328209755991209]\n17 [20.051377044125662, 2.442839611270341]\n18 [20.049172772009843, 2.552551414011119]\n19 [20.047063092087537, 2.6575561678668613]\n20 [20.045043898683737, 2.758055822780714]\n21 [20.04311134626141, 2.8542436641396707]\n22 [20.04126169735246, 2.9463046849793786]\n23 [20.039491436402585, 3.0344159418586236]\n24 [20.037797100119768, 3.118746894557255]\n25 [20.0361754521855, 3.19945973173537]\n26 [20.034623390938133, 3.276709684335131]\n27 [20.03313793652982, 3.3506453237233305]\n28 [20.03171617359179, 3.421408845892839]\n29 [20.030355435914238, 3.4891363463999814]\n30 [20.02905307411413, 3.5539580824049044]\n31 [20.02780658595349, 3.6159987219880825]\n32 [20.026613585687915, 3.6753775847836647]\n33 [20.025471758004617, 3.732208870958282]\n34 [20.024378926653775, 3.7866018810682345]\n35 [20.023332973880006, 3.8386612263190933]\n36 [20.022331891127603, 3.8884870294569893]\n37 [20.021373782454166, 3.936175118240505]\n38 [20.020456760427248, 3.981817208697045]\n39 [20.019579095445714, 4.0255010817522585]\n40 [20.018739087243826, 4.067310752589061]\n41 [20.017935091021968, 4.107326630985043]\n42 [20.017165620461952, 4.14562567751338]\n43 [20.016429140217177, 4.182281550851487]\n44 [20.015724268710606, 4.217364749099938]\n45 [20.015049636287394, 4.250942746103223]\n46 [20.014403952714495, 4.283080120704653]\n47 [20.01378597502875, 4.313838681185705]\n48 [20.013194514460565, 4.343277583993619]\n49 [20.012628414045615, 4.371453447141006]\n50 [20.01208658825722, 4.398420459215999]\n51 [20.011568035372246, 4.424230484791519]\n52 [20.011071730850883, 4.448933163459061]\n53 [20.010596714895236, 4.472576004466116]\n54 [20.01014208044218, 4.495204478772731]\n55 [20.009706939528638, 4.5168621063046]\n56 [20.009290475161325, 4.5375905399679235]\n57 [20.008891875158, 4.557429645750311]\n58 [20.008510370231065, 4.576417578971538]\n59 [20.00814524887334, 4.594590858346274]\n60 [20.007795789585977, 4.611984435823321]\n61 [20.00746133605139, 4.628631763741193]\n62 [20.00714119205082, 4.644564858428533]\n63 [20.00683481657478, 4.659814363112007]\n64 [20.00654157694565, 4.674409606833622]\n65 [20.00626093700885, 4.688378660053233]\n66 [20.0059923009138, 4.701748388289354]\n67 [20.005735205576336, 4.714544504461379]\n68 [20.00548916033653, 4.726791619391082]\n69 [20.00525365468124, 4.738513287317169]\n70 [20.005028243503418, 4.749732051371429]\n71 [20.004812513452155, 4.760469488056405]\n72 [20.0046060409254, 4.770746248364355]\n73 [20.004408419316054, 4.780582096925131]\n74 [20.004219291985194, 4.789995950689194]\n75 [20.004038256610244, 4.799005914654827]\n76 [20.00386500754045, 4.807629317187821]\n77 [20.003699183947933, 4.815882743439104]\n78 [20.00354046799362, 4.823782066495483]\n79 [20.003388552379008, 4.831342478409528]\n80 [20.003243202425267, 4.838578520565049]\n81 [20.003104049179598, 4.8455041097305935]\n82 [20.002970873552083, 4.852132564899691]\n83 [20.0028434121415, 4.858476634395222]\n84 [20.00272141143189, 4.864548519281559]\n85 [20.002604655271004, 4.870359897369896]\n86 [20.002492920554257, 4.875921945826025]\n87 [20.00238595286676, 4.881245361497983]\n88 [20.002283589604943, 4.886340382432449]\n89 [20.0021855928701, 4.8912168074015]\n90 [20.002091842376988, 4.8958840153869385]\n91 [20.002002090950178, 4.90035098289899]\n92 [20.001916203724086, 4.904626300833264]\n93 [20.001833987421463, 4.908718191643037]\n94 [20.00175530458321, 4.912634524898607]\n95 [20.001680007220763, 4.916382832992619]\n96 [20.001607906672934, 4.919970324321837]\n97 [20.001538916328176, 4.923403898248704]\n98 [20.001472898762803, 4.926690159008146]\n99 [20.001409710601003, 4.929835427066589]\n" ] } ], "source": [ "inputs = [(x, 20 * x + 5) for x in range(-50, 50)]\n", "\n", "theta = [random.uniform(-1, 1), random.uniform(-1, 1)]\n", "learning_rate = 0.001\n", "\n", "\n", "for epoch in range(100):\n", " for x, y in inputs:\n", " grad = linear_gradient(x, y, theta)\n", " theta = gradient_step(theta, grad, -learning_rate)\n", " print(epoch, theta)" ] } ] }