# Stochastic gradient descent

From the [Data Science from Scratch book](https://www.oreilly.com/library/view/data-science-from/9781492041122/).

## Libraries and helper functions

In [1]:
from typing import List
import random

In [2]:
Vector = List[float]

In [3]:
def add(vector1: Vector, vector2: Vector) -> Vector:
 assert len(vector1) == len(vector2)
 return [v1 + v2 for v1, v2 in zip(vector1, vector2)]


def scalar_multiply(c: float, vector: Vector) -> Vector:
 return [c * v for v in vector]


def gradient_step(v: Vector, gradient: Vector, step_size: float) -> Vector:
 """Return vector adjusted with step. Step is gradient times step size.
 """
 step = scalar_multiply(step_size, gradient)
 return add(v, step)

def linear_gradient(x: float, y: float, theta: Vector) -> Vector:
 slope, intercept = theta
 predicted = slope * x + intercept
 error = (predicted - y) #** 2
 # print(x, y, theta, predicted, error)
 return [2 * error * x, 2 * error]



## Stochastic gradients

Here we use one training example at a time to calculate the gradient steps

In [4]:
inputs = [(x, 20 * x + 5) for x in range(-50, 50)]

theta = [random.uniform(-1, 1), random.uniform(-1, 1)]
learning_rate = 0.001


for epoch in range(100):
 for x, y in inputs:
 grad = linear_gradient(x, y, theta)
 theta = gradient_step(theta, grad, -learning_rate)
 print(epoch, theta)

0 [20.108274621088928, -0.3890550572184463]
1 [20.103628550173042, -0.15784430337372637]
2 [20.09918250047512, 0.06344662581205483]
3 [20.094927182760102, 0.2752433412342787]
4 [20.090854449810823, 0.47795318030884215]
5 [20.086956448727392, 0.6719660044610173]
6 [20.0832257045743, 0.8576549486610185]
7 [20.07965500742264, 1.0353771386943718]
8 [20.076237509713653, 1.2054743780282082]
9 [20.07296662801438, 1.3682738056445862]
10 [20.06983608174483, 1.5240885250583422]
11 [20.06683986242782, 1.6732182068542554]
12 [20.063972181799524, 1.8159496643823287]
13 [20.06122752813499, 1.9525574051756995]
14 [20.05860064833006, 2.083304159961789]
15 [20.056086446554076, 2.2084413869124764]
16 [20.05368014168373, 2.328209755991209]
17 [20.051377044125662, 2.442839611270341]
18 [20.049172772009843, 2.552551414011119]
19 [20.047063092087537, 2.6575561678668613]
20 [20.045043898683737, 2.758055822780714]
21 [20.04311134626141, 2.8542436641396707]
22 [20.04126169735246, 2.9463046849793786]
23 [20.039