In [1]:
"""
This is common code used by multiple code snippets
in this chapter. We have factored it out. It will
be presented only once here.
"""
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
def update_parameters(params, learning_rate):
 """
 Update the current weight and bias values
 from gradient values.
 """
 # Don't track gradients while updating params
 with torch.no_grad():
 for i, p in enumerate(params):
 params[i] = p - learning_rate * p.grad
 
 # Restore tracking of gradient for all params
 for i in range(len(params)):
 params[i].requires_grad = True

In [3]:
def draw_line(m, c, min_x=0, max_x=10,
 color='magenta', label=None):
 """
 Plots y = mx + c from interval (min_x to max_x)
 """
 # linspace creates an array of equally spaced
 # values between the specified min and max in 
 # specified number of steps.
 x = np.linspace(min_x, max_x, 100)
 y = m*x + c
 
 plt.plot(x, y, color=color, 
 label='y=%0.2fx+%0.2f'%(m, c)\
 if not label else label)

In [4]:
def draw_parabola(w0, w1, w2, min_x=0, max_x=10,
 color='magenta', label=None):
 """
 Plots y = w0 + w1*x + w2*x^2 from interval
 (min_x to max_x)
 """
 x = np.linspace(min_x, max_x, 100)
 y = w0 + w1*x + w2* (x**2)
 plt.plot(x, y, color=color, 
 label='y=%0.2f+ %0.2fx + %0.2fx^2'
 %(w0, w1, w2) if not label else label)

In [5]:
def draw_subplot(pos, step,
 true_draw_func, true_draw_params,
 pred_draw_func, pred_draw_params):
 """
 Plots the curves corresponding to a specified pair
 of functions.
 We use it to plot
 (i) the true function (used to generate the observations
 that we are trying to predict with a trained mode)
 vis a vis
 (ii) the model function (used to makes the predictions)
 When the predictor is good, the two plots should
 more or less coincide.
 Thus this is used to visualize the goodness of the
 current approximation.
 """
 plt.subplot(2, 2, pos)
 plt.title('Step %d'%(step))
 true_draw_func(**true_draw_params)
 pred_draw_func(**pred_draw_params)
 plt.xlabel('x')
 plt.ylabel('y')
 plt.legend(loc='upper left')