{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "! [ -e /content ] && pip install -Uqq fastbook\n", "import fastbook\n", "fastbook.setup_book()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.vision.all import *\n", "from fastbook import *\n", "\n", "matplotlib.rc('image', cmap='Greys')" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "[[chapter_mnist_basics]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Under the Hood: Training a Digit Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Having seen what it looks like to actually train a variety of models in Chapter 2, let’s now look under the hood and see exactly what is going on. We’ll start by using computer vision to introduce fundamental tools and concepts for deep learning.\n", "\n", "To be exact, we'll discuss the roles of arrays and tensors and of broadcasting, a powerful technique for using them expressively. We'll explain stochastic gradient descent (SGD), the mechanism for learning by updating weights automatically. We'll discuss the choice of a loss function for our basic classification task, and the role of mini-batches. We'll also describe the math that a basic neural network is actually doing. Finally, we'll put all these pieces together.\n", "\n", "In future chapters we’ll do deep dives into other applications as well, and see how these concepts and tools generalize. But this chapter is about laying foundation stones. To be frank, that also makes this one of the hardest chapters, because of how these concepts all depend on each other. Like an arch, all the stones need to be in place for the structure to stay up. Also like an arch, once that happens, it's a powerful structure that can support other things. But it requires some patience to assemble.\n", "\n", "Let's begin. The first step is to consider how images are represented in a computer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pixels: The Foundations of Computer Vision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to understand what happens in a computer vision model, we first have to understand how computers handle images. We'll use one of the most famous datasets in computer vision, [MNIST](https://en.wikipedia.org/wiki/MNIST_database), for our experiments. MNIST contains images of handwritten digits, collected by the National Institute of Standards and Technology and collated into a machine learning dataset by Yann Lecun and his colleagues. Lecun used MNIST in 1998 in [Lenet-5](http://yann.lecun.com/exdb/lenet/), the first computer system to demonstrate practically useful recognition of handwritten digit sequences. This was one of the most important breakthroughs in the history of AI." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sidebar: Tenacity and Deep Learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The story of deep learning is one of tenacity and grit by a handful of dedicated researchers. After early hopes (and hype!) neural networks went out of favor in the 1990's and 2000's, and just a handful of researchers kept trying to make them work well. Three of them, Yann Lecun, Yoshua Bengio, and Geoffrey Hinton, were awarded the highest honor in computer science, the Turing Award (generally considered the \"Nobel Prize of computer science\"), in 2018 after triumphing despite the deep skepticism and disinterest of the wider machine learning and statistics community.\n", "\n", "Geoff Hinton has told of how even academic papers showing dramatically better results than anything previously published would be rejected by top journals and conferences, just because they used a neural network. Yann Lecun's work on convolutional neural networks, which we will study in the next section, showed that these models could read handwritten text—something that had never been achieved before. However, his breakthrough was ignored by most researchers, even as it was used commercially to read 10% of the checks in the US!\n", "\n", "In addition to these three Turing Award winners, there are many other researchers who have battled to get us to where we are today. For instance, Jurgen Schmidhuber (who many believe should have shared in the Turing Award) pioneered many important ideas, including working with his student Sepp Hochreiter on the long short-term memory (LSTM) architecture (widely used for speech recognition and other text modeling tasks, and used in the IMDb example in <>). Perhaps most important of all, Paul Werbos in 1974 invented back-propagation for neural networks, the technique shown in this chapter and used universally for training neural networks ([Werbos 1994](https://books.google.com/books/about/The_Roots_of_Backpropagation.html?id=WdR3OOM2gBwC)). His development was almost entirely ignored for decades, but today it is considered the most important foundation of modern AI.\n", "\n", "There is a lesson here for all of us! On your deep learning journey you will face many obstacles, both technical, and (even more difficult) posed by people around you who don't believe you'll be successful. There's one *guaranteed* way to fail, and that's to stop trying. We've seen that the only consistent trait amongst every fast.ai student that's gone on to be a world-class practitioner is that they are all very tenacious." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## End sidebar" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this initial tutorial we are just going to try to create a model that can classify any image as a 3 or a 7. So let's download a sample of MNIST that contains images of just these digits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "Path.BASE_PATH = path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see what's in this directory by using `ls`, a method added by fastai. This method returns an object of a special fastai class called `L`, which has all the same functionality of Python's built-in `list`, plus a lot more. One of its handy features is that, when printed, it displays the count of items, before listing the items themselves (if there are more than 10 items, it just shows the first few):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#9) [Path('cleaned.csv'),Path('item_list.txt'),Path('trained_model.pkl'),Path('models'),Path('valid'),Path('labels.csv'),Path('export.pkl'),Path('history.csv'),Path('train')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MNIST dataset follows a common layout for machine learning datasets: separate folders for the training set and the validation set (and/or test set). Let's see what's inside the training set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('train/7'),Path('train/3')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(path/'train').ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There's a folder of 3s, and a folder of 7s. In machine learning parlance, we say that \"3\" and \"7\" are the *labels* (or targets) in this dataset. Let's take a look in one of these folders (using `sorted` to ensure we all get the same order of files):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#6131) [Path('train/3/10.png'),Path('train/3/10000.png'),Path('train/3/10011.png'),Path('train/3/10031.png'),Path('train/3/10034.png'),Path('train/3/10042.png'),Path('train/3/10052.png'),Path('train/3/1007.png'),Path('train/3/10074.png'),Path('train/3/10091.png')...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "threes = (path/'train'/'3').ls().sorted()\n", "sevens = (path/'train'/'7').ls().sorted()\n", "threes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we might expect, it's full of image files. Let’s take a look at one now. Here’s an image of a handwritten number 3, taken from the famous MNIST dataset of handwritten numbers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA9ElEQVR4nM3Or0sDcRjH8c/pgrfBVBjCgibThiKIyTWbWF1bORhGwxARxH/AbtW0JoIGwzXRYhJhtuFY2q1ocLgbe3sGReTuuWbwkx6+r+/zQ/pncX6q+YOldSe6nG3dn8U/rTQ70L8FCGJUewvxl7NTmezNb8xIkvKugr1HSeMP6SrWOVkoTEuSyh0Gm2n3hQyObMnXnxkempRrvgD+gokzwxFAr7U7YXHZ8x4A/Dl7rbu6D2yl3etcw/F3nZgfRVI7rXM7hMUUqzzBec427x26rkmlkzEEa4nnRqnSOH2F0UUx0ePzlbuqMXAHgN6GY9if5xP8dmtHFfwjuQAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im3_path = threes[1]\n", "im3 = Image.open(im3_path)\n", "im3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we are using the `Image` class from the *Python Imaging Library* (PIL), which is the most widely used Python package for opening, manipulating, and viewing images. Jupyter knows about PIL images, so it displays the image for us automatically.\n", "\n", "In a computer, everything is represented as a number. To view the numbers that make up this image, we have to convert it to a *NumPy array* or a *PyTorch tensor*. For instance, here's what a section of the image looks like, converted to a NumPy array:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 29],\n", " [ 0, 0, 0, 48, 166, 224],\n", " [ 0, 93, 244, 249, 253, 187],\n", " [ 0, 107, 253, 253, 230, 48],\n", " [ 0, 3, 20, 20, 15, 0]], dtype=uint8)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "array(im3)[4:10,4:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `4:10` indicates we requested the rows from index 4 (included) to 10 (not included) and the same for the columns. NumPy indexes from top to bottom and left to right, so this section is located in the top-left corner of the image. Here's the same thing as a PyTorch tensor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 29],\n", " [ 0, 0, 0, 48, 166, 224],\n", " [ 0, 93, 244, 249, 253, 187],\n", " [ 0, 107, 253, 253, 230, 48],\n", " [ 0, 3, 20, 20, 15, 0]], dtype=torch.uint8)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensor(im3)[4:10,4:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can slice the array to pick just the part with the top of the digit in it, and then use a Pandas DataFrame to color-code the values using a gradient, which shows us clearly how the image is created from the pixel values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
0000000000000000000
1000002915019525425525417619315096000
200048166224253253234196253253253253233000
309324424925318746108410194253253233000
401072532532304800000192253253156000
503202015000004322425324574000
600000000002492532451260000
700000001410122325324812400000
800000111662392532532531873000000
90000016248250253253253253232213111200
100000000439898208253253253253187220
" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#hide_output\n", "im3_t = tensor(im3)\n", "df = pd.DataFrame(im3_t[4:15,4:22])\n", "df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that the background white pixels are stored as the number 0, black is the number 255, and shades of gray are between the two. The entire image contains 28 pixels across and 28 pixels down, for a total of 784 pixels. (This is much smaller than an image that you would get from a phone camera, which has millions of pixels, but is a convenient size for our initial learning and experiments. We will build up to bigger, full-color images soon.)\n", "\n", "So, now you've seen what an image looks like to a computer, let's recall our goal: create a model that can recognize 3s and 7s. How might you go about getting a computer to do that?\n", "\n", "> Warning: Stop and Think!: Before you read on, take a moment to think about how a computer might be able to recognize these two different digits. What kinds of features might it be able to look at? How might it be able to identify these features? How could it combine them together? Learning works best when you try to solve problems yourself, rather than just reading somebody else's answers; so step away from this book for a few minutes, grab a piece of paper and pen, and jot some ideas down…" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First Try: Pixel Similarity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, here is a first idea: how about we find the average pixel value for every pixel of the 3s, then do the same for the 7s. This will give us two group averages, defining what we might call the \"ideal\" 3 and 7. Then, to classify an image as one digit or the other, we see which of these two ideal digits the image is most similar to. This certainly seems like it should be better than nothing, so it will make a good baseline." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> jargon: Baseline: A simple model which you are confident should perform reasonably well. It should be very simple to implement, and very easy to test, so that you can then test each of your improved ideas, and make sure they are always better than your baseline. Without starting with a sensible baseline, it is very difficult to know whether your super-fancy models are actually any good. One good approach to creating a baseline is doing what we have done here: think of a simple, easy-to-implement model. Another good approach is to search around to find other people that have solved similar problems to yours, and download and run their code on your dataset. Ideally, try both of these!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step one for our simple model is to get the average of pixel values for each of our two groups. In the process of doing this, we will learn a lot of neat Python numeric programming tricks!\n", "\n", "Let's create a tensor containing all of our 3s stacked together. We already know how to create a tensor containing a single image. To create a tensor containing all the images in a directory, we will first use a Python list comprehension to create a plain list of the single image tensors.\n", "\n", "We will use Jupyter to do some little checks of our work along the way—in this case, making sure that the number of returned items seems reasonable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(6131, 6265)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seven_tensors = [tensor(Image.open(o)) for o in sevens]\n", "three_tensors = [tensor(Image.open(o)) for o in threes]\n", "len(three_tensors),len(seven_tensors)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> note: List Comprehensions: List and dictionary comprehensions are a wonderful feature of Python. Many Python programmers use them every day, including the authors of this book—they are part of \"idiomatic Python.\" But programmers coming from other languages may have never seen them before. There are a lot of great tutorials just a web search away, so we won't spend a long time discussing them now. Here is a quick explanation and example to get you started. A list comprehension looks like this: `new_list = [f(o) for o in a_list if o>0]`. This will return every element of `a_list` that is greater than 0, after passing it to the function `f`. There are three parts here: the collection you are iterating over (`a_list`), an optional filter (`if o>0`), and something to do to each element (`f(o)`). It's not only shorter to write but way faster than the alternative ways of creating the same list with a loop." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll also check that one of the images looks okay. Since we now have tensors (which Jupyter by default will print as values), rather than PIL images (which Jupyter by default will display as images), we need to use fastai's `show_image` function to display it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAADjElEQVR4nO2aPyh9YRjHP/f4k38L5X+ysohsUpTBhEVMJGUyGAwWg0kGkcFqlMFIyv+kSGIwKWUiUvKn5P/9DXrvcR+He+695957+vV8llPnvvd9n77n2/s8z3tOIBgMothYqQ7Ab6ggAhVEoIIIVBBBeoTf/+cUFHC6qQ4RqCACFUSggghUEIEKIlBBBCqIQAURRKpUPeHh4QGAyclJAI6PjwFYXl4GIBgMEgh8FY59fX0A3N7eAlBTUwNAU1MTAC0tLQmNVR0iCEQ4MYupl7m4uABgYmICgJWVFQDOz8/DxhUVFQFQX18fGvMbxcXFAFxeXsYSkhPay7jBkz1ke3sbgLa2NgBeX18BeH9/B6CzsxOAnZ0dAAoLCwFC+4ZlWXx8fISNXVpa8iK0qFGHCDxxyN3dHQBPT09h98vLywGYmpoCoKys7Nc5LMsKu0p6enrijtMN6hCBJ1nm8/MTgOfn57D75mlnZWVFnOPq6gqAxsZGwM5I2dnZAOzu7gJQW1vrJiQ3aJZxgyd7iHFCTk5OzHNUVlYCdmYyzjDVrYfO+BN1iCApvYzk5eUFgM3NTQCGhoZCzsjMzARgenoagIGBgaTGpg4RJMUhpnIdHh4GYH5+HrDrl++0t7cD0NXVlYzQfqAOESSk25WY+iQ/Px8g1LeYqxMlJSUAlJaWAjAyMgLYvY7pg+LAcYKkCCIxRdjJyUno3tjYGAD7+/t//tcIMjc3B0Bubm6sYWhh5oaUOMSJt7c3wHaPScn9/f2O4w8PDwGoq6uLdUl1iBtSUpg5kZGRAUBFRQUAvb29AKyurgKwsLAQNn5tbQ2IyyGOqEMEvnGIxKTV39JrdXV1QtZVhwh8k2Uke3t7ADQ3NwP2sYDh5uYGgIKCgliX0CzjBt/tIWdnZwAMDg4CP51h6pK8vLyErK8OEfhmDzF1RUdHB2AfIhnMEePp6Slg1y1xoHuIG1K6h1xfXwMwOzvL+Pg48PVpxHfMS+6trS3AE2f8iTpE4KlDzBPf2NgA7I9bHh8fATg4OADg6OgIsM807u/vQ3OkpaUB9qvLmZkZIHFZRaIOEXiaZbq7uwFYXFyMOpDW1lYARkdHAWhoaIh6jijRLOMGTx1iPnIxtUQkzEHy+vo6VVVVXwHFf3jsFnWIG3xTqaYAdYgbVBCBCiJQQQQqiCBSL5O0osAvqEMEKohABRGoIAIVRKCCCP4B/PMI7HrW9/wAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_image(three_tensors[1]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For every pixel position, we want to compute the average over all the images of the intensity of that pixel. To do this we first combine all the images in this list into a single three-dimensional tensor. The most common way to describe such a tensor is to call it a *rank-3 tensor*. We often need to stack up individual tensors in a collection into a single tensor. Unsurprisingly, PyTorch comes with a function called `stack` that we can use for this purpose.\n", "\n", "Some operations in PyTorch, such as taking a mean, require us to *cast* our integer types to float types. Since we'll be needing this later, we'll also cast our stacked tensor to `float` now. Casting in PyTorch is as simple as typing the name of the type you wish to cast to, and treating it as a method.\n", "\n", "Generally when images are floats, the pixel values are expected to be between 0 and 1, so we will also divide by 255 here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([6131, 28, 28])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked_sevens = torch.stack(seven_tensors).float()/255\n", "stacked_threes = torch.stack(three_tensors).float()/255\n", "stacked_threes.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Perhaps the most important attribute of a tensor is its *shape*. This tells you the length of each axis. In this case, we can see that we have 6,131 images, each of size 28×28 pixels. There is nothing specifically about this tensor that says that the first axis is the number of images, the second is the height, and the third is the width—the semantics of a tensor are entirely up to us, and how we construct it. As far as PyTorch is concerned, it is just a bunch of numbers in memory.\n", "\n", "The *length* of a tensor's shape is its rank:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(stacked_threes.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is really important for you to commit to memory and practice these bits of tensor jargon: _rank_ is the number of axes or dimensions in a tensor; _shape_ is the size of each axis of a tensor.\n", "\n", "> A: Watch out because the term \"dimension\" is sometimes used in two ways. Consider that we live in \"three-dimensonal space\" where a physical position can be described by a 3-vector `v`. But according to PyTorch, the attribute `v.ndim` (which sure looks like the \"number of dimensions\" of `v`) equals one, not three! Why? Because `v` is a vector, which is a tensor of rank one, meaning that it has only one _axis_ (even if that axis has a length of three). In other words, sometimes dimension is used for the size of an axis (\"space is three-dimensional\"); other times, it is used for the rank, or the number of axes (\"a matrix has two dimensions\"). When confused, I find it helpful to translate all statements into terms of rank, axis, and length, which are unambiguous terms." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also get a tensor's rank directly with `ndim`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked_threes.ndim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can compute what the ideal 3 looks like. We calculate the mean of all the image tensors by taking the mean along dimension 0 of our stacked, rank-3 tensor. This is the dimension that indexes over all the images.\n", "\n", "In other words, for every pixel position, this will compute the average of that pixel over all images. The result will be one value for every pixel position, or a single image. Here it is:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAE1klEQVR4nO2byU8jPRTEf2En7AgQO4gDi9hO8P9fOIE4AGIR+xoU1kAgQAKZA6o4eUMU6O7R983IdbE66XYHv3K98rOJ5fN5PByq/usf8H+DHxADPyAGfkAM/IAY1FT4/l9OQbGvPvQMMfADYuAHxMAPiIEfEINKWSYSaL1kW/t9MWKx2JfX5T6PCp4hBpEwxEb+/f0dgFwuB0A2mwXg6emppH1+fgbg9fWVj48PAGprawGIx+MANDc3A9DU1ARAfX19yX3V1dVAeQb9FJ4hBqEYIkZYJmQyGQBub28BuLi4AGBnZweAvb09AM7PzwFIJpOFPsSEvr4+AMbHxwGYnZ0FYHR0FICuri7AMUiMqar6jHFQpniGGARiSDmtkDZcXV0BsL+/D8Da2hoAu7u7ABwcHACOOalUipeXl5J3tLW1AXB6elrS58LCAgBTU1MA9Pf3A44pYkhQeIYYhGKIMoMYoigXZw+AmprP13R2dgJuvo+NjQGf2qNnbm5uSvp4e3sD4O7uDnC6I41Rn8pK+m1eQyJCJFlG0VDkW1tbARgaGgKgo6MDcFEX5CFyuVwhIx0dHQFwcnICOF3SO8RGsbOc+w0KzxCDUAyRoivSmse6lvIrGymqgqKayWRIJBIAXF9fl/Ste6RD8il6V11dXcn93qlGjEAMURQUFUVP19ISO8/FFGWfx8dH4NNjbGxsALC1tQU4DWloaAAc2wYGBgCXXRobGwHHyrDwDDGIREMsY8QMtVrjKMtcXl4CsLm5CcDq6irr6+uAc7fqc35+HoDe3l7AOVNlsqjWMIW/KdTT/yBCaUglVyjNSKVSgIv+0tISACsrKwAsLy8XHKhYJScqBrS0tABOU6JmhuAZYhBKQyxTBF0rm2ilKp3Q6lcMSSQSBWbIX6gyJv2RPxHburu7S+6zlbOgiLTIbBd9tjygHytBnJ6eBmBkZKTQh50KelZlABWZ2tvbATeFoiol+iljEMnizk4Zu9iTiVIZUNcyZvl8vsAmlR+TySRAwdI/PDwAcHh4CMDw8DDgmGItfFCj5hliEKpAZDWjXDnAFnGkGcXPSyskmipEizlKy/peDBJTrFELWijyDDH4EUMsI2xrmaPoKDVqnlsUM0T3SDO03NcC0pYrldpticFrSEQIpCG2uKxWURLEEEVLGcBuFcRisd88ixiirKN32ixiWesXdxEjlIbYTWw7nxUt6YLdqC4uHIsRcqTb29uA8yF6l5yptEV9e6f6hxDKh2i+q/CjzSQ5UGUCRdEu3IR0Os3x8THgmCEfoj61ua1FXU9PD+BKi9apBoVniMGPGGLnp90qSKfTgCsEyV1KY3SfXcmmUqlCiUDPaAtzcHAQcI50cnIScAUkOVT5FJ9lIkYgDZGiKyrSBm0JCPf394DTA2UQfS6NyWazBdboGMTMzAwAc3NzACwuLgIwMTEBOA0pVw8JCs8Qg0AaomhK2VUA1haBXcvY+8/OzgDnQuPxeEEjdERCtZNymmFLh2Gzi+AZYhCrcIzgyy8rHcOUY1V2kQtVK99SfBRTkbetdMkew4xg+8H/e8h3EIghZW8uc4S70tFu+N3jVGojgGfIdxApQ/4yeIZ8B35ADPyAGFRyqtH+d85fAM8QAz8gBn5ADPyAGPgBMfADYvALMumtb+Vr5kIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mean3 = stacked_threes.mean(0)\n", "show_image(mean3);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "According to this dataset, this is the ideal number 3! (You may not like it, but this is what peak number 3 performance looks like.) You can see how it's very dark where all the images agree it should be dark, but it becomes wispy and blurry where the images disagree. \n", "\n", "Let's do the same thing for the 7s, but put all the steps together at once to save some time:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAElUlEQVR4nO2bSUszWRSGn3JKUGOM84giCA7oQnHj33ejiKCI4MIpxhES5ylx6oW8dTvnMyaWJd1033dzqVRqyLlPnelWgvf3d7yc6v7pG/i3yRvEyBvEyBvEyBvEqKHK/v9yCAo++9ATYuQNYuQNYuQNYuQNYuQNYuQNYuQNYuQNYlQtU42kaj2Wz/YHwaeJY+TvRZUnxOhHhGimNb69vZWNr6+vX25r/Ep1dR9zVl9f/3HDDQ1l29qvUQRFJckTYvQtQiwRmvGXlxcAnp6eALi9vQXg+voagMvLSwAuLi7KxoeHh/A4nUOjrtHU1ARAR0cHAENDQwAMDw8D0NvbC0BraysAjY2NgCPou6R4QowiEaJZfH5+BhwR+XwegIODAwC2t7cB2NvbA+Dw8BCAo6MjAK6urgAoFovhuUSdRhEiMubn5wFYXFwEYGFhoWx/JZ9SqzwhRpEIUXTQrN7f3wNQKBQA2N/fB2B3dxdwpIicm5ubsvMlEgkSiQTgyBB1Oufj4yPgfMXo6GjZtXWc5KNMTKqJkGqZp2ZDnr25uRmA9vZ2AEZGRgDo7u4GnF/Q/nQ6HRKiGRdNq6urgPM3IsFe0+YlUeUJMaqJEJv9aRaUNYqIzs5OAMbGxsr29/f3A44Mbff19QEffkEzrBxlaWkJgLOzM8DlOJlMpmxsa2sDXP4RNbpInhCjb0WZaoTYGkVEaDuVSgGOJM1uQ0NDmNtYKdroXD09PYDzS+l0GnCE/LQa9oQYRap2K1WglhRFDn2/paUF+LPuCIIgPEY5irLb8/NzwNUyqmEGBgbKrql7+akiPTKSNYx+oH64DCKDCXttS6+vr2G43dzcBGB9fR1wj8zU1BTgHhWFbHsuSamCT91/qB81iKyTFSkiwTo6jXo8NIulUolsNgvA8vIyADs7O4CjTKFcox6VuFuKnhCjSIRU8iX2ubUNJdtYUnGYz+dZWVkBYG1tDXCJ2OzsLACTk5OAC7vJZLLs2nGR4gkximUZwhZalZrOkvarpM/lcmxsbAAuzKoQVENoZmYGcOFXfsq2Cn1iFrNiiTJ22/oSjXYZQknY1tYWuVwOcDM/NzcHOB8yODgIuKRO+YdtGVa6t1rlCTGK1YdUyg7tfvmO09NTALLZbLgkoTxjenoagImJCcBlprbMj4sMyRNiFOtityVBks8olUqAawIpGy0UCiEBKt7Gx8cB6OrqAqrnHT4P+SXFSkiljFRkaGnz5OQEcO3BIAjCdqKWF7TwpMrZRpXfei3CE2IUCyGVXouwC1la6jw+PgZcvZJKpcJ2oho/2lZe8tPlhVrlCTH6FR+ihrF8h7peWmy6u7sDXE6RyWTCaKJqVv0O+Y64apVq8oQY/corVYou8hHKTLUtMuQnEolE+OLL3z+D6C++RJUnxCiWarfaYriI0EKVljIVhZLJZLjgpIxVmWmlrvpvyRNiFFSZ3Zr+YmZ9iH3lStFGPqRYLJYdFwRBSJHIkA+xL9HF2CHzfzGrRbEQ8sdBFc5po9JX37cE/EIe4gmpRdUI+d/JE2LkDWLkDWLkDWLkDWLkDWL0F7hnDWZImx+vAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mean7 = stacked_sevens.mean(0)\n", "show_image(mean7);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now pick an arbitrary 3 and measure its *distance* from our \"ideal digits.\"\n", "\n", "> stop: Stop and Think!: How would you calculate how similar a particular image is to each of our ideal digits? Remember to step away from this book and jot down some ideas before you move on! Research shows that recall and understanding improves dramatically when you are engaged with the learning process by solving problems, experimenting, and trying new ideas yourself\n", "\n", "Here's a sample 3:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAADjElEQVR4nO2aPyh9YRjHP/f4k38L5X+ysohsUpTBhEVMJGUyGAwWg0kGkcFqlMFIyv+kSGIwKWUiUvKn5P/9DXrvcR+He+695957+vV8llPnvvd9n77n2/s8z3tOIBgMothYqQ7Ab6ggAhVEoIIIVBBBeoTf/+cUFHC6qQ4RqCACFUSggghUEIEKIlBBBCqIQAURRKpUPeHh4QGAyclJAI6PjwFYXl4GIBgMEgh8FY59fX0A3N7eAlBTUwNAU1MTAC0tLQmNVR0iCEQ4MYupl7m4uABgYmICgJWVFQDOz8/DxhUVFQFQX18fGvMbxcXFAFxeXsYSkhPay7jBkz1ke3sbgLa2NgBeX18BeH9/B6CzsxOAnZ0dAAoLCwFC+4ZlWXx8fISNXVpa8iK0qFGHCDxxyN3dHQBPT09h98vLywGYmpoCoKys7Nc5LMsKu0p6enrijtMN6hCBJ1nm8/MTgOfn57D75mlnZWVFnOPq6gqAxsZGwM5I2dnZAOzu7gJQW1vrJiQ3aJZxgyd7iHFCTk5OzHNUVlYCdmYyzjDVrYfO+BN1iCApvYzk5eUFgM3NTQCGhoZCzsjMzARgenoagIGBgaTGpg4RJMUhpnIdHh4GYH5+HrDrl++0t7cD0NXVlYzQfqAOESSk25WY+iQ/Px8g1LeYqxMlJSUAlJaWAjAyMgLYvY7pg+LAcYKkCCIxRdjJyUno3tjYGAD7+/t//tcIMjc3B0Bubm6sYWhh5oaUOMSJt7c3wHaPScn9/f2O4w8PDwGoq6uLdUl1iBtSUpg5kZGRAUBFRQUAvb29AKyurgKwsLAQNn5tbQ2IyyGOqEMEvnGIxKTV39JrdXV1QtZVhwh8k2Uke3t7ADQ3NwP2sYDh5uYGgIKCgliX0CzjBt/tIWdnZwAMDg4CP51h6pK8vLyErK8OEfhmDzF1RUdHB2AfIhnMEePp6Slg1y1xoHuIG1K6h1xfXwMwOzvL+Pg48PVpxHfMS+6trS3AE2f8iTpE4KlDzBPf2NgA7I9bHh8fATg4OADg6OgIsM807u/vQ3OkpaUB9qvLmZkZIHFZRaIOEXiaZbq7uwFYXFyMOpDW1lYARkdHAWhoaIh6jijRLOMGTx1iPnIxtUQkzEHy+vo6VVVVXwHFf3jsFnWIG3xTqaYAdYgbVBCBCiJQQQQqiCBSL5O0osAvqEMEKohABRGoIAIVRKCCCP4B/PMI7HrW9/wAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a_3 = stacked_threes[1]\n", "show_image(a_3);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How can we determine its distance from our ideal 3? We can't just add up the differences between the pixels of this image and the ideal digit. Some differences will be positive while others will be negative, and these differences will cancel out, resulting in a situation where an image that is too dark in some places and too light in others might be shown as having zero total differences from the ideal. That would be misleading!\n", "\n", "To avoid this, there are two main ways data scientists measure distance in this context:\n", "\n", "- Take the mean of the *absolute value* of differences (absolute value is the function that replaces negative values with positive values). This is called the *mean absolute difference* or *L1 norm*\n", "- Take the mean of the *square* of differences (which makes everything positive) and then take the *square root* (which undoes the squaring). This is called the *root mean squared error* (RMSE) or *L2 norm*.\n", "\n", "> important: It's Okay to Have Forgotten Your Math: In this book we generally assume that you have completed high school math, and remember at least some of it... But everybody forgets some things! It all depends on what you happen to have had reason to practice in the meantime. Perhaps you have forgotten what a _square root_ is, or exactly how they work. No problem! Any time you come across a maths concept that is not explained fully in this book, don't just keep moving on; instead, stop and look it up. Make sure you understand the basic idea, how it works, and why we might be using it. One of the best places to refresh your understanding is Khan Academy. For instance, Khan Academy has a great [introduction to square roots](https://www.khanacademy.org/math/algebra/x2f8bb11595b61c86:rational-exponents-radicals/x2f8bb11595b61c86:radicals/v/understanding-square-roots)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try both of these now:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1114), tensor(0.2021))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_3_abs = (a_3 - mean3).abs().mean()\n", "dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n", "dist_3_abs,dist_3_sqr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1586), tensor(0.3021))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_7_abs = (a_3 - mean7).abs().mean()\n", "dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n", "dist_7_abs,dist_7_sqr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In both cases, the distance between our 3 and the \"ideal\" 3 is less than the distance to the ideal 7. So our simple model will give the right prediction in this case." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch already provides both of these as *loss functions*. You'll find these inside `torch.nn.functional`, which the PyTorch team recommends importing as `F` (and is available by default under that name in fastai):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1586), tensor(0.3021))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.l1_loss(a_3.float(),mean7), F.mse_loss(a_3,mean7).sqrt()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here `mse` stands for *mean squared error*, and `l1` refers to the standard mathematical jargon for *mean absolute value* (in math it's called the *L1 norm*)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> S: Intuitively, the difference between L1 norm and mean squared error (MSE) is that the latter will penalize bigger mistakes more heavily than the former (and be more lenient with small mistakes)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> J: When I first came across this \"L1\" thingie, I looked it up to see what on earth it meant. I found on Google that it is a _vector norm_ using _absolute value_, so looked up _vector norm_ and started reading: _Given a vector space V over a field F of the real or complex numbers, a norm on V is a nonnegative-valued any function p: V → \\[0,+∞) with the following properties: For all a ∈ F and all u, v ∈ V, p(u + v) ≤ p(u) + p(v)..._ Then I stopped reading. \"Ugh, I'll never understand math!\" I thought, for the thousandth time. Since then I've learned that every time these complex mathy bits of jargon come up in practice, it turns out I can replace them with a tiny bit of code! Like, the _L1 loss_ is just equal to `(a-b).abs().mean()`, where `a` and `b` are tensors. I guess mathy folks just think differently than me... I'll make sure in this book that every time some mathy jargon comes up, I'll give you the little bit of code it's equal to as well, and explain in common-sense terms what's going on." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We just completed various mathematical operations on PyTorch tensors. If you've done some numeric programming in NumPy before, you may recognize these as being similar to NumPy arrays. Let's have a look at those two very important data structures." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NumPy Arrays and PyTorch Tensors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[NumPy](https://numpy.org/) is the most widely used library for scientific and numeric programming in Python. It provides very similar functionality and a very similar API to that provided by PyTorch; however, it does not support using the GPU or calculating gradients, which are both critical for deep learning. Therefore, in this book we will generally use PyTorch tensors instead of NumPy arrays, where possible.\n", "\n", "(Note that fastai adds some features to NumPy and PyTorch to make them a bit more similar to each other. If any code in this book doesn't work on your computer, it's possible that you forgot to include a line like this at the start of your notebook: `from fastai.vision.all import *`.)\n", "\n", "But what are arrays and tensors, and why should you care?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Python is slow compared to many languages. Anything fast in Python, NumPy, or PyTorch is likely to be a wrapper for a compiled object written (and optimized) in another language—specifically C. In fact, **NumPy arrays and PyTorch tensors can finish computations many thousands of times faster than using pure Python.**\n", "\n", "A NumPy array is a multidimensional table of data, with all items of the same type. Since that can be any type at all, they can even be arrays of arrays, with the innermost arrays potentially being different sizes—this is called a \"jagged array.\" By \"multidimensional table\" we mean, for instance, a list (dimension of one), a table or matrix (dimension of two), a \"table of tables\" or \"cube\" (dimension of three), and so forth. If the items are all of some simple type such as integer or float, then NumPy will store them as a compact C data structure in memory. This is where NumPy shines. NumPy has a wide variety of operators and methods that can run computations on these compact structures at the same speed as optimized C, because they are written in optimized C.\n", "\n", "A PyTorch tensor is nearly the same thing as a NumPy array, but with an additional restriction that unlocks some additional capabilities. It's the same in that it, too, is a multidimensional table of data, with all items of the same type. However, the restriction is that a tensor cannot use just any old type—it has to use a single basic numeric type for all components. For example, a PyTorch tensor cannot be jagged. It is always a regularly shaped multidimensional rectangular structure.\n", "\n", "The vast majority of methods and operators supported by NumPy on these structures are also supported by PyTorch, but PyTorch tensors have additional capabilities. One major capability is that these structures can live on the GPU, in which case their computation will be optimized for the GPU and can run much faster (given lots of values to work on). In addition, PyTorch can automatically calculate derivatives of these operations, including combinations of operations. As you'll see, it would be impossible to do deep learning in practice without this capability.\n", "\n", "> S: If you don't know what C is, don't worry as you won't need it at all. In a nutshell, it's a low-level (low-level means more similar to the language that computers use internally) language that is very fast compared to Python. To take advantage of its speed while programming in Python, try to avoid as much as possible writing loops, and replace them by commands that work directly on arrays or tensors.\n", "\n", "Perhaps the most important new coding skill for a Python programmer to learn is how to effectively use the array/tensor APIs. We will be showing lots more tricks later in this book, but here's a summary of the key things you need to know for now." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create an array or tensor, pass a list (or list of lists, or list of lists of lists, etc.) to `array()` or `tensor()`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = [[1,2,3],[4,5,6]]\n", "arr = array (data)\n", "tns = tensor(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2, 3],\n", " [4, 5, 6]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arr # numpy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2, 3],\n", " [4, 5, 6]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns # pytorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All the operations that follow are shown on tensors, but the syntax and results for NumPy arrays is identical.\n", "\n", "You can select a row (note that, like lists in Python, tensors are 0-indexed so 1 refers to the second row/column):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([4, 5, 6])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns[1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "or a column, by using `:` to indicate *all of the first axis* (we sometimes refer to the dimensions of tensors/arrays as *axes*):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2, 5])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns[:,1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can combine these with Python slice syntax (`[start:end]` with `end` being excluded) to select part of a row or column:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 6])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns[1,1:3]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And you can use the standard operators such as `+`, `-`, `*`, `/`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2, 3, 4],\n", " [5, 6, 7]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns+1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensors have a type:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'torch.LongTensor'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns.type()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And will automatically change type as needed, for example from `int` to `float`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.5000, 3.0000, 4.5000],\n", " [6.0000, 7.5000, 9.0000]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns*1.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, is our baseline model any good? To quantify this, we must define a metric." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing Metrics Using Broadcasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Recall that a metric is a number that is calculated based on the predictions of our model, and the correct labels in our dataset, in order to tell us how good our model is. For instance, we could use either of the functions we saw in the previous section, mean squared error, or mean absolute error, and take the average of them over the whole dataset. However, neither of these are numbers that are very understandable to most people; in practice, we normally use *accuracy* as the metric for classification models.\n", "\n", "As we've discussed, we want to calculate our metric over a *validation set*. This is so that we don't inadvertently overfit—that is, train a model to work well only on our training data. This is not really a risk with the pixel similarity model we're using here as a first try, since it has no trained components, but we'll use a validation set anyway to follow normal practices and to be ready for our second try later.\n", "\n", "To get a validation set we need to remove some of the data from training entirely, so it is not seen by the model at all. As it turns out, the creators of the MNIST dataset have already done this for us. Do you remember how there was a whole separate directory called *valid*? That's what this directory is for!\n", "\n", "So to start with, let's create tensors for our 3s and 7s from that directory. These are the tensors we will use to calculate a metric measuring the quality of our first-try model, which measures distance from an ideal image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_3_tens = torch.stack([tensor(Image.open(o)) \n", " for o in (path/'valid'/'3').ls()])\n", "valid_3_tens = valid_3_tens.float()/255\n", "valid_7_tens = torch.stack([tensor(Image.open(o)) \n", " for o in (path/'valid'/'7').ls()])\n", "valid_7_tens = valid_7_tens.float()/255\n", "valid_3_tens.shape,valid_7_tens.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's good to get in the habit of checking shapes as you go. Here we see two tensors, one representing the 3s validation set of 1,010 images of size 28×28, and one representing the 7s validation set of 1,028 images of size 28×28.\n", "\n", "We ultimately want to write a function, `is_3`, that will decide if an arbitrary image is a 3 or a 7. It will do this by deciding which of our two \"ideal digits\" this arbitrary image is closer to. For that we need to define a notion of distance—that is, a function that calculates the distance between two images.\n", "\n", "We can write a simple function that calculates the mean absolute error using an expression very similar to the one we wrote in the last section:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.1114)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))\n", "mnist_distance(a_3, mean3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the same value we previously calculated for the distance between these two images, the ideal 3 `mean3` and the arbitrary sample 3 `a_3`, which are both single-image tensors with a shape of `[28,28]`.\n", "\n", "But in order to calculate a metric for overall accuracy, we will need to calculate the distance to the ideal 3 for _every_ image in the validation set. How do we do that calculation? We could write a loop over all of the single-image tensors that are stacked within our validation set tensor, `valid_3_tens`, which has a shape of `[1010,28,28]` representing 1,010 images. But there is a better way.\n", "\n", "Something very interesting happens when we take this exact same distance function, designed for comparing two single images, but pass in as an argument `valid_3_tens`, the tensor that represents the 3s validation set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.1050, 0.1526, 0.1186, ..., 0.1122, 0.1170, 0.1086]),\n", " torch.Size([1010]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_3_dist = mnist_distance(valid_3_tens, mean3)\n", "valid_3_dist, valid_3_dist.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of complaining about shapes not matching, it returned the distance for every single image as a vector (i.e., a rank-1 tensor) of length 1,010 (the number of 3s in our validation set). How did that happen?\n", "\n", "Take another look at our function `mnist_distance`, and you'll see we have there the subtraction `(a-b)`. The magic trick is that PyTorch, when it tries to perform a simple subtraction operation between two tensors of different ranks, will use *broadcasting*. That is, it will automatically expand the tensor with the smaller rank to have the same size as the one with the larger rank. Broadcasting is an important capability that makes tensor code much easier to write.\n", "\n", "After broadcasting so the two argument tensors have the same rank, PyTorch applies its usual logic for two tensors of the same rank: it performs the operation on each corresponding element of the two tensors, and returns the tensor result. For instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2, 3, 4])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensor([1,2,3]) + tensor(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So in this case, PyTorch treats `mean3`, a rank-2 tensor representing a single image, as if it were 1,010 copies of the same image, and then subtracts each of those copies from each 3 in our validation set. What shape would you expect this tensor to have? Try to figure it out yourself before you look at the answer below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1010, 28, 28])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(valid_3_tens-mean3).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are calculating the difference between our \"ideal 3\" and each of the 1,010 3s in the validation set, for each of 28×28 images, resulting in the shape `[1010,28,28]`.\n", "\n", "There are a couple of important points about how broadcasting is implemented, which make it valuable not just for expressivity but also for performance:\n", "\n", "- PyTorch doesn't *actually* copy `mean3` 1,010 times. It *pretends* it were a tensor of that shape, but doesn't actually allocate any additional memory\n", "- It does the whole calculation in C (or, if you're using a GPU, in CUDA, the equivalent of C on the GPU), tens of thousands of times faster than pure Python (up to millions of times faster on a GPU!).\n", "\n", "This is true of all broadcasting and elementwise operations and functions done in PyTorch. *It's the most important technique for you to know to create efficient PyTorch code.*\n", "\n", "Next in `mnist_distance` we see `abs`. You might be able to guess now what this does when applied to a tensor. It applies the method to each individual element in the tensor, and returns a tensor of the results (that is, it applies the method \"elementwise\"). So in this case, we'll get back 1,010 matrices of absolute values.\n", "\n", "Finally, our function calls `mean((-1,-2))`. The tuple `(-1,-2)` represents a range of axes. In Python, `-1` refers to the last element, and `-2` refers to the second-to-last. So in this case, this tells PyTorch that we want to take the mean ranging over the values indexed by the last two axes of the tensor. The last two axes are the horizontal and vertical dimensions of an image. After taking the mean over the last two axes, we are left with just the first tensor axis, which indexes over our images, which is why our final size was `(1010)`. In other words, for every image, we averaged the intensity of all the pixels in that image.\n", "\n", "We'll be learning lots more about broadcasting throughout this book, especially in <>, and will be practicing it regularly too.\n", "\n", "We can use `mnist_distance` to figure out whether an image is a 3 or not by using the following logic: if the distance between the digit in question and the ideal 3 is less than the distance to the ideal 7, then it's a 3. This function will automatically do broadcasting and be applied elementwise, just like all PyTorch functions and operators:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def is_3(x): return mnist_distance(x,mean3) < mnist_distance(x,mean7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's test it on our example case:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(True), tensor(1.))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "is_3(a_3), is_3(a_3).float()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that when we convert the Boolean response to a float, we get `1.0` for `True` and `0.0` for `False`. Thanks to broadcasting, we can also test it on the full validation set of 3s:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([True, True, True, ..., True, True, True])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "is_3(valid_3_tens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can calculate the accuracy for each of the 3s and 7s by taking the average of that function for all 3s and its inverse for all 7s:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9168), tensor(0.9854), tensor(0.9511))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_3s = is_3(valid_3_tens).float() .mean()\n", "accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()\n", "\n", "accuracy_3s,accuracy_7s,(accuracy_3s+accuracy_7s)/2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This looks like a pretty good start! We're getting over 90% accuracy on both 3s and 7s, and we've seen how to define a metric conveniently using broadcasting.\n", "\n", "But let's be honest: 3s and 7s are very different-looking digits. And we're only classifying 2 out of the 10 possible digits so far. So we're going to need to do better!\n", "\n", "To do better, perhaps it is time to try a system that does some real learning—that is, that can automatically modify itself to improve its performance. In other words, it's time to talk about the training process, and SGD." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stochastic Gradient Descent (SGD)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Do you remember the way that Arthur Samuel described machine learning, which we quoted in <>?\n", "\n", "> : Suppose we arrange for some automatic means of testing the effectiveness of any current weight assignment in terms of actual performance and provide a mechanism for altering the weight assignment so as to maximize the performance. We need not go into the details of such a procedure to see that it could be made entirely automatic and to see that a machine so programmed would \"learn\" from its experience.\n", "\n", "As we discussed, this is the key to allowing us to have a model that can get better and better—that can learn. But our pixel similarity approach does not really do this. We do not have any kind of weight assignment, or any way of improving based on testing the effectiveness of a weight assignment. In other words, we can't really improve our pixel similarity approach by modifying a set of parameters. In order to take advantage of the power of deep learning, we will first have to represent our task in the way that Arthur Samuel described it.\n", "\n", "Instead of trying to find the similarity between an image and an \"ideal image,\" we could instead look at each individual pixel and come up with a set of weights for each one, such that the highest weights are associated with those pixels most likely to be black for a particular category. For instance, pixels toward the bottom right are not very likely to be activated for a 7, so they should have a low weight for a 7, but they are likely to be activated for an 8, so they should have a high weight for an 8. This can be represented as a function and set of weight values for each possible category—for instance the probability of being the number 8:\n", "\n", "```\n", "def pr_eight(x,w): return (x*w).sum()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we are assuming that `x` is the image, represented as a vector—in other words, with all of the rows stacked up end to end into a single long line. And we are assuming that the weights are a vector `w`. If we have this function, then we just need some way to update the weights to make them a little bit better. With such an approach, we can repeat that step a number of times, making the weights better and better, until they are as good as we can make them.\n", "\n", "We want to find the specific values for the vector `w` that causes the result of our function to be high for those images that are actually 8s, and low for those images that are not. Searching for the best vector `w` is a way to search for the best function for recognising 8s. (Because we are not yet using a deep neural network, we are limited by what our function can actually do—we are going to fix that constraint later in this chapter.) \n", "\n", "To be more specific, here are the steps that we are going to require, to turn this function into a machine learning classifier:\n", "\n", "1. *Initialize* the weights.\n", "1. For each image, use these weights to *predict* whether it appears to be a 3 or a 7.\n", "1. Based on these predictions, calculate how good the model is (its *loss*).\n", "1. Calculate the *gradient*, which measures for each weight, how changing that weight would change the loss\n", "1. *Step* (that is, change) all the weights based on that calculation.\n", "1. Go back to the step 2, and *repeat* the process.\n", "1. Iterate until you decide to *stop* the training process (for instance, because the model is good enough or you don't want to wait any longer)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These seven steps, illustrated in <>, are the key to the training of all deep learning models. That deep learning turns out to rely entirely on these steps is extremely surprising and counterintuitive. It's amazing that this process can solve such complex problems. But, as you'll see, it really does!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "\n", "init\n", "\n", "init\n", "\n", "\n", "\n", "predict\n", "\n", "predict\n", "\n", "\n", "\n", "init->predict\n", "\n", "\n", "\n", "\n", "\n", "loss\n", "\n", "loss\n", "\n", "\n", "\n", "predict->loss\n", "\n", "\n", "\n", "\n", "\n", "gradient\n", "\n", "gradient\n", "\n", "\n", "\n", "loss->gradient\n", "\n", "\n", "\n", "\n", "\n", "step\n", "\n", "step\n", "\n", "\n", "\n", "gradient->step\n", "\n", "\n", "\n", "\n", "\n", "step->predict\n", "\n", "\n", "repeat\n", "\n", "\n", "\n", "stop\n", "\n", "stop\n", "\n", "\n", "\n", "step->stop\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#id gradient_descent\n", "#caption The gradient descent process\n", "#alt Graph showing the steps for Gradient Descent\n", "gv('''\n", "init->predict->loss->gradient->step->stop\n", "step->predict[label=repeat]\n", "''')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are many different ways to do each of these seven steps, and we will be learning about them throughout the rest of this book. These are the details that make a big difference for deep learning practitioners, but it turns out that the general approach to each one generally follows some basic principles. Here are a few guidelines:\n", "\n", "- Initialize:: We initialize the parameters to random values. This may sound surprising. There are certainly other choices we could make, such as initializing them to the percentage of times that pixel is activated for that category—but since we already know that we have a routine to improve these weights, it turns out that just starting with random weights works perfectly well.\n", "- Loss:: This is what Samuel referred to when he spoke of *testing the effectiveness of any current weight assignment in terms of actual performance*. We need some function that will return a number that is small if the performance of the model is good (the standard approach is to treat a small loss as good, and a large loss as bad, although this is just a convention).\n", "- Step:: A simple way to figure out whether a weight should be increased a bit, or decreased a bit, would be just to try it: increase the weight by a small amount, and see if the loss goes up or down. Once you find the correct direction, you could then change that amount by a bit more, and a bit less, until you find an amount that works well. However, this is slow! As we will see, the magic of calculus allows us to directly figure out in which direction, and by roughly how much, to change each weight, without having to try all these small changes. The way to do this is by calculating *gradients*. This is just a performance optimization, we would get exactly the same results by using the slower manual process as well.\n", "- Stop:: Once we've decided how many epochs to train the model for (a few suggestions for this were given in the earlier list), we apply that decision. This is where that decision is applied. For our digit classifier, we would keep training until the accuracy of the model started getting worse, or we ran out of time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before applying these steps to our image classification problem, let's illustrate what they look like in a simpler case. First we will define a very simple function, the quadratic—let's pretend that this is our loss function, and `x` is a weight parameter of the function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f(x): return x**2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is a graph of that function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEMCAYAAADeYiHoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3yV5f3/8dcnm0xWElYGYQ9ZBmQ5aVW0ai1UQdwojmptrbVLq1XbX7Xj2yUqiiI46ihK3a1VVBCRsIQgG0ICJCSMkD0/vz/OoY3xBE5Czn2fJJ/n43E/esaVc7+9m5wP133d93WJqmKMMcY0FuJ2AGOMMcHJCoQxxhifrEAYY4zxyQqEMcYYn6xAGGOM8SnM7QCtpXv37pqenu52DGOMaVNWr15dpKqJvt5rNwUiPT2drKwst2MYY0ybIiI5Tb1np5iMMcb4ZAXCGGOMT1YgjDHG+GQFwhhjjE9WIIwxxvjkeIEQkQEiUikizzXxvojIwyJy0Ls9IiLidE5jjOno3LjM9VFg1XHenwN8GxgJKPBvYCfweOCjGWOMOcbRHoSIzACOAP85TrNrgD+oap6q7gX+AFwbqEzrco/w8LubA/XxxhgTMKrKr9/aRPa+4oB8vmMFQkTigQeAH52g6TBgfYPn672v+frMOSKSJSJZhYWFLcq1Ie8Ijy3dwca9gTnAxhgTKJ/tPMSTn+xiS35JQD7fyR7Eg8B8Vc09QbtYoOG3dTEQ62scQlXnqWqmqmYmJvq8U/yELh7Vm8iwEP6+ak+Lft4YY9zy0qo9xEWFMXV4z4B8viMFQkRGAd8A/s+P5qVAfIPn8UCpBmjpu4RO4VxwSk+WrNtHRXVdIHZhjDGtrri8hnc25nPJqF50iggNyD6c6kGcBaQDe0QkH7gLmCYia3y0zcYzQH3MSO9rAXNZZgollbW8s3F/IHdjjDGtZsn6vVTV1jNjbGrA9uFUgZgH9ANGebfHgbeA83y0XQjcKSK9RaQXnjGLBYEMNz6jK+ndonlp1YnOfhljTHB4aVUuQ3vGM7x3QsD24UiBUNVyVc0/tuE5jVSpqoUicrqIlDZo/gTwBrAB2IinkDwRyHwiwnczU1i56xA7C0tP/APGGOOijXuLyd53lBnjUgK6H1fupFbV+1X1Su/jT1Q1tsF7qqp3q2pX73Z3oMYfGpp+ah9CQ4SXs/ICvStjjDkpf1+1h8iwEC4Z2Tug+7GpNryS46M4e1Air67Oo6au3u04xhjjU0V1HUvW7WPq8B4kRIcHdF9WIBq4fGwqRaVVfLD5gNtRjDHGp7c37KekspbLxgb29BJYgfiKswclkhwfyYuf2z0Rxpjg9OLne+jbPYYJGd0Cvi8rEA2EhYZwWWYKH20tZO+RCrfjGGPMV2wrKCEr5zAzxqbgxBymViAauSzT022zS16NMcHmxc9zCQ8Vpp3ax5H9WYFoJKVrNGcMSOSVrFxqbbDaGBMkKmvqWLw2j3OH9aB7bKQj+7QC4cPMcSnsL67ko60tmwDQGGNa23vZ+Rwpr2FmAO+cbswKhA9ThiTTPTaSFz+300zGmODw4ud7SO0azcR+gR+cPsYKhA/hoSFcltmHDzYXkF9c6XYcY0wHt7OwlM92HmLGuBRCQpxbYNMKRBNmjE2lXm2w2hjjvhc/30NYiDDdocHpY6xANCG1WzRnDEzk76v22GC1McY1lTV1vLI6j/OG9SApLsrRfVuBOI5Zp6Wyv7iSD7fYYLUxxh3vbNzPkfIarjjNucHpY6xAHMeUwUkkx0fy/Moct6MYYzqo5z9z7s7pxqxAHEdYaAiXj03lo62F5B4qdzuOMaaD2Zx/lKycw1wxLtXRweljrECcwIyxKQjY/EzGGMe9sHIPEaEhjt053ZhjBUJEnhOR/SJyVES2isgNTbS7VkTqRKS0wXaWUzkb69W5E+cMTublrDyqa22w2hjjjPLqWl5bs5cLTulB15gIVzI42YP4f0C6qsYDFwMPicipTbRdoaqxDbaljqX0YdZ4zzTg72XnuxnDGNOB/HPdPkqqarnitDTXMjhWIFQ1W1Wrjj31bv2c2v/JOGNAIildO/HcZzZYbYwJPFVl4YocBiXHMTa9i2s5HB2DEJG5IlIObAb2A2830XS0iBR5T0XdKyJhTXzeHBHJEpGswsLAXYoaGiLMOi2NlbsOsbWgJGD7McYYgLW5R9i0/yhXTkhzZFrvpjhaIFT1ViAOOB1YDFT5aPYxMBxIAqYBM4EfN/F581Q1U1UzExMTAxPa67LMFCLCQqwXYYwJuOdW5BAbGcalowO75vSJOH4Vk6rWqeoyoA9wi4/3d6rqLlWtV9UNwAPAdKdzNtY1JoJvndKTxWv2UlpV63YcY0w7daismje/2M93xvQmNtLnyRPHuHmZaxj+jUEo4F4fq4ErJ6RRWlXL62v3uh3FGNNOvZyVS3VdPVeOd29w+hhHCoSIJInIDBGJFZFQETkPz6mjD3y0nSoiyd7Hg4F7gSVO5DyR0SmdGdYrnuc+y0FV3Y5jjGln6uqV51fmcFrfrgxMjnM7jmM9CMVzOikPOAz8HviBqi4RkVTvvQ7HJhqZAnwhImV4BrEXA79xKOdxiQhXjU9jc75nXVhjjGlNH28tJPdQBVdNcL/3AJ7TPAGnqoXAmU28tweIbfD8LuAuJ3K1xMWjevGbt7/k2U93Mza9q9txjDHtyLMrdpMYF8m5Q3u4HQWwqTaaLToijMsyU3h3Yz4FR20xIWNM69hdVMbSLYXMOi2ViLDg+GoOjhRtzFUT0qhT5fmVNj+TMaZ1LFyRQ3iouDKtd1OsQLRAWrcYzh6UxAsr99j8TMaYk1ZWVcsrq3OZOryn44sCHY8ViBa6ekIaRaVVvLNxv9tRjDFt3Gtr91JSWcs1E4NjcPoYKxAtdMaARPp2j+HZT3e7HcUY04Z55l3azfDe8YxJdW/eJV+sQLRQSIjnktc1e46wIa/Y7TjGmDZqxc6DbC0o5eoJ6a7Ou+SLFYiTMD2zD9ERoTzz6S63oxhj2qgFy3fTJTqci0f2cjvK11iBOAnxUeFMP7UPb67fT2GJr3kHjTGmabmHynn/ywJmjkslKjzU7ThfYwXiJF0zMZ3qunpesEtejTHNtHDFbs8MDUFy53RjViBOUr/EWM4alMhzK3PskldjjN/Kqmr5+6pcpg7vQc+ETm7H8ckKRCu4dmI6hSVVvLVhn9tRjDFtxOI1eZRU1nLdpHS3ozTJCkQrOGNAIhmJMTyzfLfN8mqMOaH6euWZT3czok9C0F3a2pAViFYQEiJcNzGdL/KKWbPHZnk1xhzfJ9uL2FlYxnWTgu/S1oasQLSS74zpQ3xUGE8v3+12FGNMkHt62S4S4yK58JTgu7S1ISsQrSQmMoyZ41J5Z8N+cg+Vux3HGBOkthWU8NHWQq4enxY0s7Y2xbF0IvKciOwXkaMislVEbjhO2x+KSL6IFIvI0yIS6VTOk3HNRE930abfMMY05enlu4gMC2FWECwpeiJOlq//B6SrajxwMfCQiJzauJF3OdKf4llZLh3IAH7lYM4W69W5Exec0pOXVuVSUlnjdhxjTJA5WFrFP9bs5Ttj+tA1JsLtOCfkWIFQ1WxVPXa7sXq3fj6aXgPM97Y/DDwIXOtMypM3e3JfSqpqeTkrz+0oxpgg87x3iYDZk9PdjuIXR0+AichcESkHNgP78aw53dgwYH2D5+uBZBHp5uPz5ohIlohkFRYWBiRzc41K6UxmWheeWb6Lunq75NUY41FVW8fCFTmcNSiR/klxbsfxi6MFQlVvBeKA04HFgK8JjGKBhtOjHnv8tSOqqvNUNVNVMxMTE1s7bovNntyXvMMV/Cs73+0oxpgg8c91+ygqrWL25L5uR/Gb40PoqlqnqsuAPsAtPpqUAvENnh97XBLobK3l3GE9SOnaiaeW2SyvxhjPmg/zl+1iUHIck/t3dzuO39y8xioM32MQ2cDIBs9HAgWqetCRVK0gNESYPakvq3MOszrnkNtxjDEu+3hbEZvzS7jxjIygvjGuMUcKhIgkicgMEYkVkVDvlUozgQ98NF8IzBaRoSLSBbgHWOBEztb03cwUEjqFM+/jnW5HMca47MmPd5IcHxmUaz4cj1M9CMVzOikPOAz8HviBqi4RkVQRKRWRVABVfRd4BPgQyPFu9zmUs9XERIZx1fg0/rWpgF1FZW7HMca4JHtfMcu2F3HdpL5Bf2NcY46kVdVCVT1TVTuraryqnqKqT3rf26Oqsaq6p0H7P6pqsrftdQ0uj21Trp6YRnhICPOXWS/CmI7qyY93EhMRysxxqW5Haba2Vc7amKS4KC4d3ZtXsvI4WNoma5wx5iTsO1LBG1/sZ8a4VBI6hbsdp9msQATYDaf3paq2nkWf5bgdxRjjsGeWe65kDOY1H47HCkSADUiOY8rgJBauyKGius7tOMYYhxSX1/DCyj18a0RP+nSJdjtOi1iBcMDNZ/XjUFk1L2fluh3FGOOQ51bmUFZdx01n+Lqav22wAuGAseldOTWtC09+spPaOlu32pj2rrKmjmeW7+KMgYkM7RV/4h8IUlYgHHLzmf3IO1zBWxv2ux3FGBNgr67Oo6i0mpvPzHA7ykmxAuGQKYOTGJAUy+Mf7bR1q41px+rqlSc/2cnIPglMyPjaHKNtihUIh4SECHPOyODL/Uf5aGtwzDxrjGl972zcT87Bcm4+s1+bmlbDFysQDrpkVG96JkTx2NIdbkcxxgSAqvL4Rzvo2z2Gc4f1cDvOSbMC4aCIsBBuOD2DlbsOsTrnsNtxjDGt7ONtRWzce5Sbz8wgNKRt9x7ACoTjZo5LoUt0OI8t3e52FGNMK5v74XZ6JkRx6eg+bkdpFVYgHBYdEcb1k/ry/pcH+HL/UbfjGGNaSdbuQ6zcdYgbT89oc5PyNaV9/Fe0MVdPSCc2MszGIoxpR+Yu3UHXmAhmjEtxO0qrsQLhgoTocGaNT+XNL/ax26YCN6bNy95XzAebD3DdxHSiI8LcjtNqnFowKFJE5otIjoiUiMhaEZnaRNtrRaTOu0bEse0sJ3I6afbkvoSFhvD4R9aLMKate2zpDmIjw7h6QrrbUVqVUz2IMCAXOBNIAO4FXhaR9Cbar/CuEXFsW+pISgclxUUxY2wK/1iTx94jFW7HMca00PYDpby1YT9Xjk8jIbrtTel9PE4tGFSmqver6m5VrVfVN4FdwKlO7D9Y3XSmZxKvJ6wXYUybNffD7USGhXDD6X3djtLqXBmDEJFkYCCQ3UST0SJSJCJbReReEfF5Uk9E5ohIlohkFRa2vbuTe3fuxLQxffj7qlwOHK10O44xpplyDpaxZP0+rjwtje6xkW7HaXWOFwgRCQeeB55V1c0+mnwMDAeSgGnATODHvj5LVeepaqaqZiYmJgYqckDdelZ/6uqVJz62ZUmNaWvmfriDUO80Ou2RowVCREKARUA1cJuvNqq6U1V3eU9FbQAeAKY7GNNRqd2iuWRUL55fmUORLUtqTJuRd7icf6zJY+bYFJLio9yOExCOFQjxzFo1H0gGpqlqjZ8/qkDbv2f9OL53dn+qaut58hPrRRjTVjz+0Q5E/jeW2B452YN4DBgCXKSqTV62IyJTvWMUiMhgPFc8LXEmojv6JcbyrRG9WLQih0Nl1W7HMcacwP7iCl5elcf0U1Po1bmT23ECxqn7INKAm4BRQH6D+xtmiUiq93Gqt/kU4AsRKQPeBhYDv3Eip5u+f05/KmrqmGdjEcYEvbkf7qBele+d3X57D+C5PyHgVDWH458mim3Q9i7groCHCjIDkuP41oheLFyxmxtP70u3dnhFhDHtwb4jFby0KpfvZqbQp0u023ECyqbaCCJ3TPH0Ip78ZJfbUYwxTZi7dDtK++89gBWIoNI/KY6LvL2Ig3ZFkzFBpyP1HsAKRND5vrcXMc+uaDIm6Dz6oWcdl++d3d/lJM6wAhFk+ifFcfHIXiz8NIfCEutFGBMs8g6X83KWp/fQux1fudSQFYggdMeUAVTV1tlMr8YEkb/+Zzsiwu3ndIzeA1iBCEoZibF8Z0wfFn2WQ36xzdFkjNt2FZXx6po8rhiXSs+EjtF7ACsQQeuOKQOor1f+9uE2t6MY0+H9+f2thIcKt3aAK5casgIRpFK6RnP52BReWpVL7qFyt+MY02FtLShhyfp9XDMxnaS49jnnUlP8KhAiEi0io0Ukzsd7k1o/lgG47Zz+iAh//cB6Eca45U/vbyUmIoybz+hYvQfwo0CIyDggB1gKFIjI3Y2avBOAXAbomdCJK09L49XVeewoLHU7jjEdzsa9xby9IZ/rJ6XTJSbC7TiO86cH8Qfg56qaAEwErhSRxxu8365nWnXbrWf3Iyo8lD/+a6vbUYzpcB55bwudo8O5oZ2u93Ai/hSI4cBTAKq6DpgMDBaRRd71HUwAdY+N5IbTM3hrw3425BW7HceYDmPFjoN8vLWQ753Vn/io9rXWtL/8+YIvB/67XJuqHgXO9772KtaDCLgbT+9Ll+hwHnnP1wJ8xpjWpqo88t5mesRHcdWENLfjuMafAvERcEXDF1S1ErgYCAc6zkXBLomLCufWs/rzybYiPt1R5HYcY9q9f28qYO2eI9zxjQFEhYe6Hcc1/hSIO/CxYI+qVgOXAme3dijzdVdNSKNnQhSPvLsFVXU7jjHtVl298vt/bSGjewzfPbWP23FcdcICoaqFwIhjz0Xk4gbv1arqxyf6DBGJFJH5IpIjIiUislZEph6n/Q9FJF9EikXkaRHp8IsjRIWH8oNvDGBd7hHey853O44x7dbiNXlsLSjlznMHEhbasYdZ/f2vTxKR60XkGjxrSjdXGJALnAkk4FlG9GURSW/cUETOA36KZ2W5dCAD+FUL9tnuTBvTh/5JsTzy7hZq6urdjmNMu1NZU8cf/72VkX0SuPCUnm7HcZ0/90GcAWwDbgBuBLZ6X/Obqpap6v2qultV61X1TWAXcKqP5tcA81U1W1UPAw8C1zZnf+1VWGgIPz1/MDuLyvj7qly34xjT7jyzfDf7iyv56dQhiNj1N/70IPoCaXgGo6O9j/uezE5FJBkYCGT7eHsYsL7B8/VAsoh08/E5c0QkS0SyCgsLTyZSmzFlSBLj0rvy5/e3UVZV63YcY9qNw2XVzF26nXMGJzGh39e+bjokf8YgngWKgee821Hvay0iIuHA88Czqurrus1Y7/6OOfb4a9N8qOo8Vc1U1czExMTGb7dLIsLPLhhMUWkVT9qiQsa0mkc/3E5ZVS0/OX+w21GChr9jEInAn4C/0OCeiOby3li3CKgGbmuiWSkQ3+D5scclLd1vezM6tQsXnNKDeR/v5ECJTQduzMnKPVTOwhU5TBvTh0E9vvZv0Q7L3wIh/O+GuBadmBPPCb35eAa5p6lqTRNNs4GRDZ6PBApU9WBL9tte3X3eYGrq6vm/f9sUHMacrIff3UxICNx57kC3owQVfwtEIZ77IW4HDrRwX48BQ4CLVLXiOO0WArNFZKiIdAHuARa0cJ/tVnr3GK4an85Lq3LZnH/U7TjGtFmrcw7z5hf7mXN6RodaDMgf/lzFdDWe0zxXerd472t+E5E04CZgFJAvIqXebZaIpHofpwKo6rvAI8CHeGaRzQHua87+OorvT+lPXFQ4v37rS7t5zpgWUFUeemsTiXGR3HRmx5vO+0TC/GiT4/3fckAbPPebquZw/FNTsY3a/xH4Y3P309F0jo7g+1MG8OCbm1i6tZCzByW5HcmYNuWtDftZu+cIj0wbQUykP1+HHYs/VzF9hOeS1Ke820DvayYIXDU+jfRu0fzmrS+ptZvnjPFbZU0dD7+7mcE94pjWwafUaEpzxiB2quoC7+P/EpGZrR3K+C8iLISfTh3CtgOlvPj5HrfjGNNmPLN8N7mHKrjnwqGEhthNcb74VSBU9XXgVRF5GHgLQEQ6i8hL2DQYrjtvWDITMrrxx39v5Uh5tdtxjAl6B0oq+dsH2/jm0GQmD+judpyg1ZyZqEbiGWReJSKzgQ3AEWB0IIIZ/4kIv7xoKMUVNfzpfVu/2pgT+d27W6iuq+cXFwxxO0pQ87tAqOo+4Nven5kHvKOqN6lqWaDCGf8N6RnPzHGpLPosh20Fdk+hMU1Zn3uEV1bncf3kvqR3j3E7TlDzu0CIyCggC9gJXAKcIyIvikjnQIUzzXPnNwcSExHKA29usstejfFBVfnVG9l0j43ktrP7ux0n6DXnFNN/gD+q6re9s7GOxHPp64aAJDPN1i02kh98YyCfbCvi/S9bej+jMe3XknX7WLPnCHefP4i4DrrOdHM0p0CMVdX5x554p/CeDXyv9WOZlrpqQhoDkmJ54M1sKmvq3I5jTNAoqazh129/ycg+CUwfY5e1+qM5YxA+pw5V1X+2XhxzssJDQ/jVxcPIPVTBEx/ZbK/GHPOX/2yjqLSKBy4ZTohd1uqXjr2eXjs1sX93LhzRk7lLt5N7qNztOMa4bltBCc8s383lmSmMTLFhU39ZgWin7rlwCCEiPPTWJrejGOMqVeX+N7KJjgjlx+cNcjtOm2IFop3qmdCJ26f0573sApZusQFr03G9vSGf5dsP8uPzBtEtNtLtOG2KFYh27IbJGWQkxnDfP23A2nRMJZU1PPBmNsN6xXPFaWlux2lzrEC0YxFhITx0yXByDpYz98PtbscxxnF//PdWDpRU8etLT7H5llrACkQ7N7F/d749qhePf7STHYWlbscxxjEb9xbz7Ke7mXVaKqNsYLpFHCsQInKbiGSJSJWILDhOu2tFpK7BokKlInKWUznbo19cOJTI8BDufX2j3WFtOoT6euWe1zfSNSaCH5832O04bZaTPYh9wEPA0360XaGqsQ22pYGN1r4lxkVy9/mD+XTHQZas2+d2HGMC7oXP97Au9wj3XDiUhE52x3RLOVYgVHWxd9rwg07t0/zPFeM83ewH39zE4TKbEty0XwVHK3n4nc1M6t+NS0b1cjtOmxasYxCjRaRIRLaKyL0i4nMtQBGZ4z1tlVVYWOirifEKDRF+O+0Uiis80w0Y017dtySb6rp6fv3tUxCxgemTEYwF4mNgOJAETANmAj/21VBV56lqpqpmJiYmOhixbRrcI56bzszg1dV5LN9e5HYcY1rde9n5vJudzx3fGGBTebeCoCsQqrpTVXepar2qbgAeAKa7nau9uP2cAaR3i+bnr22weyNMu1JSWcN9S7IZ3COOG0/PcDtOuxB0BcIHBayf2EqiwkP5zXdOIedgOf/3/la34xjTah5+dzMFJZX8dtoIwkPbwldb8HPyMtcwEYkCQoFQEYnyNbYgIlNFJNn7eDBwL7DEqZwdwcR+3ZkxNoUnP97J+twjbscx5qSt2HGQ5z7bw3UT+9o9D63IyTJ7D1AB/BS40vv4HhFJ9d7rkOptNwX4QkTKgLeBxcBvHMzZIfz8wiEkxUVx96tfUF1b73YcY1qsvLqWn/zjC9K6RdtkfK3Myctc71dVabTdr6p7vPc67PG2u0tVk1U1RlUzVPWXqlrjVM6OIj4qnF9fOpwtBSX8zabhMG3YH/61lT2Hyvntd0bQKSLU7Tjtip2o68CmDEnm0tG9mfvhdjbtO+p2HGOabXXOYZ5evosrx6cyoV83t+O0O1YgOrhffmsonaPDueuV9XaqybQpFdV1/PiV9fRK6MRPpw5xO067ZAWig+sSE8GvLz2FTfuP8tcPtrkdxxi/PfLeZnYWlfHI9BHERvq8l9acJCsQhvOG9eA7Y3ozd+kO1tlVTaYN+HRHEc8s3801E9KY1L+723HaLSsQBoD7LhpGUlwkP3p5nd1AZ4JaSWUNP37lC9K7RfOTqTZTayBZgTAAJHQK5+FpI9hRWMbv3tvidhxjmvTQm1+yv7iCP1w2kugIO7UUSFYgzH+dMTCRq8anMX/ZLpZts7maTPB5d2M+L2XlctOZ/Tg1ravbcdo9KxDmK35+wRD6Jcbwo1fWcaTcpgU3wePA0Up+tvgLhveO54ffGOh2nA7BCoT5ik4Rofx5xmgOlVXz89c22Ap0JijU1yt3vfoFFTV1/Ony0USE2VeXE+wom68Z3juBO785iLc35PPq6jy34xjDsyt28/HWQn5x4VD6J8W6HafDsAJhfJpzRgan9e3Kff/MZmdhqdtxTAe2ad9R/t87mzlncBJXnpZ64h8wrcYKhPEpNET404xRRISFcPuLa6mqtUtfjfPKq2u57cU1dO4Uzu+mj7AV4hxmBcI0qWdCJ343fSTZ+47y23c2ux3HdED3LclmV1EZf5oxim6xkW7H6XCsQJjj+ubQZK6dmM4zy3fz/qYCt+OYDmTJur28sjqP287uz8R+dre0G5xcMOg2EckSkSoRWXCCtj8UkXwRKRaRp0XE/ungop9dMJihPeO569X15B0udzuO6QB2FJby88UbyEzrwh1TBrgdp8NysgexD3gIePp4jUTkPDyLCk0B0oEM4FeBDmeaFhkWyqOzxlBbp9z2wlqb9dUEVEV1Hbc+t4aIsBD+MnM0YbZ8qGucXDBosaq+Dhw8QdNrgPmqmq2qh4EHgWsDnc8cX9/uMfxu+gjW5R7hN29/6XYc047du2QjWw+U8KcZo+nVuZPbcTq0YCzNw4D1DZ6vB5JFxFYDcdnUU3py3aR0Fny6m7e+2O92HNMOvZyVy6ur87j97P6cOTDR7TgdXjAWiFiguMHzY4/jGjcUkTnecY2swsJCR8J1dD+bOoTRqZ25+9X1bD9Q4nYc045s3FvMva9vZGK/btxhU2kEhWAsEKVAfIPnxx5/7dtIVeepaqaqZiYm2r82nBARFsLcWWPoFBHKnEWrKam05cLNyTtUVs1Ni1bTNSaCv8wcTWiI3e8QDIKxQGQDIxs8HwkUqOqJxi6MQ3omdOJvV4wh52A5d768nvp6m6/JtFxtXT23v7iGwtIqHr/yVLrb/Q5Bw8nLXMNEJAoIBUJFJEpEfE3mvhCYLSJDRaQLcA+wwKmcxj/jM7rxiwuG8O9NBfztw+1uxzFt2O/e28Ly7Qd56NvDGZnS2e04pgEnexD3ABV4LmG90vv4HhFJFZFSEUkFUNV3gUeAD4Ec73afgzmNn66blM6lo3vzf8PDN34AABHlSURBVO9v5V/Z+W7HMW3Q62v38sTHO7lyfCqXZaa4Hcc0Iu1lOufMzEzNyspyO0aHU1lTx+VPrGDbgVL+cctEhvSMP/EPGQOs3XOYy+d9xpjUziyafRrhdr+DK0Rktapm+nrP/h8xJyUqPJR5V2cSFxXGDc9mUVRa5XYk0wbsL65gzqLV9IiP4rFZp1pxCFL2/4o5acnxUTx5dSYHy6q4edFqm/nVHFd5dS03LsyiorqOp67JpEtMhNuRTBOsQJhWMaJPZ37/3ZFk5Rzm7le/sJXojE919cr3X1zHpn1H+cvMUQxM/trtTSaI+LqKyJgW+daIXuw5VM4j724htWs0Pzp3kNuRTJB58M1NvP9lAQ9cMoxzBie7HcecgBUI06puObMfew6W89cPtpPSJZrLxtqVKcbj6WW7WPDpbmZP7svVE9LdjmP8YAXCtCoR4cFvD2fvkQp+/toGkuIjOWtQktuxjMve2bCfB9/axHnDkvn5BUPcjmP8ZGMQptWFh3qm4xjUI45bnlvD2j2H3Y5kXPTpjiLu+Ps6xqR24U+X2zQabYkVCBMQcVHhLLhuHEnxkVy/YBXbD5S6Hcm4YOPeYuYsXE1at2jmX5NJp4hQtyOZZrACYQImMS6ShdePIzREuObpz9l3pMLtSMZBOQfLuPaZVcRHhbFw9jg6R9vlrG2NFQgTUGndYlhw3TiOVtRw5VMrKSyxG+k6gn1HKrjiyZXU1dfz7PXj6JlgC/+0RVYgTMAN753AM9eNZX9xJVfNX8mR8mq3I5kAOlBSyaynVnK0ooZFs09jgN3r0GZZgTCOyEzvypNXZ7KzsIxrnv7c1pFopw6XVXP1/M/JL65kwfVjGd47we1I5iRYgTCOmTygO3NnjSF731GutiLR7hwuq2bWUyvZWVTGU9dkcmpaV7cjmZNkBcI46htDk/nbFWPYkFfM1U9/zlErEu3CobJqrnhqJdsLS3nq6kwm9e/udiTTCqxAGMedP7wHj87yFon5ViTaukPHeg7e4nDGQFv+t71wckW5riLymoiUiUiOiFzRRLv7RaTGu4jQsS3DqZzGGecN6+E93VTMFU9+xkGbJrxNKjhayeVPrGBnYSlPWnFod5zsQTwKVAPJwCzgMREZ1kTbl1Q1tsG207GUxjHnDuvBvKsz2VZQyuXzPqPgaKXbkUwz5B4q57uPr2DfkQoWXDfOikM75EiBEJEYYBpwr6qWquoy4J/AVU7s3wSvswcl8ez148gvrmT645+y52C525GMH7YfKOG7j6+guKKG528cz4R+3dyOZALAqR7EQKBOVbc2eG090FQP4iIROSQi2SJyS1MfKiJzRCRLRLIKCwtbM69x0PiMbjx/w2mUVNbynceWsyGv2O1I5jhW7T7EtMdWUFuvvHTTeEaldHY7kgkQpwpELND4r74Y8HUHzcvAECARuBH4pYjM9PWhqjpPVTNVNTMx0bq3bdnIlM68evNEIsNCuXzeCj7aagU/GL27MZ8rn1pJt5gIXrt1IoN72Brk7ZlTBaIUaPybFA+UNG6oqptUdZ+q1qnqp8CfgekOZDQu658Uy+JbJ5LWLYbZC1bx8qpctyMZL1VlwfJd3PL8aob2iufVWyaS0jXa7VgmwJwqEFuBMBEZ0OC1kUC2Hz+rgM0P3EEkx0fx8k2ec9p3/+MLfvP2l9TV2/Klbqqpq+feJRu5/41NTBmczAs3jKerrSPdIThSIFS1DFgMPCAiMSIyCbgEWNS4rYhcIiJdxGMc8H1giRM5TXCIiwrn6WvHctX4NOZ9vJObFq2mtKrW7VgdUnFFDdcvWMVzn+3hpjMyeOKqU23K7g7EyctcbwU6AQeAF4FbVDVbRE4XkYaLBcwAtuM5/bQQeFhVn3UwpwkC4aEhPPjt4fzq4mF8sLmA78xdzs5CW1PCSVvyS7jkb8v4bOdBHpk+gp9dMMQW++lgRLV9dN8zMzM1KyvL7RgmAJZtK+L2F9dQW6f88fJRfHOoLXYfaG+s38fdr35BbFQYc2eNYWy6zavUXonIalXN9PWeTbVhgt7kAd154/bJpHWP5saFWfzuvc3U1tW7Hatdqqqt41dvZHP7i2sZ1iuet26fbMWhA7MCYdqEPl2iefXmiVyemcKjH+5gxrzP2Gsr1LWq3UVlTH9sBc8s3821E9N54cbxJMVHuR3LuMgKhGkzosJDeXj6CP48YxSb80u44M+f8O7G/W7HavNUldfX7uVbf13GnkPlPHHVqdx/8TAiwuzroaOz3wDT5lwyqjdvfX8yad2iufm5NfzwpXUUV9iMsC1xsLSKW59fww9eWsfgHnG8fcfpnDesh9uxTJAIczuAMS2R1i2Gf9wykb99sJ2/fbidFTsO8vD0EZxpE8b57b3sfH7x2gaOVtTyk/MHM+eMDLtKyXyF9SBMmxUeGsIPvzmQ126dSGxUGNc8/Tl3/H0tRTZ1+HHlF1dy86LV3LRoNYlxUfzz9kncclY/Kw7ma+wyV9MuVNbUMXfpDh5bup3oiDB+NnUwl2WmEGJfev9VW1fP8yv38Lv3tlBTV88d3xjAjadnEB5q/07syI53masVCNOubCso4eevbWDV7sOc0juB+y4aSqZdpsny7UU88MYmthSUMLl/d3596XDSusW4HcsEASsQpkNRVZas28dv39lM/tFKvjWiJ3edO4j07h3vC3FbQQm/e28L/9pUQJ8unfjFBUM4f3gPRKxnZTysQJgOqby6lseX7uDJT3ZRXVfPZZkp3DFlAD0S2v+1/bmHyvnT+9t4bW0e0RFh3HxmBjecnkFUuM2jZL7KCoTp0A6UVPLoB9t54fM9iAjTT+3DTWdktMtTLNsPlPL4Rzt4fe1eQkKEayakcctZ/W32VdMkKxDG4PlX9dylO/jH6jxq6+u54JSeXDepL2NSO7fpUy6qyspdh1iwfDfvbconMiyEGWNTuenMDHomdHI7nglyViCMaeDA0UrmL9vFCyv3UFJVy7Be8Vw9IY0LR/QiNrLt3Bp0tLKGf67bx6IVOWwpKCGhUzhXjU/juknpdIuNdDueaSOsQBjjQ1lVLa+t3cvCFbvZWlBKp/BQzh/eg0tH92ZCv25BeflndW09y7YXsnjNXv61qYDq2nqG9ozn2onpXDSyl63VYJrNCoQxx6GqrM45zOK1e3lz/T6OVtaS0CmcKUOSOHdoDyb170ZcVLhr+YrLa1i2vYj3svP5cPMBSqpq6RIdzsUje3HpmD6M7JPQpk+RGXcFRYEQka7AfOBcoAj4maq+4KOdAL8FbvC+NB/4iZ4gqBUI0xoqa+pYuqWQf23K5z9fHqC4oobQEGFknwQm9e/OmLQujOzTOaCDvkWlVazPPcLqnMMs317Ehr3F1Ct0jYngG96idcbARJtMz7SK4xUIJ0+4PgpUA8nAKOAtEVmvqo3XpZ4DfBvPmtUK/BvYCTzuYFbTQUV5TzOdP7wHNXX1ZO32fEkv31HE3KU7/rs+dkrXTgxKjqd/Uiz9EmPo0yWaHglR9IiP8us0T1lVLflHKykoriTvcAU7CkvZUVjKl/tL/juNeViIMDq1M7efM4DJA7ozOqUzYUF42su0X470IEQkBjgMDFfVrd7XFgF7VfWnjdp+CixQ1Xne57OBG1V1/PH2YT0IE2ilVbVs3FvM+twjrM87wraCUnYfLKOm7qt/Q5FhIcRFhRETGUaE9wtd8YwflFXVUlJVS3XtVxc8iggLIaN7DP2TYhnZpzOjUjszrFc80RFtZ9DctE3B0IMYCNQdKw5e64EzfbQd5n2vYbthvj5URObg6XGQmpraOkmNaUJsZBjjM7oxPqPbf1+rrasn93AF+45UkF9cSf7RSo5W1FBSVUtpZS219f8rBOGhIcRGhhEbFUbnThH0SIgkOT6K3p070adLtE2WZ4KOUwUiFihu9FoxEOdH22IgVkSk8TiEt5cxDzw9iNaLa4x/wkJD6Ns9hr4dcBoP0/45dUKzFIhv9Fo8UOJH23ig9ESD1MYYY1qXUwViKxAmIgMavDYSaDxAjfe1kX60M8YYE0COFAhVLQMWAw+ISIyITAIuARb5aL4QuFNEeotIL+BHwAInchpjjPkfJ6+ZuxXoBBwAXgRuUdVsETldREobtHsCeAPYAGwE3vK+ZowxxkGOXUOnqofw3N/Q+PVP8AxMH3uuwN3ezRhjjEvsrhtjjDE+WYEwxhjjkxUIY4wxPrWb2VxFpBDIaeGPd8czgWCwsVzNY7maL1izWa7mOZlcaaqa6OuNdlMgToaIZDU1F4mbLFfzWK7mC9Zslqt5ApXLTjEZY4zxyQqEMcYYn6xAeMxzO0ATLFfzWK7mC9Zslqt5ApLLxiCMMcb4ZD0IY4wxPlmBMMYY45MVCGOMMT51uAIhIpEiMl9EckSkRETWisjUE/zMD0UkX0SKReRpEYkMULbbRCRLRKpEZMEJ2l4rInUiUtpgO8vtXN72Th2vriLymoiUef//vOI4be8XkZpGxyvD6Szi8bCIHPRuj4hIwNYabUaugB6fRvtqzu+5I79Lzc3m8N9fs76zWvOYdbgCgWcG21w862EnAPcCL4tIuq/GInIe8FNgCpAOZAC/ClC2fcBDwNN+tl+hqrENtqVu53L4eD0KVAPJwCzgMRHxuX6510uNjtdOF7LMwTOr8UhgBPAt4KZWzNHSXBDY49OQX79PDv8uNSubl1N/f35/Z7X6MVPVDr8BXwDTmnjvBeA3DZ5PAfIDnOchYMEJ2lwLLHP4OPmTy5HjBcTg+eIb2OC1RcBvm2h/P/BcgI6L31mAT4E5DZ7PBj4LglwBOz4t/X1y42+vGdkc//trtH+f31mtfcw6Yg/iK0QkGRhI08uaDgPWN3i+HkgWkW6BzuaH0SJSJCJbReReEXFsfY/jcOp4DQTqVHVro30drwdxkYgcEpFsEbnFpSy+js/xMjuVCwJ3fFoqmP/2wKW/vxN8Z7XqMevQBUJEwoHngWdVdXMTzWKB4gbPjz2OC2Q2P3wMDAeSgGnATODHribycOp4Nd7PsX01tZ+XgSFAInAj8EsRmelCFl/HJzZA4xDNyRXI49NSwfq3By79/fnxndWqx6zdFQgRWSoi2sS2rEG7EDzd7WrgtuN8ZCkQ3+D5scclgcjlL1Xdqaq7VLVeVTcADwDTm/s5rZ0L545X4/0c25fP/ajqJlXdp6p1qvop8GdacLya0Jwsvo5PqXrPB7Qyv3MF+Pi0VKv8LgVCa/39NYef31mteszaXYFQ1bNUVZrYJoPnShJgPp6Bu2mqWnOcj8zGM6B4zEigQFUPtnauk6RAs/8VGoBcTh2vrUCYiAxotK+mThV+bRe04Hg1oTlZfB0ffzMHMldjrXl8WqpVfpccEtDj1YzvrFY9Zu2uQPjpMTzd6YtUteIEbRcCs0VkqIh0Ae4BFgQilIiEiUgUEAqEikhUU+c1RWSq91wkIjIYz5UNS9zOhUPHS1XLgMXAAyISIyKTgEvw/AvL13/DJSLSRTzGAd+nlY5XM7MsBO4Ukd4i0gv4EQH6fWpOrkAeHx/78vf3ybG/veZmc/Lvz8vf76zWPWZujcK7tQFpeKp9JZ7u2LFtlvf9VO/z1AY/cydQABwFngEiA5Ttfm+2htv9vnIBv/dmKgN24unihrudy+Hj1RV43XsM9gBXNHjvdDynbo49fxE46M26Gfi+E1l85BDgEeCQd3sE75xoTh4jp4+PP79Pbv4uNTebw39/TX5nBfqY2WR9xhhjfOqop5iMMcacgBUIY4wxPlmBMMYY45MVCGOMMT5ZgTDGGOOTFQhjjDE+WYEwxhjjkxUIY4wxPlmBMMYY45MVCGMCQET6eddWGON93su7dsBZLkczxm821YYxASIiN+KZF+dU4DVgg6re5W4qY/xnBcKYABKRfwJ98Uy2NlZVq1yOZIzf7BSTMYH1JJ6Vx/5qxcG0NdaDMCZARCQWz5rAHwJTgVNU9ZC7qYzxnxUIYwJEROYDcap6mYjMAzqr6mVu5zLGX3aKyZgAEJFLgPOBm70v3QmMEZFZ7qUypnmsB2GMMcYn60EYY4zxyQqEMcYYn6xAGGOM8ckKhDHGGJ+sQBhjjPHJCoQxxhifrEAYY4zxyQqEMcYYn/4/4/tCvGNWZU4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(f, 'x', 'x**2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The sequence of steps we described earlier starts by picking some random value for a parameter, and calculating the value of the loss:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEMCAYAAADeYiHoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3yV5f3/8dcnm0xWEmYSwh6yDMhy0qpo1VqogrhRHNXaWrVDrdb1q3Z8OxQVRREcdRSlbmsVFUQkbIJsCGEkJIyQPT+/P86hjfEETkLOfZ8kn+fjcT96xpVzv72bnA/Xfd33dYmqYowxxtQX4nYAY4wxwckKhDHGGJ+sQBhjjPHJCoQxxhifrEAYY4zxKcztAM2lc+fOmpaW5nYMY4xpUVasWFGgqom+3ms1BSItLY3MzEy3YxhjTIsiItkNvWenmIwxxvhkBcIYY4xPViCMMcb4ZAXCGGOMT1YgjDHG+OR4gRCRviJSLiIvNvC+iMijInLAuz0mIuJ0TmOMaevcuMz1CWD5Md6fCfwQGAYo8G9gO/BU4KMZY4w5ytEehIhMBQ4D/zlGs6uAP6nqblXdA/wJuDpQmVbnHObRDzYG6uONMSZgVJWH391A1t7CgHy+YwVCROKBB4BfHKfpYGBNnedrvK/5+syZIpIpIpn5+flNyrVu92GeXLSN9XsCc4CNMSZQvtp+kGe+2MGm3KKAfL6TPYgHgTmqmnOcdrFA3W/rQiDW1ziEqs5W1QxVzUhM9Hmn+HFdOLw7kWEh/GP5rib9vDHGuOXV5buIiwpj0pCuAfl8RwqEiAwHvgf8nx/Ni4H4Os/jgWIN0NJ3Ce3COe+krixcvZeyyppA7MIYY5pdYWkV76/P5aLh3WgXERqQfTjVgzgDSAN2iUgucAcwWURW+mibhWeA+qhh3tcC5pKMnhSVV/P++n2B3I0xxjSbhWv2UFFdy9RRKQHbh1MFYjbQGxju3Z4C3gXO8dF2HnC7iHQXkW54xizmBjLcmPSOpHWK5tXlxzv7ZYwxweHV5TkM6hrPkO4JAduHIwVCVUtVNffohuc0Urmq5ovIqSJSXKf508DbwDpgPZ5C8nQg84kIP87oybIdB9meX3z8HzDGGBet31NI1t4jTB3dM6D7ceVOalW9X1Uv9z7+QlVj67ynqnqXqnb0bncFavyhrikn9yA0RHgtc3egd2WMMSfkH8t3ERkWwkXDugd0PzbVhldyfBRn9k/kjRW7qaqpdTuOMcb4VFZZw8LVe5k0pAsJ0eEB3ZcViDouHZVCQXEFn2zc73YUY4zx6b11+ygqr+aSUYE9vQRWIL7lzP6JJMdH8srXdk+EMSY4vfL1Lnp1jmFseqeA78sKRB1hoSFcktGTzzbns+dwmdtxjDHmW7bkFZGZfYipo3rixBymViDquSTD022zS16NMcHmla9zCA8VJp/cw5H9WYGop2fHaE7rm8jrmTlU22C1MSZIlFfVsGDVbs4e3IXOsZGO7NMKhA/TRvdkX2E5n21u2gSAxhjT3D7MyuVwaRXTAnjndH1WIHyYODCZzrGRvPK1nWYyxgSHV77eRUrHaMb1Dvzg9FFWIHwIDw3hkowefLIxj9zCcrfjGGPauO35xXy1/SBTR/ckJMS5BTatQDRg6qgUatUGq40x7nvl612EhQhTHBqcPsoKRANSOkVzWr9E/rF8lw1WG2NcU15Vw+srdnPO4C4kxUU5um8rEMcw/ZQU9hWW8+kmG6w2xrjj/fX7OFxaxWWnODc4fZQViGOYOCCJ5PhIXlqW7XYUY0wb9dJXzt05XZ8ViGMICw3h0lEpfLY5n5yDpW7HMca0MRtzj5CZfYjLRqc4Ojh9lBWI45g6qicCNj+TMcZxLy/bRURoiGN3TtfnWIEQkRdFZJ+IHBGRzSJyXQPtrhaRGhEprrOd4VTO+rq1b8dZA5J5LXM3ldU2WG2McUZpZTVvrtzDeSd1oWNMhCsZnOxB/D8gTVXjgQuBh0Tk5AbaLlXV2DrbIsdS+jB9jGca8A+zct2MYYxpQ/61ei9FFdVcdkqqaxkcKxCqmqWqFUeferfeTu3/RJzWN5GeHdvx4lc2WG2MCTxVZd7SbPonxzEqrYNrORwdgxCRWSJSCmwE9gHvNdB0hIgUeE9F3SsiYQ183kwRyRSRzPz8wF2KGhoiTD8llWU7DrI5ryhg+zHGGIBVOYfZsO8Il49NdWRa74Y4WiBU9WYgDjgVWABU+Gj2OTAESAImA9OAOxv4vNmqmqGqGYmJiYEJ7XVJRk8iwkKsF2GMCbgXl2YTGxnGxSMCu+b08Th+FZOq1qjqYqAHcJOP97er6g5VrVXVdcADwBSnc9bXMSaCH5zUlQUr91BcUe12HGNMK3WwpJJ31u7jRyO7Exvp8+SJY9y8zDUM/8YgFHCvj1XH5WNTKa6o5q1Ve9yOYoxppV7LzKGyppbLx7g3OH2UIwVCRJJEZKqIxIpIqIicg+fU0Sc+2k4SkWTv4wHAvcBCJ3Iez4ie7RncLZ4Xv8pGVd2OY4xpZWpqlZeWZXNKr470S45zO45jPQjFczppN3AI+CPwM1VdKCIp3nsdjk40MhFYKyIleAaxFwCPOJTzmESEK8aksjHXsy6sMcY0p88355NzsIwrxrrfewDPaZ6AU9V84PQG3tsFxNZ5fgdwhxO5muLC4d145L1veOHLnYxK6+h2HGNMK/LC0p0kxkVy9qAubkcBbKqNRouOCOOSjJ58sD6XvCO2mJAxpnnsLChh0aZ8pp+SQkRYcHw1B0eKFuaKsanUqPLSMpufyRjTPOYtzSY8VFyZ1rshViCaILVTDGf2T+LlZbtsfiZjzAkrqajm9RU5TBrS1fFFgY7FCkQTXTk2lYLiCt5fv8/tKMaYFu7NVXsoKq/mqnHBMTh9lBWIJjqtbyK9Osfwwpc73Y5ijGnBPPMu7WRI93hGprg375IvViCaKCTEc8nryl2HWbe70O04xpgWaun2A2zOK+bKsWmuzrvkixWIEzAlowfREaE8/+UOt6MYY1qouUt20iE6nAuHdXM7yndYgTgB8VHhTDm5B++s2Ud+ka95B40xpmE5B0v5+Js8po1OISo81O0432EF4gRdNS6NyppaXrZLXo0xjTRv6U7PDA1Bcud0fVYgTlDvxFjO6J/Ii8uy7ZJXY4zfSiqq+cfyHCYN6ULXhHZux/HJCkQzuHpcGvlFFby7bq/bUYwxLcSClbspKq/mmvFpbkdpkBWIZnBa30TSE2N4fslOm+XVGHNctbXK81/uZGiPhKC7tLUuKxDNICREuGZcGmt3F7Jyl83yaow5ti+2FrA9v4Rrxgffpa11WYFoJj8a2YP4qDCeW7LT7SjGmCD33OIdJMZFcv5JwXdpa11WIJpJTGQY00an8P66feQcLHU7jjEmSG3JK+KzzflcOSY1aGZtbYhj6UTkRRHZJyJHRGSziFx3jLY/F5FcESkUkedEJNKpnCfiqnGe7qJNv2GMachzS3YQGRbC9CBYUvR4nCxf/w9IU9V44ELgIRE5uX4j73Kkv8KzslwakA78zsGcTdatfTvOO6krry7Poai8yu04xpggc6C4gn+u3MOPRvagY0yE23GOy7ECoapZqnr0dmP1br19NL0KmONtfwh4ELjamZQnbsaEXhRVVPNa5m63oxhjgsxL3iUCZkxIczuKXxw9ASYis0SkFNgI7MOz5nR9g4E1dZ6vAZJFpJOPz5spIpkikpmfnx+QzI01vGd7MlI78PySHdTU2iWvxhiPiuoa5i3N5oz+ifRJinM7jl8cLRCqejMQB5wKLAB8TWAUC9SdHvXo4+8cUVWdraoZqpqRmJjY3HGbbMaEXuw+VMZHWbluRzHGBIl/rd5LQXEFMyb0cjuK3xwfQlfVGlVdDPQAbvLRpBiIr/P86OOiQGdrLmcP7kLPju14drHN8mqM8az5MGfxDvonxzGhT2e34/jNzWuswvA9BpEFDKvzfBiQp6oHHEnVDEJDhBnje7Ei+xArsg+6HccY47LPtxSwMbeI609LD+ob4+pzpECISJKITBWRWBEJ9V6pNA34xEfzecAMERkkIh2Ae4C5TuRsTj/O6ElCu3Bmf77d7SjGGJc98/l2kuMjg3LNh2NxqgeheE4n7QYOAX8EfqaqC0UkRUSKRSQFQFU/AB4DPgWyvdt9DuVsNjGRYVwxJpWPNuSxo6DE7TjGGJdk7S1k8dYCrhnfK+hvjKvPkbSqmq+qp6tqe1WNV9WTVPUZ73u7VDVWVXfVaf9nVU32tr2mzuWxLcqV41IJDwlhzmLrRRjTVj3z+XZiIkKZNjrF7SiN1rLKWQuTFBfFxSO683rmbg4Ut8gaZ4w5AXsPl/H22n1MHZ1CQrtwt+M0mhWIALvu1F5UVNcy/6tst6MYYxz2/BLPlYzBvObDsViBCLC+yXFMHJDEvKXZlFXWuB3HGOOQwtIqXl62ix8M7UqPDtFux2kSKxAOuPGM3hwsqeS1zBy3oxhjHPLismxKKmu44TRfV/O3DFYgHDAqrSMnp3bgmS+2U11j61Yb09qVV9Xw/JIdnNYvkUHd4o//A0HKCoRDbjy9N7sPlfHuun1uRzHGBNgbK3ZTUFzJjaenux3lhFiBcMjEAUn0TYrlqc+227rVxrRiNbXKM19sZ1iPBMamf2eO0RbFCoRDQkKEmael882+I3y2OThmnjXGNL/31+8j+0ApN57eu0VNq+GLFQgHXTS8O10Tonhy0Ta3oxhjAkBVeeqzbfTqHMPZg7u4HeeEWYFwUERYCNedms6yHQdZkX3I7TjGmGb2+ZYC1u85wo2npxMa0rJ7D2AFwnHTRvekQ3Q4Ty7a6nYUY0wzm/XpVromRHHxiB5uR2kWViAcFh0RxrXje/HxN/v5Zt8Rt+MYY5pJ5s6DLNtxkOtPTW9xk/I1pHX8V7QwV45NIzYyzMYijGlFZi3aRseYCKaO7ul2lGZjBcIFCdHhTB+Twjtr97LTpgI3psXL2lvIJxv3c824NKIjwtyO02ycWjAoUkTmiEi2iBSJyCoRmdRA26tFpMa7RsTR7QwncjppxoRehIWG8NRn1oswpqV7ctE2YiPDuHJsmttRmpVTPYgwIAc4HUgA7gVeE5G0Btov9a4RcXRb5EhKByXFRTF1VE/+uXI3ew6XuR3HGNNEW/cX8+66fVw+JpWE6JY3pfexOLVgUImq3q+qO1W1VlXfAXYAJzux/2B1w+meSbyetl6EMS3WrE+3EhkWwnWn9nI7SrNzZQxCRJKBfkBWA01GiEiBiGwWkXtFxOdJPRGZKSKZIpKZn9/y7k7u3r4dk0f24B/Lc9h/pNztOMaYRso+UMLCNXu5/JRUOsdGuh2n2TleIEQkHHgJeEFVN/po8jkwBEgCJgPTgDt9fZaqzlbVDFXNSExMDFTkgLr5jD7U1CpPf27LkhrT0sz6dBuh3ml0WiNHC4SIhADzgUrgFl9tVHW7qu7wnopaBzwATHEwpqNSOkVz0fBuvLQsmwJbltSYFmP3oVL+uXI300b1JCk+yu04AeFYgRDPrFVzgGRgsqpW+fmjCrT8e9aP4Sdn9qGiupZnvrBehDEtxVOfbUPkf2OJrZGTPYgngYHABara4GU7IjLJO0aBiAzAc8XTQmciuqN3Yiw/GNqN+UuzOVhS6XYcY8xx7Css47Xlu5lyck+6tW/ndpyAceo+iFTgBmA4kFvn/obpIpLifZzibT4RWCsiJcB7wALgESdyuumnZ/WhrKqG2TYWYUzQm/XpNmpV+cmZrbf3AJ77EwJOVbM59mmi2Dpt7wDuCHioINM3OY4fDO3GvKU7uf7UXnRqhVdEGNMa7D1cxqvLc/hxRk96dIh2O05A2VQbQeS2iZ5exDNf7HA7ijGmAbMWbUVp/b0HsAIRVPokxXGBtxdxwK5oMibotKXeA1iBCDo/9fYiZtsVTcYEnSc+9azj8pMz+7icxBlWIIJMn6Q4LoyvZN5/viE/tiOkpcFLL7kdy5g2b/ehUl7L9PQeurfiK5fqsgIRbF56idsev5OKkDCeOmUyZGfDzJlWJIxx2d//sxUR4daz2kbvAaxABJ+77yZ97zZ+lPUJ80ecR25sJygthbvvdjuZMW3WjoIS3li5m8tGp9A1oW30HsAKRPDZtQuA25b8g1oJ4fGxl3zrdWOM8/768WbCQ4Wb28CVS3VZgQg2KZ77BXsW5nHp2o94ddjZ5MQn/fd1Y4yzNucVsXDNXq4al0ZSXOucc6khfhUIEYkWkREiEufjvfHNH6sNe/hhiPZcPnfL0lcRVf5+2uWe140xjvvLx5uJiQjjxtPaVu8B/CgQIjIayAYWAXkicle9Ju8HIFfbNX06zJ4Nqal0LT7I5dsW88bgM9l29kVuJzOmzVm/p5D31uVy7fg0OsREuB3Hcf70IP4E/EZVE4BxwOUi8lSd91v1TKuumD4ddu6E2lpunv8IURFh/PmjzW6nMqbNeezDTbSPDue6Vrrew/H4UyCGAM8CqOpqYAIwQETme9d3MAHUOTaS605N5911+1i3u9DtOMa0GUu3HeDzzfn85Iw+xEe1rrWm/eXPF3wp8N/l2lT1CHCu97U3sB5EwF1/ai86RIfz2Ie+FuAzxjQ3VeWxDzfSJT6KK8amuh3HNf4UiM+Ay+q+oKrlwIVAONB2Lgp2SVxUODef0YcvthTw5bYCt+MY0+r9e0Meq3Yd5rbv9SUqPNTtOK7xp0Dcho8Fe1S1ErgYOLO5Q5nvumJsKl0Tonjsg02oqttxjGm1amqVP360ifTOMfz45B5ux3HVcQuEquYDQ48+F5EL67xXraqfH+8zRCRSROaISLaIFInIKhGZdIz2PxeRXBEpFJHnRKTNL44QFR7Kz77Xl9U5h/kwK9ftOMa0WgtW7mZzXjG3n92PsNC2Pczq7399kohcKyJX4VlTurHCgBzgdCABzzKir4lIWv2GInIO8Cs8K8ulAenA75qwz1Zn8sge9EmK5bEPNlFVU+t2HGNanfKqGv78780M65HA+Sd1dTuO6/y5D+I0YAtwHXA9sNn7mt9UtURV71fVnapaq6rvADuAk300vwqYo6pZqnoIeBC4ujH7a63CQkP41bkD2F5Qwj+W57gdx5hW5/klO9lXWM6vJg1ExK6/8acH0QtIxTMYHe193OtEdioiyUA/IMvH24OBNXWerwGSRaSTj8+ZKSKZIpKZn59/IpFajIkDkxid1pG/fryFkopqt+MY02ocKqlk1qKtnDUgibG9v/N10yb5MwbxAlAIvOjdjnhfaxIRCQdeAl5QVV/XbcZ693fU0cffmeZDVWeraoaqZiQmJtZ/u1USEX593gAKiit4xhYVMqbZPPHpVkoqqvnluQPcjhI0/B2DSAT+AvyNOvdENJb3xrr5QCVwSwPNioH4Os+PPi5q6n5bmxEpHTjvpC7M/nw7+4vK3Y5jTIuXc7CUeUuzmTyyB/27fOffom2WvwVC+N8NcU06MSeeE3pz8AxyT1bVqgaaZgHD6jwfBuSp6oGm7Le1uuucAVTV1PJ//7YpOIw5UY9+sJGQELj97H5uRwkq/haIfDz3Q9wK7G/ivp4EBgIXqGrZMdrNA2aIyCAR6QDcA8xt4j5brbTOMVwxJo1Xl+ewMfeI23GMabFWZB/inbX7mHlqeptaDMgf/lzFdCWe0zyXe7d472t+E5FU4AZgOJArIsXebbqIpHgfpwCo6gfAY8CneGaRzQbua8z+2oqfTuxDXFQ4D7/7jd08Z0wTqCoPvbuBxLhIbji97U3nfTxhfrTJ9v5vKaB1nvtNVbM59qmp2Hrt/wz8ubH7aWvaR0fw04l9efCdDSzanM+Z/ZPcjmRMi/Luun2s2nWYxyYPJSbSn6/DtsWfq5g+w3NJ6rPerZ/3NRMErhiTSlqnaB559xuq7eY5Y/xWXlXDox9sZECXOCa38Sk1GtKYMYjtqjrX+/i/RGRac4cy/osIC+FXkwayZX8xr3xt61Yb46/nl+wk52AZ95w/iNAQuynOF78KhKq+BbwhIo8C7wKISHsReRWbBsN15wxOZmx6J/78780cLq10O44xQW9/UTmPf7KF7w9KZkLfzm7HCVqNmYlqGJ5B5uUiMgNYBxwGRgQimPGfiPDbCwZRWFbFXz7e4nYcY4LeHz7YRGVNLXefN9DtKEHN7wKhqnuBH3p/ZjbwvqreoKolgQpn/DewazzTRqcw/6tstuTZPYXGNGRNzmFeX7Gbayf0Iq1zjNtxgprfBUJEhgOZwHbgIuAsEXlFRNoHKpxpnNu/34+YiFAeeGeDXfZqjA+qyu/ezqJzbCS3nNnH7ThBrzGnmP4D/FlVf+idjXUYnktf1wUkmWm0TrGR/Ox7/fhiSwEff9PU+xmNab0Wrt7Lyl2Huevc/sS10XWmG6MxBWKUqs45+sQ7hfcM4CfNH8s01RVjU+mbFMsD72RRXlXjdhxjgkZReRUPv/cNw3okMGWkXdbqj8aMQficOlRV/9V8ccyJCg8N4XcXDibnYBlPf2azvRpz1N/+s4WC4goeuGgIIXZZq1/a9np6rdS4Pp05f2hXZi3aSs7BUrfjGOO6LXlFPL9kJ5dm9GRYTxs29ZcViFbqnvMHEiLCQ+9ucDuKMa5SVe5/O4voiFDuPKe/23FaFCsQrVTXhHbcOrEPH2blsWiTDVibtuu9dbks2XqAO8/pT6fYSLfjtChWIFqx6yakk54Yw33/sgFr0zYVlVfxwDtZDO4Wz2WnpLodp8WxAtGKRYSF8NBFQ8g+UMqsT7e6HccYx/3535vZX1TBwxefZPMtNYEViFZuXJ/O/HB4N576bDvb8ovdjmOMY9bvKeSFL3cy/ZQUhtvAdJM4ViBE5BYRyRSRChGZe4x2V4tITZ1FhYpF5AyncrZGd58/iMjwEO59a73dYW3ahNpa5Z631tMxJoI7zxngdpwWy8kexF7gIeA5P9ouVdXYOtuiwEZr3RLjIrnr3AF8ue0AC1fvdTuOMQH38te7WJ1zmHvOH0RCO7tjuqkcKxCqusA7bfgBp/Zp/uey0Z5u9oPvbOBQiU0JblqvvCPlPPr+Rsb36cRFw7u5HadFC9YxiBEiUiAim0XkXhHxuRagiMz0nrbKzM/P99XEeIWGCL+ffBKFZZ7pBoxpre5bmEVlTS0P//AkRGxg+kQEY4H4HBgCJAGTgWnAnb4aqupsVc1Q1YzExEQHI7ZMA7rEc8Pp6byxYjdLtha4HceYZvdhVi4fZOVy2/f62lTezSDoCoSqblfVHapaq6rrgAeAKW7nai1uPasvaZ2i+c2b6+zeCNOqFJVXcd/CLAZ0ieP6U9PdjtMqBF2B8EEB6yc2k6jwUB750UlkHyjl/z7e7HYcY5rNox9sJK+onN9PHkp4aEv4agt+Tl7mGiYiUUAoECoiUb7GFkRkkogkex8PAO4FFjqVsy0Y17szU0f15JnPt7Mm57DbcYw5YUu3HeDFr3Zxzbheds9DM3KyzN4DlAG/Ai73Pr5HRFK89zqkeNtNBNaKSAnwHrAAeMTBnG3Cb84fSFJcFHe9sZbK6lq34xjTZKWV1fzyn2tJ7RRtk/E1Mycvc71fVaXedr+q7vLe67DL2+4OVU1W1RhVTVfV36pqlVM524r4qHAevngIm/KKeNym4TAt2J8+2syug6X8/kdDaRcR6nacVsVO1LVhEwcmc/GI7sz6dCsb9h5xO44xjbYi+xDPLdnB5WNSGNu7k9txWh0rEG3cb38wiPbR4dzx+ho71WRalLLKGu58fQ3dEtrxq0kD3Y7TKlmBaOM6xETw8MUnsWHfEf7+yRa34xjjt8c+3Mj2ghIemzKU2Eif99KaE2QFwnDO4C78aGR3Zi3axmq7qsm0AF9uK+D5JTu5amwq4/t0djtOq2UFwgBw3wWDSYqL5BevrbYb6ExQKyqv4s7X15LWKZpfTrKZWgPJCoQBIKFdOI9OHsq2/BL+8OEmt+MY06CH3vmGfYVl/OmSYURH2KmlQLICYf7rtH6JXDEmlTmLd7B4i83VZILPB+tzeTUzhxtO783JqR3djtPqWYEw3/Kb8wbSOzGGX7y+msOlNi24CR77j5Tz6wVrGdI9np9/r5/bcdoEKxDmW9pFhPLXqSM4WFLJb95cZyvQmaBQW6vc8cZayqpq+MulI4gIs68uJ9hRNt8xpHsCt3+/P++ty+WNFbvdjmMMLyzdyeeb87n7/EH0SYp1O06bYQXC+DTztHRO6dWR+/6Vxfb8YrfjmDZsw94j/L/3N3LWgCQuPyXl+D9gmo0VCONTaIjwl6nDiQgL4dZXVlFRbZe+GueVVlZzyysrad8unD9MGWorxDnMCoRpUNeEdvxhyjCy9h7h9+9vdDuOaYPuW5jFjoIS/jJ1OJ1iI92O0+ZYgTDH9P1ByVw9Lo3nl+zk4w15bscxbcjC1Xt4fcVubjmzD+N6293SbnBywaBbRCRTRCpEZO5x2v5cRHJFpFBEnhMR+6eDi3593gAGdY3njjfWsPtQqdtxTBuwLb+Y3yxYR0ZqB26b2NftOG2Wkz2IvcBDwHPHaiQi5+BZVGgikAakA78LdDjTsMiwUJ6YPpLqGuWWl1fZrK8moMoqa7j5xZVEhIXwt2kjCLPlQ13j5IJBC1T1LeDAcZpeBcxR1SxVPQQ8CFwd6Hzm2Hp1juEPU4ayOucwj7z3jdtxTCt278L1bN5fxF+mjqBb+3Zux2nTgrE0DwbW1Hm+BkgWEVsNxGWTTurKNePTmPvlTt5du8/tOKYVei0zhzdW7ObWM/twer9Et+O0ecFYIGKBwjrPjz6Oq99QRGZ6xzUy8/PzHQnX1v160kBGpLTnrjfWsHV/kdtxTCuyfk8h9761nnG9O3GbTaURFIKxQBQD8XWeH338nW8jVZ2tqhmqmpGYaP/acEJEWAizpo+kXUQoM+evoKjclgs3J+5gSSU3zF9Bx5gI/jZtBKEhdr9DMAjGApEFDKvzfBiQp6rHG7swDuma0I7HLxtJ9oFSbn9tDbW1Nl+TabrqmlpufWUl+cUVPHX5yXS2+x2ChpOXuYaJSBQQCoSKSJSI+JrMfR4wQ0QGiUgH4B5grlM5jX/GpHfi7vMG8u8NeTz+6Va345gW7A8fbgGUCaEAABIkSURBVGLJ1gM89MMhDOvZ3u04pg4nexD3AGV4LmG93Pv4HhFJEZFiEUkBUNUPgMeAT4Fs73afgzmNn64Zn8bFI7rzfx9v5qOsXLfjmBborVV7ePrz7Vw+JoVLMnq6HcfUI61lOueMjAzNzMx0O0abU15Vw6VPL2XL/mL+edM4BnaNP/4PGQOs2nWIS2d/xciU9syfcQrhdr+DK0Rkhapm+HrP/h8xJyQqPJTZV2YQFxXGdS9kUlBc4XYk0wLsKyxj5vwVdImP4snpJ1txCFL2/4o5YcnxUTxzZQYHSiq4cf4Km/nVHFNpZTXXz8ukrLKGZ6/KoENMhNuRTAOsQJhmMbRHe/7442FkZh/irjfW2kp0xqeaWuWnr6xmw94j/G3acPolf+f2JhNEfF1FZEyT/GBoN3YdLOWxDzaR0jGaX5zd3+1IJsg8+M4GPv4mjwcuGsxZA5LdjmOOwwqEaVY3nd6bXQdK+fsnW+nZIZpLRtmVKcbjucU7mPvlTmZM6MWVY9PcjmP8YAXCNCsR4cEfDmHP4TJ+8+Y6kuIjOaN/ktuxjMveX7ePB9/dwDmDk/nNeQPdjmP8ZGMQptmFh3qm4+jfJY6bXlzJql2H3I5kXPTltgJu+8dqRqZ04C+X2jQaLYkVCBMQcVHhzL1mNEnxkVw7dzlb9xe7Hcm4YP2eQmbOW0Fqp2jmXJVBu4hQtyOZRrACYQImMS6SedeOJjREuOq5r9l7uMztSMZB2QdKuPr55cRHhTFvxmjaR9vlrC2NFQgTUKmdYph7zWiOlFVx+bPLyC+yG+nagr2Hy7jsmWXU1NbywrWj6ZpgC/+0RFYgTMAN6Z7A89eMYl9hOVfMWcbh0kq3I5kA2l9UzvRnl3GkrIr5M06hr93r0GJZgTCOyEjryDNXZrA9v4Srnvva1pFopQ6VVHLlnK/JLSxn7rWjGNI9we1I5gRYgTCOmdC3M7OmjyRr7xGutCLR6hwqqWT6s8vYXlDCs1dlcHJqR7cjmRNkBcI46nuDknn8spGs213Ilc99zRErEq3CwZJKLnt2GVvzi3n2ygzG9+nsdiTTDKxAGMedO6QLT0z3Fok5ViRauoNHew7e4nBaP1v+t7VwckW5jiLypoiUiEi2iFzWQLv7RaTKu4jQ0S3dqZzGGecM7uI93VTIZc98xQGbJrxFyjtSzqVPL2V7fjHPWHFodZzsQTwBVALJwHTgSREZ3EDbV1U1ts623bGUxjFnD+7C7Csz2JJXzKWzvyLvSLnbkUwj5Bws5cdPLWXv4TLmXjPaikMr5EiBEJEYYDJwr6oWq+pi4F/AFU7s3wSvM/sn8cK1o8ktLGfKU1+y60Cp25GMH7buL+LHTy2lsKyKl64fw9jendyOZALAqR5EP6BGVTfXeW0N0FAP4gIROSgiWSJyU0MfKiIzRSRTRDLz8/ObM69x0Jj0Trx03SkUlVfzoyeXsG53oduRzDEs33mQyU8upbpWefWGMQzv2d7tSCZAnCoQsUD9v/pCwNcdNK8BA4FE4HrgtyIyzdeHqupsVc1Q1YzEROvetmTDerbnjRvHERkWyqWzl/LZZiv4weiD9blc/uwyOsVE8ObN4xjQxdYgb82cKhDFQP3fpHigqH5DVd2gqntVtUZVvwT+CkxxIKNxWZ+kWBbcPI7UTjHMmLuc15bnuB3JeKkqc5fs4KaXVjCoWzxv3DSOnh2j3Y5lAsypArEZCBORvnVeGwZk+fGzCtj8wG1EcnwUr93gOad91z/X8sh731BTa8uXuqmqppZ7F67n/rc3MHFAMi9fN4aOto50m+BIgVDVEmAB8ICIxIjIeOAiYH79tiJykYh0EI/RwE+BhU7kNMEhLiqc564exRVjUpn9+XZumL+C4opqt2O1SYVlVVw7dzkvfrWLG05L5+krTrYpu9sQJy9zvRloB+wHXgFuUtUsETlVROouFjAV2Irn9NM84FFVfcHBnCYIhIeG8OAPh/C7CwfzycY8fjRrCdvzbU0JJ23KLeKixxfz1fYDPDZlKL8+b6At9tPGiGrr6L5nZGRoZmam2zFMACzeUsCtr6ykukb586XD+f4gW+w+0N5es5e73lhLbFQYs6aPZFSazavUWonIClXN8PWeTbVhgt6Evp15+9YJpHaO5vp5mfzhw41U19S6HatVqqiu4XdvZ3HrK6sY3C2ed2+dYMWhDbMCYVqEHh2ieePGcVya0ZMnPt3G1NlfscdWqGtWOwtKmPLkUp5fspOrx6Xx8vVjSIqPcjuWcZEVCNNiRIWH8uiUofx16nA25hZx3l+/4IP1+9yO1eKpKm+t2sMP/r6YXQdLefqKk7n/wsFEhNnXQ1tnvwGmxbloeHfe/ekEUjtFc+OLK/n5q6spLLMZYZviQHEFN7+0kp+9upoBXeJ477ZTOWdwF7djmSAR5nYAY5oitVMM/7xpHI9/spXHP93K0m0HeHTKUE63CeP89mFWLne/uY4jZdX88twBzDwt3a5SMt9iPQjTYoWHhvDz7/fjzZvHERsVxlXPfc1t/1hFgU0dfky5heXcOH8FN8xfQWJcFP+6dTw3ndHbioP5DrvM1bQK5VU1zFq0jScXbSU6IoxfTxrAJRk9CbEvvf+qrqnlpWW7+MOHm6iqqeW27/Xl+lPTCQ+1fye2Zce6zNUKhGlVtuQV8Zs317F85yFO6p7AfRcMIsMu02TJ1gIeeHsDm/KKmNCnMw9fPITUTjFuxzJBwAqEaVNUlYWr9/L79zeSe6ScHwztyh1n9yetc9v7QtySV8QfPtzERxvy6NGhHXefN5Bzh3RBxHpWxsMKhGmTSiureWrRNp75YgeVNbVcktGT2yb2pUtC67+2P+dgKX/5eAtvrtpNdEQYN56eznWnphMVbvMomW+zAmHatP1F5TzxyVZe/noXIsKUk3tww2nprfIUy9b9xTz12TbeWrWHkBDhqrGp3HRGH5t91TTICoQxeP5VPWvRNv65YjfVtbWcd1JXrhnfi5Ep7Vv0KRdVZdmOg8xdspMPN+QSGRbC1FEp3HB6Ol0T2rkdzwQ5KxDG1LH/SDlzFu/g5WW7KKqoZnC3eK4cm8r5Q7sRG9lybg06Ul7Fv1bvZf7SbDblFZHQLpwrxqRyzfg0OsVGuh3PtBBWIIzxoaSimjdX7WHe0p1sziumXXgo5w7pwsUjujO2d6egvPyzsrqWxVvzWbByDx9tyKOyupZBXeO5elwaFwzrZms1mEazAmHMMagqK7IPsWDVHt5Zs5cj5dUktAtn4sAkzh7UhfF9OhEXFe5avsLSKhZvLeDDrFw+3bifoopqOkSHc+Gwblw8sgfDeiS06FNkxl1BUSBEpCMwBzgbKAB+raov+2gnwO+B67wvzQF+qccJagXCNIfyqhoWbcrnow25/Oeb/RSWVREaIgzrkcD4Pp0ZmdqBYT3aB3TQt6C4gjU5h1mRfYglWwtYt6eQWoWOMRF8z1u0TuuXaJPpmWZxrALh5AnXJ4BKIBkYDrwrImtUtf661DOBH+JZs1qBfwPbgacczGraqCjvaaZzh3ShqqaWzJ2eL+kl2wqYtWjbf9fH7tmxHf2T4+mTFEvvxBh6dIimS0IUXeKj/DrNU1JRTe6RcvIKy9l9qIxt+cVsyy/mm31F/53GPCxEGJHSnlvP6suEvp0Z0bM9YUF42su0Xo70IEQkBjgEDFHVzd7X5gN7VPVX9dp+CcxV1dne5zOA61V1zLH2YT0IE2jFFdWs31PImpzDrNl9mC15xew8UEJVzbf/hiLDQoiLCiMmMowI7xe64hk/KKmopqiimsrqby94FBEWQnrnGPokxTKsR3uGp7RncLd4oiNazqC5aZmCoQfRD6g5Why81gCn+2g72Pte3XaDfX2oiMzE0+MgJSWleZIa04DYyDDGpHdiTHqn/75WXVNLzqEy9h4uI7ewnNwj5Rwpq6Koopri8mqqa/9XCMJDQ4iNDCM2Koz27SLokhBJcnwU3du3o0eHaJsszwQdpwpELFBY77VCIM6PtoVArIhI/XEIby9jNnh6EM0X1xj/hIWG0KtzDL3a4DQepvVz6oRmMRBf77V4oMiPtvFA8fEGqY0xxjQvpwrEZiBMRPrWeW0YUH+AGu9rw/xoZ4wxJoAcKRCqWgIsAB4QkRgRGQ9cBMz30XwecLuIdBeRbsAvgLlO5DTGGPM/Tl4zdzPQDtgPvALcpKpZInKqiBTXafc08DawDlgPvOt9zRhjjIMcu4ZOVQ/iub+h/utf4BmYPvpcgbu8mzHGGJfYXTfGGGN8sgJhjDHGJysQxhhjfGo1s7mKSD6Q3cQf74xnAsFgY7kax3I1XrBms1yNcyK5UlU10dcbraZAnAgRyWxoLhI3Wa7GsVyNF6zZLFfjBCqXnWIyxhjjkxUIY4wxPlmB8JjtdoAGWK7GsVyNF6zZLFfjBCSXjUEYY4zxyXoQxhhjfLICYYwxxicrEMYYY3xqcwVCRCJFZI6IZItIkYisEpFJx/mZn4tIrogUishzIhIZoGy3iEimiFSIyNzjtL1aRGpEpLjOdobbubztnTpeHUXkTREp8f7/edkx2t4vIlX1jle601nE41EROeDdHhORgK012ohcAT0+9fbVmN9zR36XGpvN4b+/Rn1nNecxa3MFAs8Mtjl41sNOAO4FXhORNF+NReQc4FfARCANSAd+F6Bse4GHgOf8bL9UVWPrbIvczuXw8XoCqASSgenAkyLic/1yr1frHa/tLmSZiWdW42HAUOAHwA3NmKOpuSCwx6cuv36fHP5dalQ2L6f+/vz+zmr2Y6aqbX4D1gKTG3jvZeCROs8nArkBzvMQMPc4ba4GFjt8nPzJ5cjxAmLwfPH1q/PafOD3DbS/H3gxQMfF7yzAl8DMOs9nAF8FQa6AHZ+m/j658bfXiGyO//3V27/P76zmPmZtsQfxLSKSDPSj4WVNBwNr6jxfAySLSKdAZ/PDCBEpEJHNInKviDi2vscxOHW8+gE1qrq53r6O1YO4QEQOikiWiNzkUhZfx+dYmZ3KBYE7Pk0VzH974NLf33G+s5r1mLXpAiEi4cBLwAuqurGBZrFAYZ3nRx/HBTKbHz4HhgBJwGRgGnCnq4k8nDpe9fdzdF8N7ec1YCCQCFwP/FZEprmQxdfxiQ3QOERjcgXy+DRVsP7tgUt/f358ZzXrMWt1BUJEFomINrAtrtMuBE93uxK45RgfWQzE13l+9HFRIHL5S1W3q+oOVa1V1XXAA8CUxn5Oc+fCueNVfz9H9+VzP6q6QVX3qmqNqn4J/JUmHK8GNCaLr+NTrN7zAc3M71wBPj5N1Sy/S4HQXH9/jeHnd1azHrNWVyBU9QxVlQa2CeC5kgSYg2fgbrKqVh3jI7PwDCgeNQzIU9UDzZ3rBCnQ6H+FBiCXU8drMxAmIn3r7auhU4Xf2QVNOF4NaEwWX8fH38yBzFVfcx6fpmqW3yWHBPR4NeI7q1mPWasrEH56Ek93+gJVLTtO23nADBEZJCIdgHuAuYEIJSJhIhIFhAKhIhLV0HlNEZnkPReJiAzAc2XDQrdz4dDxUtUSYAHwgIjEiMh44CI8/8Ly9d9wkYh0EI/RwE9ppuPVyCzzgNtFpLuIdAN+QYB+nxqTK5DHx8e+/P19cuxvr7HZnPz78/L3O6t5j5lbo/BubUAqnmpfjqc7dnSb7n0/xfs8pc7P3A7kAUeA54HIAGW735ut7na/r1zAH72ZSoDteLq44W7ncvh4dQTe8h6DXcBldd47Fc+pm6PPXwEOeLNuBH7qRBYfOQR4DDjo3R7DOyeak8fI6ePjz++Tm79Ljc3m8N9fg99ZgT5mNlmfMcYYn9rqKSZjjDHHYQXCGGOMT1YgjDHG+GQFwhhjjE9WIIwxxvhkBcIYY4xPViCMMcb4ZAXCGGOMT1YgjDHG+GQFwpgAEJHe3rUVRnqfd/OuHXCGy9GM8ZtNtWFMgIjI9XjmxTkZeBNYp6p3uJvKGP9ZgTAmgETkX0AvPJOtjVLVCpcjGeM3O8VkTGA9g2flsb9bcTAtjfUgjAkQEYnFsybwp8Ak4CRVPehuKmP8ZwXCmAARkTlAnKpeIiKzgfaqeonbuYzxl51iMiYAROQi4FzgRu9LtwMjRWS6e6mMaRzrQRhjjPHJehDGGGN8sgJhjDHGJysQxhhjfLICYYwxxicrEMYYY3yyAmGMMcYnKxDGGGN8sgJhjDHGp/8PWA5XvPpRWP4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(f, 'x', 'x**2')\n", "plt.scatter(-1.5, f(-1.5), color='red');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we look to see what would happen if we increased or decreased our parameter by a little bit—the *adjustment*. This is simply the slope at a particular point:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can change our weight by a little in the direction of the slope, calculate our loss and adjustment again, and repeat this a few times. Eventually, we will get to the lowest point on our curve:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"An" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This basic idea goes all the way back to Isaac Newton, who pointed out that we can optimize arbitrary functions in this way. Regardless of how complicated our functions become, this basic approach of gradient descent will not significantly change. The only minor changes we will see later in this book are some handy ways we can make it faster, by finding better steps." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculating Gradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The one magic step is the bit where we calculate the gradients. As we mentioned, we use calculus as a performance optimization; it allows us to more quickly calculate whether our loss will go up or down when we adjust our parameters up or down. In other words, the gradients will tell us how much we have to change each weight to make our model better.\n", "\n", "You may remember from your high school calculus class that the *derivative* of a function tells you how much a change in its parameters will change its result. If not, don't worry, lots of us forget calculus once high school is behind us! But you will have to have some intuitive understanding of what a derivative is before you continue, so if this is all very fuzzy in your head, head over to Khan Academy and complete the [lessons on basic derivatives](https://www.khanacademy.org/math/differential-calculus/dc-diff-intro). You won't have to know how to calculate them yourselves, you just have to know what a derivative is.\n", "\n", "The key point about a derivative is this: for any function, such as the quadratic function we saw in the previous section, we can calculate its derivative. The derivative is another function. It calculates the change, rather than the value. For instance, the derivative of the quadratic function at the value 3 tells us how rapidly the function changes at the value 3. More specifically, you may recall that gradient is defined as *rise/run*, that is, the change in the value of the function, divided by the change in the value of the parameter. When we know how our function will change, then we know what we need to do to make it smaller. This is the key to machine learning: having a way to change the parameters of a function to make it smaller. Calculus provides us with a computational shortcut, the derivative, which lets us directly calculate the gradients of our functions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One important thing to be aware of is that our function has lots of weights that we need to adjust, so when we calculate the derivative we won't get back one number, but lots of them—a gradient for every weight. But there is nothing mathematically tricky here; you can calculate the derivative with respect to one weight, and treat all the other ones as constant, then repeat that for each other weight. This is how all of the gradients are calculated, for every weight.\n", "\n", "We mentioned just now that you won't have to calculate any gradients yourself. How can that be? Amazingly enough, PyTorch is able to automatically compute the derivative of nearly any function! What's more, it does it very fast. Most of the time, it will be at least as fast as any derivative function that you can create by hand. Let's see an example.\n", "\n", "First, let's pick a tensor value which we want gradients at:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xt = tensor(3.).requires_grad_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice the special method `requires_grad_`? That's the magical incantation we use to tell PyTorch that we want to calculate gradients with respect to that variable at that value. It is essentially tagging the variable, so PyTorch will remember to keep track of how to compute gradients of the other, direct calculations on it that you will ask for.\n", "\n", "> a: This API might throw you off if you're coming from math or physics. In those contexts the \"gradient\" of a function is just another function (i.e., its derivative), so you might expect gradient-related APIs to give you a new function. But in deep learning, \"gradients\" usually means the _value_ of a function's derivative at a particular argument value. The PyTorch API also puts the focus on the argument, not the function you're actually computing the gradients of. It may feel backwards at first, but it's just a different perspective.\n", "\n", "Now we calculate our function with that value. Notice how PyTorch prints not just the value calculated, but also a note that it has a gradient function it'll be using to calculate our gradients when needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(9., grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yt = f(xt)\n", "yt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we tell PyTorch to calculate the gradients for us:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "yt.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The \"backward\" here refers to *backpropagation*, which is the name given to the process of calculating the derivative of each layer. We'll see how this is done exactly in chapter <>, when we calculate the gradients of a deep neural net from scratch. This is called the \"backward pass\" of the network, as opposed to the \"forward pass,\" which is where the activations are calculated. Life would probably be easier if `backward` was just called `calculate_grad`, but deep learning folks really do like to add jargon everywhere they can!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now view the gradients by checking the `grad` attribute of our tensor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(6.)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xt.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you remember your high school calculus rules, the derivative of `x**2` is `2*x`, and we have `x=3`, so the gradients should be `2*3=6`, which is what PyTorch calculated for us!\n", "\n", "Now we'll repeat the preceding steps, but with a vector argument for our function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 3., 4., 10.], requires_grad=True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xt = tensor([3.,4.,10.]).requires_grad_()\n", "xt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we'll add `sum` to our function so it can take a vector (i.e., a rank-1 tensor), and return a scalar (i.e., a rank-0 tensor):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(125., grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x): return (x**2).sum()\n", "\n", "yt = f(xt)\n", "yt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our gradients are `2*xt`, as we'd expect!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 6., 8., 20.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yt.backward()\n", "xt.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The gradients only tell us the slope of our function, they don't actually tell us exactly how far to adjust the parameters. But it gives us some idea of how far; if the slope is very large, then that may suggest that we have more adjustments to do, whereas if the slope is very small, that may suggest that we are close to the optimal value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stepping With a Learning Rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deciding how to change our parameters based on the values of the gradients is an important part of the deep learning process. Nearly all approaches start with the basic idea of multiplying the gradient by some small number, called the *learning rate* (LR). The learning rate is often a number between 0.001 and 0.1, although it could be anything. Often, people select a learning rate just by trying a few, and finding which results in the best model after training (we'll show you a better approach later in this book, called the *learning rate finder*). Once you've picked a learning rate, you can adjust your parameters using this simple function:\n", "\n", "```\n", "w -= gradient(w) * lr\n", "```\n", "\n", "This is known as *stepping* your parameters, using an *optimizer step*. Notice how we _subtract_ the `gradient * lr` from the parameter to update it. This allows us to adjust the parameter in the direction of the slope by increasing the parameter when the slope is negative and decreasing the parameter when the slope is positive. We want to adjust our parameters in the direction of the slope because our goal in deep learning is to _minimize_ the loss.\n", "\n", "If you pick a learning rate that's too low, it can mean having to do a lot of steps. <> illustrates that." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"An" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But picking a learning rate that's too high is even worse—it can actually result in the loss getting *worse*, as we see in <>!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"An" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the learning rate is too high, it may also \"bounce\" around, rather than actually diverging; <> shows how this has the result of taking many steps to train successfully." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"An" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's apply all of this in an end-to-end example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### An End-to-End SGD Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've seen how to use gradients to find a minimum. Now it's time to look at an SGD example and see how finding a minimum can be used to train a model to fit data better.\n", "\n", "Let's start with a simple, synthetic, example model. Imagine you were measuring the speed of a roller coaster as it went over the top of a hump. It would start fast, and then get slower as it went up the hill; it would be slowest at the top, and it would then speed up again as it went downhill. You want to build a model of how the speed changes over time. If you were measuring the speed manually every second for 20 seconds, it might look something like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time = torch.arange(0,20).float(); time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAD7CAYAAACYLnSTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAWy0lEQVR4nO3dfYxcV3nH8e8vtpWsbC+u48XFW9luDLGpExw3i4KIAkhJa0FLcWMkTFIISMhAlKqo1IK0OLh5UQBD/yiEF0sp5LUNhrUhpGA1SlJICikbXMdaYVt1UkPWKawhXrz2OjHu0z/mTjKezM7c8cydlzu/jzSS59wzd56czD5z5txzz1FEYGZm3e2sdgdgZmaNczI3M8sBJ3MzsxxwMjczywEnczOzHJjZjjddsGBBLF26tB1vbWbWtZ544onDETFQ6VhbkvnSpUsZGRlpx1ubmXUtSQenO+ZhFjOzHHAyNzPLASdzM7MccDI3M8sBJ3Mzsxxoy2yWM7Vj1xhbdu7j0JEpFs3rY+Oa5axdPdjusMzM2q5rkvmOXWNcP7yHqZOnABg7MsX1w3sAnNDNrOd1zTDLlp37XkzkRVMnT7Fl5742RWRm1jm6JpkfOjJVV7mZWS/pmmS+aF5fXeVmZr2ka5L5xjXL6Zs147Syvlkz2LhmeZsiMjPrHF1zAbR4kdOzWczMXq5rkjkUErqTt5nZy3XNMIuZmU3PydzMLAeczM3McsDJ3MwsB2omc0mTZY9Tkj5fcvxySXslHZf0sKQl2YZsZmblaibziJhTfAALgSlgG4CkBcAwsAmYD4wA92UXrpmZVVLvMMs7gV8CP0ieXwmMRsS2iDgBbAZWSVrRvBDNzKyWepP5NcCdERHJ85XA7uLBiDgGHEjKTyNpg6QRSSPj4+NnGq+ZmVWQOplLWgy8GbijpHgOMFFWdQKYW/76iNgaEUMRMTQwMHAmsZqZ2TTq6Zm/F3g0Ip4uKZsE+svq9QNHGw3MzMzSqzeZ31FWNgqsKj6RNBtYlpSbmVmLpErmkt4IDJLMYimxHbhA0jpJ5wA3AE9GxN7mhmlmZtWkXWjrGmA4Ik4bPomIcUnrgC8AdwOPA+ubG6KZWffLeg/jVMk8Ij5Y5diDgKcimplNoxV7GPt2fjOzjLViD2MnczOzjLViD2MnczOzjLViD2MnczOzjLViD+Ou2jbOzKwbtWIPYydzM7MWyHoPYw+zmJnlgJO5mVkOOJmbmeWAk7mZWQ44mZuZ5YCTuZlZDjiZm5nlgJO5mVkOOJmbmeWAk7mZWQ6kTuaS1kv6qaRjkg5Iuiwpv1zSXknHJT0saUl24ZqZWSWp1maR9EfAp4F3Af8JvCopXwAMAx8A7gduAu4D3pBFsI3KetsmM7N2SbvQ1t8DN0bEj5LnYwCSNgCjEbEteb4ZOCxpRadt6tyKbZvMzNql5jCLpBnAEDAg6b8lPSPpC5L6gJXA7mLdiDgGHEjKy8+zQdKIpJHx8fHm/Rek1Iptm8zM2iXNmPlCYBbwTuAy4CJgNfAJYA4wUVZ/AphbfpKI2BoRQxExNDAw0FDQZ6IV2zaZmbVLmmRezHafj4hnI+Iw8A/A24BJoL+sfj9wtHkhNkcrtm0yM2uXmsk8Ip4DngGiwuFRYFXxiaTZwLKkvKO0YtsmM7N2SXsB9KvAX0r6HnAS+AjwHWA7sEXSOuAB4AbgyU67+Amt2bbJzPKr02fDpU3mNwELgP3ACeDrwC0RcSJJ5F8A7gYeB9ZnEWgzZL1tk5nlUzfMhkuVzCPiJHBt8ig/9iCwoslxmZl1jGqz4Tolmft2fjOzGrphNpyTuZlZDd0wG87J3Myshm6YDZf2AqiZWc/qhtlwTuZmZil0+mw4D7OYmeWAk7mZWQ44mZuZ5YCTuZlZDjiZm5nlgJO5mVkOOJmbmeWAk7mZWQ44mZuZ5YCTuZlZDjiZm5nlQKpkLukRSSckTSaPfSXHrpJ0UNIxSTskzc8uXDMzq6Senvl1ETEneSwHkLQS+ArwHmAhcBz4YvPDNDOzahpdNfFq4P6I+D6ApE3ATyXNjYijDUdnZmap1NMzv1XSYUmPSXpLUrYS2F2sEBEHgBeA88tfLGmDpBFJI+Pj443EbGZmZdIm848B5wGDwFbgfknLgDnARFndCWBu+QkiYmtEDEXE0MDAQAMhm5lZuVTJPCIej4ijEfF8RNwBPAa8DZgE+suq9wMeYjEza6EznZoYgIBRYFWxUNJ5wNnA/sZDMzOztGpeAJU0D7gE+Hfgt8C7gDcBH0le/0NJlwE/AW4Ehn3x08ystdLMZpkF3AysAE4Be4G1EbEPQNKHgHuAc4EHgfdnE6qZmU2nZjKPiHHg9VWO3wvc28ygzMysPr6d38wsBxq9aain7Ng1xpad+zh0ZIpF8/rYuGY5a1cPtjssMzMn87R27Brj+uE9TJ08BcDYkSmuH94D4IRuZm3nYZaUtuzc92IiL5o6eYotO/dN8wozs9ZxMk/p0JGpusrNzFrJyTylRfP66io3M2slJ/OUNq5ZTt+sGaeV9c2awcY1y9sUkZnZS3wBNKXiRU7PZjGzTuRkXoe1qwedvM2sI3mYxcwsB5zMzcxywMMsZtYT8n4Ht5O5meVeL9zB7WEWM8u9XriD28nczHKvF+7gdjI3s9zrhTu4nczNLPd64Q7uupK5pNdIOiHp7pKyqyQdlHRM0g5J85sfppnZmVu7epBbr7yQwXl9CBic18etV16Ym4ufUP9sltuAHxefSFoJfAX4EwobOm8Fvgisb1aAZmbNkPc7uFMnc0nrgSPAfwCvToqvBu6PiO8ndTYBP5U0NyKONjtYMzOrLNUwi6R+4Ebgo2WHVgK7i08i4gDwAnB+hXNskDQiaWR8fPzMIzYzs5dJO2Z+E3B7RPy8rHwOMFFWNgHMLT9BRGyNiKGIGBoYGKg/UjMzm1bNYRZJFwFXAKsrHJ4E+svK+gEPsZiZtVCaMfO3AEuBn0mCQm98hqQ/AL4HrCpWlHQecDawv9mBmpnZ9NIk863Av5Q8/xsKyf3DwCuBH0q6jMJslhuBYV/8NDNrrZrJPCKOA8eLzyVNAiciYhwYl/Qh4B7gXOBB4P0ZxWpmZtOoe9XEiNhc9vxe4N5mBWRmZvXz7fxmZjngZG5mlgNO5mZmOeBkbmaWA07mZmY54GRuZpYDTuZmZjngZG5mlgNO5mZmOeBkbmaWA07mZmY5UPfaLGZm7bBj1xhbdu7j0JEpFs3rY+Oa5bne07NeTuZm1vF27Brj+uE9TJ08BcDYkSmuH94D4ISe8DCLmXW8LTv3vZjIi6ZOnmLLzn1tiqjzOJmbWcc7dGSqrvJe5GRuZh1v0by+usp7UapkLuluSc9K+o2k/ZI+UHLsckl7JR2X9LCkJdmFa2a9aOOa5fTNmnFaWd+sGWxcs7xNEXWetD3zW4GlEdEP/Blws6SLJS0AhoFNwHxgBLgvk0jNrGetXT3IrVdeyOC8PgQMzuvj1isv9MXPEqlms0TEaOnT5LEMuBgYjYhtAJI2A4clrYiIvU2O1cx62NrVg07eVaQeM5f0RUnHgb3As8C/AiuB3cU6EXEMOJCUl79+g6QRSSPj4+MNB25mZi9Jncwj4lpgLnAZhaGV54E5wERZ1YmkXvnrt0bEUEQMDQwMnHnEZmb2MnXNZomIUxHxKPB7wIeBSaC/rFo/cLQ54ZmZWRpnOjVxJoUx81FgVbFQ0uyScjMza5GayVzSKyWtlzRH0gxJa4B3Aw8B24ELJK2TdA5wA/CkL36ambVWmp55UBhSeQZ4Dvgs8JGI+FZEjAPrgFuSY5cA6zOK1czMplFzamKSsN9c5fiDwIpmBpVXXvXNzLLiVRNbxKu+Wa9zZyZbXpulRbzqm/WyYmdm7MgUwUudmR27xtodWm44mbeIV32zXubOTPaczFvEq75ZL3NnJntO5i3iVd+sl7kzkz0n8xbxqm/Wy9yZyZ5ns7SQV32zXlX83Hs2S3aczM2sJdyZyZaHWczMcsDJ3MwsB5zMzcxywMnczCwHfAG0i3htCzObjpN5l/BCXWZWjYdZuoTXtjCzapzMu4TXtjCzatJsG3e2pNslHZR0VNIuSW8tOX65pL2Sjkt6WNKSbEPuTV7bwsyqSdMznwn8nMJuQ68ANgFfl7RU0gJgOCmbD4wA92UUa0/z2hZmVk2abeOOAZtLir4j6WngYuBcYDQitgFI2gwclrTCmzo3VzPWtvBsGLP8qns2i6SFwPnAKIWNnncXj0XEMUkHgJXA3rLXbQA2ACxevLiBkHtXI2tbeDaMWb7VdQFU0izgHuCOpOc9B5goqzYBzC1/bURsjYihiBgaGBg403jtDHk2jFm+pU7mks4C7gJeAK5LiieB/rKq/cDRpkRnTePZMGb5liqZSxJwO7AQWBcRJ5NDo8CqknqzgWVJuXUQz4Yxy7e0PfMvAa8F3h4RpV257cAFktZJOge4AXjSFz87j2fDmOVbmnnmS4APAhcB/ytpMnlcHRHjwDrgFuA54BJgfZYB25nxtnVm+aaIaPmbDg0NxcjISMvf18ysm0l6IiKGKh3z7fxmZjngZG5mlgNeAtfMUvEdxJ3NydzMavIdxJ3PwyxmVpPvIO58TuZmVpPvIO58TuZmVpPvIO58TuZmVpPvIO58vgBqZjU1Yz19y5aTuZml0sh6+pY9J3NLzfOMzTqXk7ml4nnGZp3NF0AtFc8zNutsTuaWiucZm3U2D7NYKovm9TFWIXHXM8/YY+5m2XHP3FJpdJ5xccx97MgUwUtj7jt2jWUQrVnvSbsH6HWSRiQ9L+lrZccul7RX0nFJDyc7E1nONLpTkcfc22/HrjEu/dRD/P7HH+DSTz3kL9KcSTvMcgi4GVgDvPi7WtICYBj4AHA/cBNwH/CG5oZpnaCRecYec28vz0bKv1Q984gYjogdwK/KDl0JjEbEtog4AWwGVkla0dwwrdt5bY/28i+j/Gt0zHwlsLv4JCKOAQeS8tNI2pAM1YyMj483+LbWbby2R3v5l1H+NZrM5wATZWUTwNzyihGxNSKGImJoYGCgwbe1btPomLs1xr+M8q/RqYmTQH9ZWT9wtMHzWg55bY/22bhm+Wlj5uBfRnnTaM98FFhVfCJpNrAsKTezDuFfRvmXqmcuaWZSdwYwQ9I5wG+B7cAWSeuAB4AbgCcjYm9G8ZrZGfIvo3xL2zP/BDAFfBz4i+Tfn4iIcWAdcAvwHHAJsD6DOM3MrIpUPfOI2Exh2mGlYw8CnopoZtZGvp3fzCwHnMzNzHLAydzMLAe8BK5Zl/ASwlaNk7lZF/BCWVaLh1nMuoAXyrJanMzNuoAXyrJaPMxiXaOXx4ybsW2f5Zt75tYV8rDtXCM7/XgJYavFydy6QrePGTf6ZeSFsqwWD7NYV+j2MeNqX0ZpE7IXyrJq3DO3rtDtmyt0+5eRdT4nc+sK3T5m3O1fRtb5nMytK3T7mHG3fxlZ5/OYuXWNbh4zLsbdq1MrLXtO5mYt0s1fRtb5mjLMImm+pO2Sjkk6KOmqZpzXzMzSaVbP/DbgBWAhcBHwgKTdEeGNnS03evkOVOt8DffMJc2msA/opoiYjIhHgW8D72n03GadIg93oFq+NWOY5XzgVETsLynbDaxswrnNmqaR2+m7/Q5Uy79mDLPMASbKyiaAuaUFkjYAGwAWL17chLc1S6/R9cB90491umb0zCeB/rKyfuBoaUFEbI2IoYgYGhgYaMLbmqXXaM/aN/1Yp2tGMt8PzJT0mpKyVYAvflrHaLRn7Zt+rNM1nMwj4hgwDNwoabakS4F3AHc1em6zZmm0Z93td6Ba/jVrauK1wD8BvwR+BXzY0xKtk2xcs/y0MXOov2ftm36skzUlmUfEr4G1zTiXWRZ8O73lnW/nt57hnrXlmVdNNDPLASdzM7MccDI3M8sBJ3MzsxxwMjczywFFROvfVBoHDjZwigXA4SaFkwXH1xjH1xjH15hOjm9JRFRcD6UtybxRkkYiYqjdcUzH8TXG8TXG8TWm0+ObjodZzMxywMnczCwHujWZb213ADU4vsY4vsY4vsZ0enwVdeWYuZmZna5be+ZmZlbCydzMLAeczM3McqAjk7mk+ZK2Szom6aCkq6apJ0mflvSr5PEZSco4trMl3Z7EdVTSLklvnabu+ySdkjRZ8nhLlvEl7/uIpBMl71lxo8s2td9k2eOUpM9PU7cl7SfpOkkjkp6X9LWyY5dL2ivpuKSHJS2pcp6lSZ3jyWuuyDI+SW+Q9G+Sfi1pXNI2Sa+qcp5Un4smxrdUUpT9/9tU5Tytbr+ry2I7nsR78TTnyaT9mqUjkzlwG/ACsBC4GviSpJUV6m2gsCnGKuB1wJ8CH8w4tpnAz4E3A68ANgFfl7R0mvo/jIg5JY9HMo6v6LqS95xuO52Wt19pW1D4/zsFbKvykla03yHgZgq7Zb1I0gIKWyJuAuYDI8B9Vc7zz8Au4Fzg74BvSGrG7uUV4wN+h8LMi6XAEgqbqH+1xrnSfC6aFV/RvJL3vKnKeVrafhFxT9nn8VrgKeAnVc6VRfs1Rcclc0mzgXXApoiYjIhHgW8D76lQ/RrgcxHxTESMAZ8D3pdlfBFxLCI2R8T/RMT/RcR3gKeBit/mHa7l7VfmnRS2GvxBC9/zZSJiOCJ2UNjysNSVwGhEbIuIE8BmYJWkFeXnkHQ+8IfAJyNiKiK+Ceyh8FnOJL6I+G4S228i4jjwBeDSRt+vWfHVox3tV8E1wJ3RpVP8Oi6ZA+cDpyJif0nZbqBSz3xlcqxWvcxIWkgh5un2PF0t6bCk/ZI2SWrV7k63Ju/7WJWhiXa3X5o/nna1H5S1T7J5+QGm/yw+FRFHS8pa3Z5vYvrPYVGaz0WzHZT0jKSvJr92Kmlr+yXDZ28C7qxRtR3tl0onJvM5wERZ2QQwN0XdCWBO1uO+RZJmAfcAd0TE3gpVvg9cALySQg/j3cDGFoT2MeA8YJDCz/D7JS2rUK9t7SdpMYWhqjuqVGtX+xU18lmsVrfpJL0OuIHq7ZP2c9Esh4HXUxgCuphCW9wzTd22th/wXuAHEfF0lTqtbr+6dGIynwT6y8r6KYwH1qrbD0y24meSpLOAuyiM7V9XqU5EPBURTyfDMXuAGykMLWQqIh6PiKMR8XxE3AE8BrytQtW2tR+FP55Hq/3xtKv9SjTyWaxWt6kkvRr4LvBXETHtkFUdn4umSIZJRyLitxHxCwp/J38sqbydoI3tl3gv1TsWLW+/enViMt8PzJT0mpKyVVT++TiaHKtVr6mSnuvtFC7grYuIkylfGkBLfjWkfN+2tF+i5h9PBa1uv9PaJ7mes4zpP4vnSSrtSWbensnwwIPATRFxV50vb3V7FjsJ030WW95+AJIuBRYB36jzpe36e66o45J5Mi45DNwoaXbS0O+g0Asudyfw15IGJS0CPgp8rQVhfgl4LfD2iJiarpKktyZj6iQXzTYB38oyMEnzJK2RdI6kmZKupjAWuLNC9ba0n6Q3UvipWm0WS8vaL2mnc4AZwIxi2wHbgQskrUuO3wA8WWlILbnG81/AJ5PX/zmFGULfzCo+SYPAQ8BtEfHlGueo53PRrPgukbRc0lmSzgX+EXgkIsqHU9rSfiVVrgG+WTZeX36OzNqvaSKi4x4UpoHtAI4BPwOuSsovozAMUKwn4DPAr5PHZ0jWm8kwtiUUvpFPUPhpWHxcDSxO/r04qftZ4BfJf8dTFIYJZmUc3wDwYwo/T48APwL+qFPaL3nfrwB3VShvS/tRmKUSZY/NybErgL0UplA+Aiwted2XgS+XPF+a1JkC9gFXZBkf8Mnk36Wfw9L/v38LfLfW5yLD+N5NYabXMeBZCp2H3+2U9kuOnZO0x+UVXteS9mvWwwttmZnlQMcNs5iZWf2czM3McsDJ3MwsB5zMzcxywMnczCwHnMzNzHLAydzMLAeczM3McuD/AdndnL7Vn+NhAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1\n", "plt.scatter(time,speed);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've added a bit of random noise, since measuring things manually isn't precise. This means it's not that easy to answer the question: what was the roller coaster's speed? Using SGD we can try to find a function that matches our observations. We can't consider every possible function, so let's use a guess that it will be quadratic; i.e., a function of the form `a*(time**2)+(b*time)+c`.\n", "\n", "We want to distinguish clearly between the function's input (the time when we are measuring the coaster's speed) and its parameters (the values that define *which* quadratic we're trying). So, let's collect the parameters in one argument and thus separate the input, `t`, and the parameters, `params`, in the function's signature: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f(t, params):\n", " a,b,c = params\n", " return a*(t**2) + (b*t) + c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In other words, we've restricted the problem of finding the best imaginable function that fits the data, to finding the best *quadratic* function. This greatly simplifies the problem, since every quadratic function is fully defined by the three parameters `a`, `b`, and `c`. Thus, to find the best quadratic function, we only need to find the best values for `a`, `b`, and `c`.\n", "\n", "If we can solve this problem for the three parameters of a quadratic function, we'll be able to apply the same approach for other, more complex functions with more parameters—such as a neural net. Let's find the parameters for `f` first, and then we'll come back and do the same thing for the MNIST dataset with a neural net.\n", "\n", "We need to define first what we mean by \"best.\" We define this precisely by choosing a *loss function*, which will return a value based on a prediction and a target, where lower values of the function correspond to \"better\" predictions. It is important for loss functions to return _lower_ values when predictions are more accurate, as the SGD procedure we defined earlier will try to _minimize_ this loss. For continuous data, it's common to use *mean squared error*:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(preds, targets): return ((preds-targets)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's work through our 7 step process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 1: Initialize the parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we initialize the parameters to random values, and tell PyTorch that we want to track their gradients, using `requires_grad_`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = torch.randn(3).requires_grad_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "orig_params = params.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 2: Calculate the predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we calculate the predictions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preds = f(time, params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create a little function to see how close our predictions are to our targets, and take a look:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_preds(preds, ax=None):\n", " if ax is None: ax=plt.subplots()[1]\n", " ax.scatter(time, speed)\n", " ax.scatter(time, to_np(preds), color='red')\n", " ax.set_ylim(-300,100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEACAYAAABLfPrqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAc1ElEQVR4nO3dfbBU9Z3n8fcHsYAAd8B4hyxMAaurMHV1iOtNOZuYxIxuTDKbLUf2D81NHGY2koxFaqusMjErRCtKaeLU/pHZyQOWBHWZncQUWuM86M5UxESzk9pmLExuBHeJXBKfckkQuYD4wHf/OKe1bbvvvX3POd2nuz+vqi66z+/Xp78cDv3t83s6igjMzMxaMavTAZiZWfdx8jAzs5Y5eZiZWcucPMzMrGVOHmZm1jInDzMza5mTh5mZtSzX5CFpg6SKpBOSttWVXSxpj6Rjkh6WtKKmbI6krZJekvS8pGvzjMvMzPKV95XHs8AtwNbajZJOB3YAm4DTgArwnZoqNwFnASuADwGfl/SRnGMzM7OcqIgZ5pJuAX4nItalr9cD6yLivenr+cBB4LyI2CPpGeBPIuJ/peU3A2dFxBW5B2dmZpnNbtPnDAG7qy8i4qikfcCQpBeApbXl6fPLGu0oTUTrAebPn3/+6tWrCwvazKwX7dq162BEDGbZR7uSxwJgvG7bYWBhWlZ9XV/2NhGxBdgCMDw8HJVKJd9Izcx6nKSxrPto12irCWCgbtsAcCQto668WmZmZiXUruQxCqypvkj7PM4ERiPiEPBcbXn6fLRNsZmZWYvyHqo7W9Jc4BTgFElzJc0G7gPOkbQ2Lf8S8ERE7EnfejewUdJiSauBq4FtecZmZmb5yfvKYyNwHLge+GT6fGNEjANrgc3AIeACoHYk1Y3APmAMeAS4PSIezDk2MzPLSSFDddvFHeZmZq2TtCsihrPsw8uTmJlZy5w8zMysZU4eZmbWMicPMzNrWbtmmJfS/Y8/w+0P7eXZF4+zdNE8rrt0FZedt6zTYZmZlV7fJo/7H3+GL+74CcdffR2AZ148zhd3/ATACcTMbAp922x1+0N730gcVcdffZ3bH9rboYjMzLpH3yaPZ1883tJ2MzN7U982Wy1dNI9nGiSKpYvmdSAaM7PGyto327dXHtdduop5p57ylm3zTj2F6y5d1aGIzMzeqto3+8yLxwne7Ju9//FnOh1a/yaPy85bxq2Xn8uyRfMQsGzRPG69/NxSZHQzMyh332zfNltBkkCcLMysrMrcN9vXySOrsrZFmllvKHPfbN82W2VV5rZIMyuX+x9/hvfd9n3+9fV/x/tu+/60vyfK3Dfr5DFDZW6LNLPyyPJDs8x9s262mqEyt0WaWXlM9kNzOkmgrH2zvvKYoWZtjmVoizSz8ujVH5ptTR6Sdkp6WdJE+thbU/YJSWOSjkq6X9Jp7YytVWVuizSz8ujVH5qduPLYEBEL0scqAElDwLeATwFLgGPA1zsQ27SVuS3SzMqjV39olqXPYwR4ICJ+ACBpE/CkpIURcaSzoTVX1rZIMyuP6ndErw3r70TyuFXSbcBe4IaI2AkMAT+qVoiIfZJeAc4GdnUgxsJ5johZ/+jFH5rtTh5fAH4GvAJcATwg6d3AAuBwXd3DwML6HUhaD6wHWL58eaHBFsX3EjGzbtfWPo+I+HFEHImIExFxF/AY8DFgAhioqz4AvK3JKiK2RMRwRAwPDg4WH3QBPEfErLvMdJJfL+t0n0cAAkaBNdWNks4A5gBPdSiuQvXq0D2zXuSWgsbaduUhaZGkSyXNlTRb0gjwAeAhYDvwcUnvlzQf+DKwo8yd5Vn06tA9s17kloLG2tlsdSpwCzAOHAQ+B1wWEXsjYhT4LEkS+RVJX8c1bYytrXp16J5ZL3JLQWNta7aKiHHgPZOU/xXwV+2Kp5N6deieWS8q88q2ndTpPo++1YtD98zKKsvQ+OsuXfWWPg9wSwE4eZhZj8va4e2WgsacPMysp2Vd1RbcUtCIk0eX8gx16zczPefd4V0MJ48u5HHn1m+ynPPu8C6G7+fRhTzu3PpNlnPeQ+OL4SuPLpTHZbibvaybZDnn3eFdDCePLpT1MjyPZi8nH2unrOe8O7zz52arLpT1Mjxrs1c1+Tzz4nGCN5OPF4uzorjpqXx85dGFsl6GZ232ymPoo/WfLFerbnoqHyePLpXlMjxrE4CHPlqr8mgqddNTubjZqg9lbQLIuiqw743QfzxCsPf4yqMPZW0CyLLWjzvru1eW4+6r1d7j5NGnsjQBZEk+WftLPEGyM7Ied0/U6z1OHjYjM00+7qzvTlmPu1em7T1OHtZW7qzvnE42O3m0VO9x8rC2yvoLNI/mj27uM5lp7GVodvJoqd7S36Ottm+HlSth1qzkz+3bOx1Rz7vsvGXcevm5LFs0DwHLFs3j1svPbamzPstIsW6e4Jgl9qyjnTxJz+qV5spD0mnAncCHSe5x/sX01rTF2L4d1q+HY8eS12NjyWuAkZHp7+OGG+DAAVi+HDZvnv57+1inOuur78vaYZ/lqiXL+7PE7mYny1tpkgfwl8ArwBLg3cDfSdodEaOFfNoNN7yZOKqOHUu2TycBZE0+TjwzliX5ZPkSzdr0k/X9WWJ3s5PlrRTNVpLmA2uBTRExERGPAn8DfKqwDz1woLXt9SZLPlOpJp6xMYh4M/G42axwWSY4Zm36yfr+LLG72cnyVorkAZwNvB4RT9Vs2w0M1VeUtF5SRVJlfHx85p+4fHlr2+tlST5ZEo9lkuVLNGvTT9b3Z4k9a1+TWb2yJI8FwOG6bYeBhfUVI2JLRAxHxPDg4ODMP3HzZnjHO9667R3vSLZPR5bkk/WqB9zZP0NZvkSzLsuS9f1ZE8Bl5y3jsev/gKdv+0Meu/4PnDgsk7L0eUwAA3XbBoAjhX1itX9hpv0Omze/tc8Dpp98li9PmqoabZ+OPDr7+9hM2+6zDjPOY6Kc+x2sLMpy5fEUMFvSWTXb1gDFdJZXjYzA/v1w8mTyZytfvCMjsGULrFgBUvLnli3T20fWqx43e3VEHr/83XRkvUIR0ekYAJD010AAnyYZbfX3wHsnG201PDwclUqlTRHmLMtoq1mzko72elKSCIv+fDPrapJ2RcRwln2UpdkK4BpgK/Ar4NfAnxU2TLcMRkZm/mXtZi8z67CyNFsREb+JiMsiYn5ELC90gmC3c7OXmXVYaZKHtSBLfwvkM9rLzPqak0e3ytLZn3WOC3iosFmfc/LoR1mbvTxD3qzvOXn0o6zNXu4zMet7pRmqOxNdPVS3m+UxVNjMOiaPobq+8rDW5dFnYmZdzcnDWpdHn4k72826mpOHtS5Ln4k72816gvs8rL1Wrmw8O37FimTIsZkVzn0e1n08QdGsJzh5WHu5s92sJzh5WHtl7Ww3s1Jw8rD2yjpBETxay6wEyrQku/WLLMvRezl5s1LwlYd1Fy+NYlYKTh7WXTxay6wUnDysu3i0llkptCV5SNop6WVJE+ljb135JySNSToq6X5Jp7UjLutCHq1lVgrtvPLYEBEL0seq6kZJQ8C3gE8BS4BjwNfbGJd1kzxGa5lZZmVothoBHoiIH0TEBLAJuFzSwg7HZWWV5S6K4KG+ZjloZ/K4VdJBSY9Juqhm+xCwu/oiIvYBrwBnN9qJpPWSKpIq4+PjhQZsPcgLM5rlol3J4wvAGcAyYAvwgKQz07IFwOG6+oeBhlceEbElIoYjYnhwcLCoeK1XeaivWS4yJ4+0MzyaPB4FiIgfR8SRiDgREXcBjwEfS3cxAQzU7XYAOJI1NrO38VBfs1xknmEeERfN5G2A0uejwJpqgaQzgDnAU1ljM3ub5csbLwnvob5mLSm82UrSIkmXSporabakEeADwENple3AxyW9X9J84MvAjojwlYflz0N9zXLRjj6PU4FbgHHgIPA54LKI2AsQEaPAZ0mSyK9I+jquaUNc1o881NcsF76ToFkrtm9POtcPHEiaujZvduKxrpPHnQS9qq7ZdHlFX7M3lGGSoFl38DBfszc4eZhNl4f5mr3BycNsuryir9kbnDzMpsvDfM3e4ORhNl0e5mv2Bo+2MmtFlvuvm/UQX3mYmVnLnDzM2sn3ErEe4WYrs3bxJEPrIb7yMGsXTzK0HuLkYdYunmRoPcTJw6xdPMnQeoiTh1m7eJKh9RAnD7N28SRD6yEebWXWTp5kaD3CVx5mZtayXJKHpA2SKpJOSNrWoPxiSXskHZP0sKQVNWVzJG2V9JKk5yVdm0dMZj3JkwytJPK68niW5D7lW+sLJJ0O7AA2AacBFeA7NVVuAs4CVgAfAj4v6SM5xWXWO6qTDMfGIOLNSYZOINYBuSSPiNgREfcDv25QfDkwGhH3RsTLJMlijaTVaflVwM0RcSgingTuANblEZdZT/EkQyuRdvR5DAG7qy8i4iiwDxiStBhYWluePh9qtjNJ69Mmssr4+HhBIZuVkCcZWom0I3ksAA7XbTsMLEzLqCuvljUUEVsiYjgihgcHB3MN1KzUPMnQSmTK5CFpp6Ro8nh0Gp8xAQzUbRsAjqRl1JVXy8yslicZWolMmTwi4qKIUJPHhdP4jFFgTfWFpPnAmST9IIeA52rL0+ejrf01zPqAJxlaieQ1VHe2pLnAKcApkuZKqk5AvA84R9LatM6XgCciYk9afjewUdLitBP9amBbHnGZ9ZyREdi/H06eTP504rAOyavPYyNwHLge+GT6fCNARIwDa4HNwCHgAuCKmvfeSNKBPgY8AtweEQ/mFJeZmRVAEdHpGGZseHg4KpVKp8MwM+sqknZFxHCWfXh5ErN+4dnpliMvjGjWD3wLXMuZrzzM+oFnp1vOnDzM+oFnp1vOnDzM+oFnp1vOnDzM+oFnp1vOnDzM+oFnp1vOPNrKrF/4FriWI195mJlZy5w8zMysZU4eZmbWMicPM5seL29iNdxhbmZT8/ImVsdXHmY2NS9vYnWcPMxsal7exOo4eZjZ1Ly8idVx8jCzqXl5E6uT1z3MN0iqSDohaVtd2UpJIWmi5rGppnyOpK2SXpL0vKRr84jJzHLk5U2sTl6jrZ4FbgEuBeY1qbMoIl5rsP0m4CxgBfAu4GFJP/N9zM1KxsubWI1crjwiYkdE3A/8egZvvwq4OSIORcSTwB3AujziMjOzYrSzz2NM0i8lfVvS6QCSFgNLgd019XYDQ812Iml92kRWGR8fLzZiMzNrqB3J4yDwHpJmqfOBhUB1auqC9M/DNfUPp3UaiogtETEcEcODg4MFhGtmZlOZMnlI2pl2eDd6PDrV+yNiIiIqEfFaRLwAbAA+LGkAmEirDdS8ZQA4MpO/jJmVmJc36SlTdphHxEU5f2akfyoiDkl6DlgD/GO6fQ0wmvNnmlkneXmTnpPXUN3ZkuYCpwCnSJoraXZadoGkVZJmSXon8DVgZ0RUm6ruBjZKWixpNXA1sC2PuMysJLy8Sc/Jq89jI3AcuB74ZPp8Y1p2BvAgSVPUT4ETwJU1770R2AeMAY8At3uYrlmP8fImPUcRMXWtkhoeHo5KpdLpMMxsKitXJk1V9VasgP372x1N35O0KyKGs+zDy5OYWfG8vEnPcfIws+J5eZOe45tBmVl7eHmTnuIrDzMza5mTh5mZtczJw8zMWubkYWZmLXPyMDOzljl5mFl38MKKpeKhumZWfl5YsXR85WFm5eeFFUvHycPMys8LK5aOk4eZld/y5a1tt8I5eZhZ+XlhxdJx8jCz8vPCiqXj0VZm1h28sGKp+MrDzMxaljl5SJoj6U5JY5KOSHpc0kfr6lwsaY+kY5IelrSi7v1bJb0k6XlJ12aNyczMipXHlcds4BfAB4HfAjYB35W0EkDS6cCOdPtpQAX4Ts37bwLOAlYAHwI+L+kjOcRlZmYFyZw8IuJoRNwUEfsj4mRE/C3wNHB+WuVyYDQi7o2Il0mSxRpJq9Pyq4CbI+JQRDwJ3AGsyxqXmZkVJ/c+D0lLgLOB0XTTELC7Wh4RR4F9wJCkxcDS2vL0+dAk+18vqSKpMj4+nnf4ZmY2DbkmD0mnAtuBuyJiT7p5AXC4ruphYGFaRl15tayhiNgSEcMRMTw4OJhP4GbW27yoYu6mTB6SdkqKJo9Ha+rNAu4BXgE21OxiAhio2+0AcCQto668WmZmll11UcWxMYh4c1FFJ5BMpkweEXFRRKjJ40IASQLuBJYAayPi1ZpdjAJrqi8kzQfOJOkHOQQ8V1uePh/FzCwPXlSxEHk1W30D+F3g4xFxvK7sPuAcSWslzQW+BDxR06x1N7BR0uK0E/1qYFtOcZlZv/OiioXIY57HCuAzwLuB5yVNpI8RgIgYB9YCm4FDwAXAFTW7uJGkA30MeAS4PSIezBqXmRngRRULknl5kogYAzRFnX8CVjcpOwH8afowM8vX5s1vvZEUeFHFHHh5EjPrbV5UsRBeGNHMep8XVcydrzzMzKxlTh5mZtYyJw8zM2uZk4eZmbXMycPMzFrm5GFmZi1z8jAzm4pX5X0bz/MwM5tMdVXe6gz16qq80NdzR3zlYWY2Ga/K25CTh5nZZLwqb0NOHmZmk/GqvA05eZiZTWbz5mQV3lpeldfJw8xsUl6VtyGPtjIzm4pX5X0bX3mYmVnL8rgN7RxJd0oak3RE0uOSPlpTvlJS1NyedkLSprr3b5X0kqTnJV2bNSYzMytWHs1Ws4FfAB8EDgAfA74r6dyI2F9Tb1FEvNbg/TcBZwErgHcBD0v6me9jbmZWXpmvPCLiaETcFBH7I+JkRPwt8DRw/jR3cRVwc0QciogngTuAdVnjMjOz4uTe5yFpCXA2MFpXNCbpl5K+Len0tO5iYCmwu6bebmAo77jMzCw/uSYPSacC24G7ImJPuvkg8B6SZqnzgYVpHYAF6Z+Ha3ZzOK3T7DPWS6pIqoyPj+cZvpmZTdOUyUPSzrTDu9Hj0Zp6s4B7gFeADdXtETEREZWIeC0iXkjLPixpAJhIqw3UfOQAcKRZPBGxJSKGI2J4cHCwpb+smZnlY8rkEREXRYSaPC4EkCTgTmAJsDYiXp1sl+mfiohDwHPAmpryNby9ycvMrHv14JLueU0S/Abwu8AlEXG8tkDSBcCLwP8FFgNfA3ZGRLWp6m5go6QKSfK5GviTnOIyM+usHl3SPY95HiuAzwDvBp6vmctRPSpnAA+SNEX9FDgBXFmzixuBfcAY8Ahwu4fpmlnP6NEl3RURU9cqqeHh4ahUKp0Ow8ysuVmzoNH3rAQnT7Y/HkDSrogYzrIPL09iZlakHl3S3cnDzKxIPbqku5OHmVmRenRJdy/JbmZWtB5c0t1XHmZm1jInDzMza5mTh5mZtczJw8zMWubkYWZmLXPyMDOzljl5mJlZy5w8zMysZU4eZmZlVtJ7gXiGuZlZWZX4XiC+8jAzK6sS3wvEycPMrKwOHGhtexs5eZiZlVWJ7wXi5GFmVlYlvhdILslD0v+Q9JyklyQ9JenTdeUXS9oj6Zikh9P7nlfL5kjamr73eUnX5hGTmVnXK/G9QHK5h7mkIeD/RcQJSauBncAfRsQuSacD+4BPAw8ANwPvj4jfT997K3Ah8B+BdwEPA+si4sGpPtf3MDcza11p7mEeEaMRcaL6Mn2cmb6+HBiNiHsj4mXgJmBNmmQArgJujohDEfEkcAewLo+4zMysGLnN85D0dZIv/XnA48Dfp0VDwO5qvYg4KmkfMCTpBWBpbXn6/LJJPmc9kA50ZkLS3hzCPx04mMN+ilDm2KDc8Tm2mSlzbFDu+LolthWTVZyO3JJHRFwj6XPAvwMuAqpXIguA8brqh4GFaVn1dX1Zs8/ZAmzJIeQ3SKpkvYQrSpljg3LH59hmpsyxQbnj66fYpmy2krRTUjR5PFpbNyJej4hHgd8B/izdPAEM1O12ADiSllFXXi0zM7OSmjJ5RMRFEaEmjwubvG02b/Z5jAJrqgWS5qdloxFxCHiutjx9PjqTv4yZmbVH5g5zSb8t6QpJCySdIulS4Erg+2mV+4BzJK2VNBf4EvBEROxJy+8GNkpanHaiXw1syxpXi3JtBstZmWODcsfn2GamzLFBuePrm9gyD9WVNAh8j+SKYRYwBnwtIu6oqXMJ8N9JOml+TDIUd39aNgf4BvCfgOPAVyLiv2UKyszMCpXLPA8zM+svXp7EzMxa5uRhZmYt64vkIek0SfdJOippTNInmtSTpK9I+nX6+KokFRzbHEl3pnEdkfS4pI82qbtO0uuSJmoeFxUc305JL9d8XsNJme0+dnXHYCI9Ln/RpG7hx03SBkkVSSckbasra7q2W4P9rEzrHEvfc0lRsUn6fUn/KOk3ksYl3SvpX02yn2mdCznGtzKdElD777Zpkv2089iN1MV1LI31/Cb7yf3YTfXdUfR51xfJA/hL4BVgCTACfEPJelz11pPMbl8D/B7wH4DPFBzbbOAXwAeB3wI2Ad+VtLJJ/f8dEQtqHjsLjg9gQ83nrWpSp63HrvYYkPy7HgfuneQtRR+3Z4FbgK21G5Ws7baD5N/1NKACfGeS/fxPkhUa3gncAHxPyaCU3GMDFpOMwFlJMpjlCPDtKfY1nXMhr/iqFtV85s2T7Kdtxy4ittedg9cAPwf+ZZJ95X3smn53tOW8i4iefgDzSRLH2TXb7gFua1D3R8D6mtf/GfjnDsT8BLC2wfZ1wKNtjmUn8Olp1OvYsQP+mOQ/rpqUt+24kXzRbKt5vR74Uc3r+SSJbnWD955NsjLDwpptPwQ+W0RsDcr/LXAk67mQ47FbSbJO3uxpvLfTx+5h4MZOHbuaz3kCWNuO864frjzOBl6PiKdqtu0mWXOr3lvW4ZqkXmEkLSGJudlEyfMkHVSy9P0mSe24D/2t6Wc+NklzTyeP3R8Dd0d61jfRieMGDdZ2I1llutn59/OIqF1hoZ3H8QNMPUF3OudC3sYk/VLSt9Nf1I107NilzUEfIJmzNplCj13dd0fh510/JI8FvHXtLGi+flZ93cPAgiLb7mtJOhXYDtwVb06irPUD4Bzgt0l+XVwJXFdwWF8AzgCWkTRxPCDpzAb1OnLsJC0nuWy/a5JqnThuVVnOv8nq5krS75FM4J3suEz3XMjLQeA9JE1q55Mch+1N6nbs2JGsDP7DiHh6kjqFHrsG3x2Fn3f9kDwmW1trqroDwMQUv2hzIWkWSXPaK8CGRnUi4ucR8XREnIyInwBfJplcWZiI+HFEHImIExFxF/AY8LEGVTt17K4iaZJq+h+3E8etRpbzb7K6uZH0b4B/AP5LRPywWb0WzoVcRMRERFQi4rWIeIHk/8WHJdUfI+jQsUtdxeQ/Xgo9dk2+Owo/7/oheTwFzJZ0Vs22ZutnvWUdrknq5Sr9dX4nScfv2oh4dZpvDaAtV0XT+MyOHDum8R+3gXYet6ZruzWpe4ak2l98hR7HtMnln0juqXNPi29v9/lX/SHS7Pxr67EDkPQ+kttKfK/Ft+Zy7Cb57ij+vCu6A6cMD+CvSUYTzAfeR3JJNtSg3meBJ0kuLZemBy+XDrcp4vsm8M/AginqfRRYkj5fDfyUSTrpcohrEXApMJdkZMcIcBRYVYZjB7w3jWfhFPUKP27p8ZkL3EryK7B6zAbT821tuu0rTDKQID0P/jyt+0fAi8BgQbEtI2kHvy7PcyHH+C4AVpH8yH0nyWihh8tw7GrKt5D0t3Xq2DX87mjHeZfbf54yP0iGqt2f/oMdAD6Rbn8/SdNKtZ6ArwK/SR9fpckInhxjW0HyK+RlksvH6mMEWJ4+X57W/XPghfTv8XOS5pdTC4xtEPg/JJevL6Yn2L8v0bH7FnBPg+1tP24kd8iMusdNadklwB6S0S47gZU17/sm8M2a1yvTOseBvcAlRcUG3Jg+rz3vav9N/yvwD1OdCwXGdyXwdPrv9hxJh/S7ynDs0rK56bG4uMH7Cj92TPLd0Y7zzmtbmZlZy/qhz8PMzHLm5GFmZi1z8jAzs5Y5eZiZWcucPMzMrGVOHmZm1jInDzMza5mTh5mZtez/A8bIaIXveo0oAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_preds(preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This doesn't look very close—our random parameters suggest that the roller coaster will end up going backwards, since we have negative speeds!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 3: Calculate the loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We calculate the loss as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(25823.8086, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = mse(preds, speed)\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal is now to improve this. To do that, we'll need to know the gradients." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 4: Calculate the gradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next step is to calculate the gradients. In other words, calculate an approximation of how the parameters need to change:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-53195.8594, -3419.7146, -253.8908])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss.backward()\n", "params.grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.5320, -0.0342, -0.0025])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params.grad * 1e-5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use these gradients to improve our parameters. We'll need to pick a learning rate (we'll discuss how to do that in practice in the next chapter; for now we'll just use 1e-5, or 0.00001):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.7658, -0.7506, 1.3525], requires_grad=True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 5: Step the weights. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need to update the parameters based on the gradients we just calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-5\n", "params.data -= lr * params.grad.data\n", "params.grad = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> a: Understanding this bit depends on remembering recent history. To calculate the gradients we call `backward` on the `loss`. But this `loss` was itself calculated by `mse`, which in turn took `preds` as an input, which was calculated using `f` taking as an input `params`, which was the object on which we originally called `requires_grad_`—which is the original call that now allows us to call `backward` on `loss`. This chain of function calls represents the mathematical composition of functions, which enables PyTorch to use calculus's chain rule under the hood to calculate these gradients." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see if the loss has improved:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(5435.5366, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = f(time,params)\n", "mse(preds, speed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And take a look at the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEACAYAAABLfPrqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAcYElEQVR4nO3df4wc5Z3n8ffHGGFje9YmzJKzVx4vHNgrwzqcJ+IO8sMJuZCwx4nF9wdkEta3G5wscnQSJxJy2AEFkElY3R/Zu/wwgpgfXjYhMmhJduGyG0yCcxvdEGSSCTaSA0P4mXFijMc25tf3/qjq0HS6Z6amqrpruj8vqTTd9TxV/Z1navrbz/NUVSsiMDMzy2JWpwMwM7OZx8nDzMwyc/IwM7PMnDzMzCwzJw8zM8vMycPMzDJz8jAzs8wKTR6SNkgalnRU0taGsnMl7ZZ0WNKDkgbqyo6TdKuklyW9IOmKIuMyM7NiFd3zeA64Hri1fqWkE4HtwCbgBGAY+FZdlWuBU4EB4APAZyV9pODYzMysICrjCnNJ1wN/FBHr0ufrgXURcXb6fB6wDzgzInZLehb4rxHxf9Ly64BTI+LiwoMzM7PcZrfpdVYCu2pPIuKQpL3ASkkvAovry9PHFzbbUZqI1gPMmzdv9YoVK0oL2sysGz3yyCP7IqI/zz7alTzmA2MN6w4AC9Ky2vPGst8TEVuALQCDg4MxPDxcbKRmZl1O0mjefbTrbKtxoK9hXR9wMC2jobxWZmZmFdSu5DECrKo9Sec8TgFGImI/8Hx9efp4pE2xmZlZRkWfqjtb0hzgGOAYSXMkzQbuAU6XtDYt/wLwWETsTje9HdgoaZGkFcBlwNYiYzMzs+IU3fPYCBwBrgI+nj7eGBFjwFrgBmA/cBZQfybVNcBeYBR4CLgpIu4vODYzMytIKafqtosnzM3MspP0SEQM5tmHb09iZmaZOXmYmVlmTh5mZpaZk4eZmWXWrivMK+neR5/lpgf28NxLR1i8cC5XnrecC89c0umwzMwqr2eTx72PPsvnt/+MI6+9AcCzLx3h89t/BuAEYmY2iZ4dtrrpgT2/Sxw1R157g5se2NOhiMzMZo6eTR7PvXQk03ozM3tLzw5bLV44l2ebJIrFC+d2IBozs+aqOjfbsz2PK89bztxjj3nburnHHsOV5y3vUERmZm9Xm5t99qUjBG/Nzd776LOdDq13k8eFZy5h80VnsGThXAQsWTiXzRedUYmMbmYG1Z6b7dlhK0gSiJOFmVVVledmezp55FXVsUgz6w5Vnpvt2WGrvKo8Fmlm1XLvo89yzo0/4I+v+h7n3PiDKb9PVHlu1sljmqo8Fmlm1ZHng2aV52Y9bDVNVR6LNLPqmOiD5lSSQFXnZt3zmKZWY45VGIs0s+ro1g+abU0eknZIekXSeLrsqSv7mKRRSYck3SvphHbGllWVxyLNrDq69YNmJ3oeGyJifrosB5C0EvgG8AngJOAw8NUOxDZlVR6LNLPq6NYPmlWZ8xgC7ouIHwJI2gQ8LmlBRBzsbGitVXUs0syqo/Ye0W2n9XcieWyWdCOwB7g6InYAK4Ef1ypExF5JrwKnAY90IMbS+RoRs97RjR802508Pgf8AngVuBi4T9K7gPnAgYa6B4AFjTuQtB5YD7B06dJSgy2Lv0vEzGa6ts55RMRPIuJgRByNiNuAncD5wDjQ11C9D/i9IauI2BIRgxEx2N/fX37QJfA1ImYzy3Qv8utmnZ7zCEDACLCqtlLSycBxwBMdiqtU3Xrqnlk38khBc23reUhaKOk8SXMkzZY0BLwPeADYBlwg6b2S5gFfBLZXebI8j249dc+sG3mkoLl2DlsdC1wPjAH7gM8AF0bEnogYAT5NkkR+TTLXcXkbY2urbj11z6wbeaSgubYNW0XEGPDuCcr/Dvi7dsXTSd166p5ZN6rynW07qdNzHj2rG0/dM6uqPKfGX3ne8rfNeYBHCsDJw8y6XN4Jb48UNOfkYWZdLe9dbcEjBc04ecxQvkLdes10j3lPeJfDyWMG8nnn1mvyHPOe8C6Hv89jBvJ559Zr8hzzPjW+HO55zEBFdMM97GUzSZ5j3hPe5XDymIHydsOLGPZy8rF2ynvMe8K7eB62moHydsPzDnvVks+zLx0heCv5+GZxVhYPPVWPex4zUN5ueN5hryJOfbTek6e36qGn6nHymKHydMPzDgH41EfLqoihUg89VYuHrXpQ3iGAvHcF9ncj9B6fIdh93PPoQXmHAPLc68eT9TNXnnZ3b7X7OHn0qDxDAHmST975El8g2Rl5290X6nUfJw+blukmH0/Wz0x52913pu0+Th7WVp6s75xODjv5bKnu4+RhbZX3E2gRwx8zec5kurFXYdjJZ0t1l94+22rbNli2DGbNSn5u29be7XvQhWcuYfNFZ7Bk4VwELFk4l80XnZFpsj7PmWIz+QLHPLHnPdvJF+lZo8r0PCSdANwCfJjkO84/n341bTm2bYP16+Hw4eT56GjyHGBoqD3bX301PP00LF0KN9wwte26QKcm62vb5Z2wz9NrybN9ntg97GRFU0R0OgYAJN1F0hP6K+BdwPeAsyNipNU2g4ODMTw8PL0XXLYsecNvNDAATz1V7vaNiQfg+ONhy5apJ5AeTj55/PFV36PZES/gyRv/bMJtG4d+IPn0PdWeU97t88R+zo0/aDrstGThXHZe9cFJX9u6i6RHImIwzz4qMWwlaR6wFtgUEeMR8TDwD8AnSnvRp5/Otr7I7a+++u2JA5LnV189tdeuJZ/RUYh4q9fjYbNJ5bnAMe/QT97t88TuYScrWiWSB3Aa8EZEPFG3bhewsrGipPWShiUNj42NTf8Vly7Ntr7I7fMmrrzJB3p2vibPm2jeoZ+82+eJPe9ck1mjqiSP+cCBhnUHgAWNFSNiS0QMRsRgf3//9F/xhhuSoaJ6xx+frC97+7yJK2/yKaLnMkOTT5430by3Zcm7fd4EcOGZS9h51Qd58sY/Y+dVH3TisFwqMech6UxgZ0QcX7fuvwNrIuKCVtvlmvOA/PMG090+75xHJ+droJg5mxmo03MeZkUpYs6DiOj4AswDXgVOrVt3O3DjRNutXr06Zqw774wYGIiQkp933plt2+OPj0j6Dcly/PFT34f09m1rizS17QcGmm8/MDD1+Kf7u3fYPT99Js7e/C+x7HPfjbM3/0vc89Nn2rq9WRGA4cj5vl2JngeApL8HAvgkydlW/0iZZ1vNdHl6TXl7HrNmJemikQRvvjnxtj3aazGrkq452yp1OTAX+DVwF/DXEyWOnjc0lLzRv/lm8jPLG2/e+Z48czae7DfrCpVJHhHx24i4MCLmRcTSKPMCwV43NJR80h8YSHoLAwPZPvnnST5VmOw3s9wqkzyszfL0XPIkn7xnmrnnYlYJTh42PdNNPnmHzNxzMasEJw9rr7xDZlXouZiZk4d1QCcn+4vouXjIy8zJw2aYTvZcPORl9juVuc5jOnr6Og+bnjzXmeS9PsasIrrtOg+z8uXpueQd8gIPe1nXqMyXQZm1zdDQ9K5mX7q0ec9jqpP1eb9AzKxC3PMwm6q8k/U+08u6iJOH2VTlnaz3sJd1EQ9bmWUx3SEv8LCXdRX3PMzaxcNe1kWcPMzapdPDXh7ysgJ52MqsnTo17OUhLyuYex5mM0WeYS8PeVnBnDzMZopOX+BoVsfJw2wmme5NJfPejRg8Z2Jv05bkIWmHpFckjafLnobyj0kalXRI0r2STmhHXGY9I++ZXr4ppDVoZ89jQ0TMT5fltZWSVgLfAD4BnAQcBr7axrjMul/eM708Z2INqjBsNQTcFxE/jIhxYBNwkaQFHY7LrLvk+R4VXx1vDdqZPDZL2idpp6Q1detXArtqTyJiL/AqcFqznUhaL2lY0vDY2FipAZtZKu+ciYe9uk67ksfngJOBJcAW4D5Jp6Rl84EDDfUPAE17HhGxJSIGI2Kwv7+/rHjNrJ6vjrcGuZNHOhkeLZaHASLiJxFxMCKORsRtwE7g/HQX40Bfw277gIN5YzOzgnT66nirnNzJIyLWRIRaLO9ptRmg9PEIsKpWIOlk4DjgibyxmVmB8syZ+FThrlP6sJWkhZLOkzRH0mxJQ8D7gAfSKtuACyS9V9I84IvA9ohwz8OsW/hU4a7TjjmPY4HrgTFgH/AZ4MKI2AMQESPAp0mSyK9J5joub0NcZtYuPlW46ygiOh3DtA0ODsbw8HCnwzCzss2alfQ4GknJMJplIumRiBjMs48qXOdhZjaxIuZMrFBOHmZWfUXMmXiyvVBOHmZWfXnmTDzZXgrPeZhZd1u2rPmXaA0MJKcc9yDPeZiZTcYXKJbCycPMupsvUCyFk4eZdTdfoFgKJw8z626+QLEUnjA3M5tIF16g6AlzM7Oy+QLFppw8zMwmknfOBLpywt3Jw8xsInnnTLp0wt1zHmZmZargRYqe8zAzq7ouvUjRycPMrExdOuHu5GFmVqYuvSOwk4eZWZm69I7AhSQPSRskDUs6Kmlrk/JzJe2WdFjSg5IG6sqOk3SrpJclvSDpiiJiMjOrjKGhZHL8zTeTn11wdXtRPY/nSL6n/NbGAkknAtuBTcAJwDDwrboq1wKnAgPAB4DPSvpIQXGZmc1cFZ5sLyR5RMT2iLgX+E2T4ouAkYi4OyJeIUkWqyStSMsvBa6LiP0R8ThwM7CuiLjMzGa0Ck+2t2POYyWwq/YkIg4Be4GVkhYBi+vL08crW+1M0vp0iGx4bGyspJDNzCqgiKvbS9KO5DEfONCw7gCwIC2jobxW1lREbImIwYgY7O/vLzRQM7NKyXt1e4lmT1ZB0g7g/S2Kd0bEeybZxTjQ17CuDziYltWev9JQZmZmQ0OVSBaNJu15RMSaiFCLZbLEATACrKo9kTQPOIVkHmQ/8Hx9efp4JNuvYWZm7VTUqbqzJc0BjgGOkTRHUq1Xcw9wuqS1aZ0vAI9FxO60/HZgo6RF6ST6ZcDWIuIyM7NyFDXnsRE4AlwFfDx9vBEgIsaAtcANwH7gLODium2vIZlAHwUeAm6KiPsLisvMzErgu+qamfUY31XXzMw6wsnDzMwyc/IwM7PMnDzMzCwzJw8zM8vMycPMzDJz8jAzs8ycPMzMLDMnDzMzy8zJw8zMMnPyMDOzzJw8zMwsMycPMzPLzMnDzMwyc/IwM7PMnDzMzCwzJw8zM8usqO8w3yBpWNJRSVsbypZJCknjdcumuvLjJN0q6WVJL0i6ooiYzMysPLML2s9zwPXAecDcFnUWRsTrTdZfC5wKDADvBB6U9At/j7mZWXUV0vOIiO0RcS/wm2lsfilwXUTsj4jHgZuBdUXEZWZm5WjnnMeopGckfVPSiQCSFgGLgV119XYBK1vtRNL6dIhseGxsrNyIzcysqXYkj33Au0mGpVYDC4Btadn89OeBuvoH0jpNRcSWiBiMiMH+/v4SwjUzs8lMmjwk7UgnvJstD0+2fUSMR8RwRLweES8CG4APS+oDxtNqfXWb9AEHp/PLmJlZe0w6YR4Rawp+zUh/KiL2S3oeWAV8P12/Chgp+DXNzKxARZ2qO1vSHOAY4BhJcyTNTsvOkrRc0ixJ7wC+AuyIiNpQ1e3ARkmLJK0ALgO2FhGXmZmVo6g5j43AEeAq4OPp441p2cnA/SRDUT8HjgKX1G17DbAXGAUeAm7yabpmZtWmiJi8VkUNDg7G8PBwp8MwM5tRJD0SEYN59uHbk5iZWWZOHmZmlpmTh5mZZebkYWZmmTl5mJlZZk4eZmaWmZOHmZll5uRhZmaZOXmYmVlmTh5mZpaZk4eZmWXm5GFmZpk5eZiZWWZOHmZmlpmTh5mZZebkYWZmmTl5mJlZZrmTh6TjJN0iaVTSQUmPSvpoQ51zJe2WdFjSg5IGGra/VdLLkl6QdEXemMzMrFxF9DxmA78C3g/8AbAJ+LakZQCSTgS2p+tPAIaBb9Vtfy1wKjAAfAD4rKSPFBCXmZmVJHfyiIhDEXFtRDwVEW9GxHeBJ4HVaZWLgJGIuDsiXiFJFqskrUjLLwWui4j9EfE4cDOwLm9cZmZWnsLnPCSdBJwGjKSrVgK7auURcQjYC6yUtAhYXF+ePl45wf7XSxqWNDw2NlZ0+GZmNgWFJg9JxwLbgNsiYne6ej5woKHqAWBBWkZDea2sqYjYEhGDETHY399fTOBmZpbJpMlD0g5J0WJ5uK7eLOAO4FVgQ90uxoG+ht32AQfTMhrKa2VmZlZRkyaPiFgTEWqxvAdAkoBbgJOAtRHxWt0uRoBVtSeS5gGnkMyD7Aeery9PH49gZmaVVdSw1deAPwEuiIgjDWX3AKdLWitpDvAF4LG6Ya3bgY2SFqWT6JcBWwuKy8zMSlDEdR4DwKeAdwEvSBpPlyGAiBgD1gI3APuBs4CL63ZxDckE+ijwEHBTRNyfNy4zMyvP7Lw7iIhRQJPU+WdgRYuyo8BfpouZmc0Avj2JmZll5uRhZmaZOXmYmVlmTh5mZpaZk4eZmWXm5GFmZpk5eZiZWWZOHmZmlpmTh5mZZebkYWZmmTl5mJlZZk4eZmaWmZOHmZll5uRhZmaZOXmYmVlmTh5mZpaZk4eZmWVWxNfQHifpFkmjkg5KelTSR+vKl0mKuq+nHZe0qWH7WyW9LOkFSVfkjcnMzMqV+2to0338Cng/8DRwPvBtSWdExFN19RZGxOtNtr8WOBUYAN4JPCjpF/4eczOz6srd84iIQxFxbUQ8FRFvRsR3gSeB1VPcxaXAdRGxPyIeB24G1uWNy8zMylP4nIekk4DTgJGGolFJz0j6pqQT07qLgMXArrp6u4CVRcdlZmbFKTR5SDoW2AbcFhG709X7gHeTDEutBhakdQDmpz8P1O3mQFqn1WuslzQsaXhsbKzI8M3MbIomTR6SdqQT3s2Wh+vqzQLuAF4FNtTWR8R4RAxHxOsR8WJa9mFJfcB4Wq2v7iX7gIOt4omILRExGBGD/f39mX5ZMzMrxqQT5hGxZrI6kgTcApwEnB8Rr020y9pmEbFf0vPAKuD76fpV/P6Ql5mZVUhRw1ZfA/4EuCAijtQXSDpL0nJJsyS9A/gKsCMiakNVtwMbJS2StAK4DNhaUFxmZlaCIq7zGAA+BbwLeKHuWo6htMrJwP0kQ1E/B44Cl9Tt4hpgLzAKPATc5NN0zcyqLfd1HhExCmiC8ruAuyYoPwr8ZbqYmdkM4NuTmJlZZk4eZmaWmZOHmZll5uRhZmaZOXmYmVlmTh5mZpaZk4eZmWXm5GFmZpk5eZiZWWZOHmZmlpmTh5mZZebkYWZmmTl5mJlZZk4eZmaWmZOHmZll5uRhZmaZOXmYmVlmTh5mZpZZIclD0p2Snpf0sqQnJH2yofxcSbslHZb0YPq957Wy4yTdmm77gqQriojJzMzKU1TPYzOwLCL6gP8MXC9pNYCkE4HtwCbgBGAY+FbdttcCpwIDwAeAz0r6SEFxmZlZCQpJHhExEhFHa0/T5ZT0+UXASETcHRGvkCSLVZJWpOWXAtdFxP6IeBy4GVhXRFxmZlaO2UXtSNJXSd705wKPAv+YFq0EdtXqRcQhSXuBlZJeBBbXl6ePL5zgddYD69On45L2FBD+icC+AvZThirHBtWOz7FNT5Vjg2rHN1NiG5io4lQUljwi4nJJnwH+A7AGqPVE5gNjDdUPAAvSstrzxrJWr7MF2FJAyL8jaTgiBovcZ1GqHBtUOz7HNj1Vjg2qHV8vxTbpsJWkHZKixfJwfd2IeCMiHgb+CPjrdPU40New2z7gYFpGQ3mtzMzMKmrS5BERayJCLZb3tNhsNm/NeYwAq2oFkualZSMRsR94vr48fTwynV/GzMzaI/eEuaQ/lHSxpPmSjpF0HnAJ8IO0yj3A6ZLWSpoDfAF4LCJ2p+W3AxslLUon0S8DtuaNK6NCh8EKVuXYoNrxObbpqXJsUO34eiY2RUS+HUj9wHdIegyzgFHgKxFxc12dDwH/i2SS5ifAuoh4Ki07Dvga8F+AI8CXIuJ/5grKzMxKlTt5mJlZ7/HtSczMLDMnDzMzy6wnkoekEyTdI+mQpFFJH2tRT5K+JOk36fJlSSo5tuMk3ZLGdVDSo5I+2qLuOklvSBqvW9aUHN8OSa/UvV7TizLb3XYNbTCetsvftqhbertJ2iBpWNJRSVsbylre263JfpaldQ6n23yorNgk/XtJ35f0W0ljku6W9G8m2M+UjoUC41uWXhJQ/3fbNMF+2tl2Qw1xHU5jXd1iP4W33WTvHWUfdz2RPID/DbwKnAQMAV+TtLJJvfUkV7evAv4U+E/Ap0qObTbwK+D9wB+Q3APs25KWtaj/fyNift2yo+T4ADbUvd7yFnXa2nb1bUDydz0C3D3BJmW323PA9cCt9Ss1+b3dGt1FcoeGdwBXA99RclJK4bEBi0jOwFlGcjLLQeCbk+xrKsdCUfHVLKx7zesm2E/b2i4itjUcg5cDvwR+OsG+im67lu8dbTnuIqKrF2AeSeI4rW7dHcCNTer+GFhf9/yvgH/tQMyPAWubrF8HPNzmWHYAn5xCvY61HfAXJP+4alHetnYjeaPZWvd8PfDjuufzSBLdiibbnkZyZ4YFdet+BHy6jNialP874GDeY6HAtltGcp+82VPYttNt9yBwTafaru51HgPWtuO464Wex2nAGxHxRN26XST33Gr0tvtwTVCvNJJOIom51YWSZ0rap+TW95skFXaLmQlsTl9z5wTDPZ1su78Abo/0qG+hE+0GTe7tBuyl9fH3y4iov8NCO9vxfUx+ge5UjoWijUp6RtI300/UzXSs7dLhoPeRXLM2kVLbruG9o/TjrheSx3zefu8saH3/rMa6B4D5ZY7d15N0LLANuC3euoiy3g+B04E/JPl0cQlwZclhfQ44GVhCMsRxn6RTmtTrSNtJWkrSbb9tgmqdaLeaPMffRHULJelPSS7gnahdpnosFGUf8G6SIbXVJO2wrUXdjrUdyZ3BfxQRT05Qp9S2a/LeUfpx1wvJY6J7a01Wtw8Yn+QTbSEkzSIZTnsV2NCsTkT8MiKejIg3I+JnwBdJLq4sTUT8JCIORsTRiLgN2Amc36Rqp9ruUpIhqZb/uJ1otzp5jr+J6hZG0r8F/gn4bxHxo1b1MhwLhYiI8YgYjojXI+JFkv+LD0tqbCPoUNulLmXiDy+ltl2L947Sj7teSB5PALMlnVq3rtX9s952H64J6hUq/XR+C8nE79qIeG2KmwbQll7RFF6zI23HFP5xm2hnu7W8t1uLuidLqv/EV2o7pkMu/0zynTp3ZNy83cdf7YNIq+OvrW0HIOkckq+V+E7GTQtpuwneO8o/7sqewKnCAvw9ydkE84BzSLpkK5vU+zTwOEnXcnHaeIVMuE0S39eBfwXmT1Lvo8BJ6eMVwM+ZYJKugLgWAucBc0jO7BgCDgHLq9B2wNlpPAsmqVd6u6XtM4fkWzXvqGuz/vR4W5uu+xITnEiQHgd/k9b9c+AloL+k2JaQjINfWeSxUGB8ZwHLST7kvoPkbKEHq9B2deVbSObbOtV2Td872nHcFfbPU+WF5FS1e9M/2NPAx9L17yUZWqnVE/Bl4Lfp8mVanMFTYGwDJJ9CXiHpPtaWIWBp+nhpWvdvgBfT3+OXJMMvx5YYWz/w/0i6ry+lB9h/rFDbfQO4o8n6trcbyTdkRsNybVr2IWA3ydkuO0i+srm23deBr9c9X5bWOQLsAT5UVmzANenj+uOu/m/6P4B/muxYKDG+S4An07/b8yQT0u+sQtulZXPStji3yXaltx0TvHe047jzva3MzCyzXpjzMDOzgjl5mJlZZk4eZmaWmZOHmZll5uRhZmaZOXmYmVlmTh5mZpaZk4eZmWX2/wExahUKOb/yYgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_preds(preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to repeat this a few times, so we'll create a function to apply one step:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def apply_step(params, prn=True):\n", " preds = f(time, params)\n", " loss = mse(preds, speed)\n", " loss.backward()\n", " params.data -= lr * params.grad.data\n", " params.grad = None\n", " if prn: print(loss.item())\n", " return preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 6: Repeat the process " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we iterate. By looping and performing many improvements, we hope to reach a good result:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5435.53662109375\n", "1577.4495849609375\n", "847.3780517578125\n", "709.22265625\n", "683.0757446289062\n", "678.12451171875\n", "677.1839599609375\n", "677.0025024414062\n", "676.96435546875\n", "676.9537353515625\n" ] } ], "source": [ "for i in range(10): apply_step(params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "params = orig_params.detach().requires_grad_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The loss is going down, just as we hoped! But looking only at these loss numbers disguises the fact that each iteration represents an entirely different quadratic function being tried, on the way to finding the best possible quadratic function. We can see this process visually if, instead of printing out the loss function, we plot the function at every step. Then we can see how the shape is approaching the best possible quadratic function for our data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1QAAADMCAYAAAB0vOLuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df4zc9Z3f8dfbLIe9BsemGFJMdn1CHD4Z5CA2ihroBUJVuOS4InxSaCbiSARbtUcvVw6DqY0hgIuJo2suIuW0BMQB2wuXE6Y6LgdVjx8VrlplCYHUxERK8RpsI0xrG/Can/vuH98ZdnY83+/MfOczM98fz4c0Ws/3M7P+7tf78nfe38/n+/mYuwsAAAAA0LkFg94BAAAAAMgrCioAAAAASImCCgAAAABSoqACAAAAgJQoqAAAAAAgJQoqAAAAAEiJggoAAAAAUgpaUJnZtWY2ZWbvm9kDDW0XmdlOM5sxs6fNbLSu7Tgzu9/M3jazN8zsupD7BeQReQLCIlNAWGQKiITuodor6Q5J99dvNLOTJD0q6WZJJ0qakvRI3UtulXSGpFFJF0q6wcwuCbxvQN6QJyAsMgWERaYASebu4b+p2R2STnP3q6rPxyVd5e5fqD5fLOktSee4+04z2yPpG+7+X6vtt0s6w92vCL5zQM6QJyAsMgWERaZQdv26h2q1pBdrT9z9sKRfS1ptZssknVrfXv3z6j7tG5A35AkIi0wBYZEplMpQn/6e4yXtb9h2SNIJ1bba88a2o1SveoxL0uLFi89dtWpV2D0Fmnj++effcvflg96PqmB5ksgUBqOomSJPGBQyBYTVSab6VVC9K2lJw7Ylkt6pttWev9fQdhR3n5A0IUljY2M+NTUVfGeBRmY2Peh9qBMsTxKZwmAUNVPkCYNCpoCwOslUv4b87ZC0pvakOpb2dEk73P2ApH317dU/7+jTvgF5Q56AsMgUEBaZQqmEnjZ9yMwWSjpG0jFmttDMhiRtk3SWma2ttm+S9JK776y+9UFJG81smZmtknSNpAdC7huQN+QJCItMAWGRKSASuodqo6QjktZL+nr1zxvdfb+ktZI2Szog6fOS6mdyuUXRzYrTkp6VtNXdnwi8b0DekCcgLDIFhEWmAPVo2vR+YSwt+sXMnnf3sUHvR6+RKfRLGTJFntBPZAoIq5NM9eseKgAAAAAonH7N8jcQj72wR1uffEV7Dx7RqUsXad3FZ+qyc1YMereA3CJTQDjkCQiLTGFQCltQPfbCHt306C905MOPJUl7Dh7RTY/+QpIIF5ACmQLCIU9AWGQKg1TYIX9bn3zlk1DVHPnwY2198pUB7RGQb2QKCIc8AWGRKQxSYXuo9h480tF2lAvDAjpHppCETHWGPKEVMtUZMoUkvc5TYXuoTl26qKPtKI/asIA9B4/INTcs4LEX9gx61zKNTCEOmeoceUISMtU5MoU4/chTYQuqdRefqUXHHjNv26Jjj9G6i88c0B4hKxgWkA6ZQhwy1TnyhCRkqnNkCnH6kafCDvmrdePFde/RlV5eDAtIh0whDpnqXKs8SWSqzMhU5zhHIU4/8lTYgkqKwtUsLMwEUw5x/3meunSR9jQJEcMCWiNT5UamworLk0SmyoJMhcU5qtwGmafCDvlLQld68SWNl2VYQHhkqvjIVH+RqeIjU/1Dnopv0HkqZUFFV3rxJf3nedk5K3Tn5WdrxdJFMkkrli7SnZefzVWqLpCp4iNT/UWmio9M9Q95Kr5B56nQQ/7i0JVefK3+80waaoPOkaniI1P9RaaKj0z1D3kqvkHnqZQ9VHSlFx/Tp/YXmSo+MtVfZKr4yFT/kKfiG3SeSllQter6e+yFPTpvy1P6zfV/p/O2PMW6DznEf579RaaKj0z1F5kqPjLVP+Sp+Aadp1IO+ZOYCabo2pmSGGGRqWIjU/1HpoqNTPUXeSq2QeeptAVVnFY3tSFbktaVYPx5NpCpfCFT2Uem8qPV2kdkavDIU75k9RxFQdWAmWDyg6tK+UCm8oNM5QOZygfylA/kKT+ynKlS3kOVZNA3taF9rCuRD2QqP8hUPpCpfCBP+UCe8iPLmaKgajDom9rQPq4q5QOZyg8ylQ9kKh/IUz6Qp/zIcqYY8tdg0De14Whx42VZVyIfyFT2kKl8I1PZQp7yjTxlTx4zRUHVBDeJZkfSeNl1F585r03iqlJWkansIFPFQKaygTwVA3nKjrxmioKqQ61m7EE6ccc1abzs9vVfksRVpbwjU+ElHVMyVXxkKjzOUeVGpsIrWqYoqDqQ5dlF8izpuLYaL8tVpXwjU+G1OqZkqtjIVHico8qNTIU3kExNTkobNki7d0sjI9LmzVKlku4HaIJJKTrQanYRVtpOJ+m4MvtOsZGp8FodUzJVbGQqPM5R5UamwutZpiYnpZUrpQULoq+Tk3Pbx8el6WnJPfo6Pj7XHgAFVQeSquZatb3n4BG55qrt+mARuuaSjiuz7xRbN5kiT821urpHpoqNTIXHOarcyFR4PclUUtG0YYM0MzP/9TMz0fZAKKg6kFQ1t3MFo1XBVVZJx/Wyc1bozsvP1oqli2SSVixdpDsvP5tu9oJImynyFK/V1T0yVWxkKjzOUeVGpsLrKlNxvVBJRdPu3c13JG57ChRUHUiqmltdFc7yYmT9EnelptXViMvOWaHt67+kV7d8RdvXf4kTVYGkzRR5Sp8niUwVGZlKj3MUmiFT6aXO1MvPaPtffFOvfudSbf+Lb+qyl5+JXpTUC5VUNI2MNG+L254Ck1J0IGmtgq1PvpI4N36rgqvoM8i0c1NnkX9+NJc2U+SJPKE5MpUOmUIcMpVOq0yt+Mk2febu23Xywf16c+lyvXb9zfrcOZfMFU213qZa0SQl90KNjESvbVSbgKL+e0rS8HC0PRBz92DfrN/GxsZ8ampq0Lsh6ehfHCmqtmvdlOdteapp6FZUQ5T03ryELmk/k37+2jSYWWZmz7v72KD3o9fykqm4k1g7eap97zxnKu95ksqRqSzlSepdpvKSJ4lM5R2Zyp6fbr5bn/luQ2G04Vqdt+Upnbv9J7rhvz+oU99+S3uXnKTv/M6Vev68L2v7Z/Y1L3AmJqLiqFlhNDoa9TY1q1vMpIceiv+elUqqWf46yRQ9VIG0uoKVtBhZq27hPEzX2e1UzUCjVplKk6faiSrvmSJPSKMXmWp8X1bzJJEphFeaTMUUIz/dfLfO+vb1WvTh+5KkTx98U5/69vX6qaSx/7lLdz5xt4Y/itpOe3u/tjxxt26SpF/8Tbr7nZJ6oWrFUVzRVKkEnSa9ET1UfRR3xeE31/+dmv0rmKJu46SrZv2+ipH26l7er/6V4cqflK9MpcnTq1u+0vJ3MQ+ZkpTrPEnlyFSe8iQV4xyV9HOQqfwjU11kKqmHpkXbR1dfo6H35vbpo4WLNPTDe/XGtdfp0wffPOqvemPpyZIU2/bpQ/vje5riiqbR0fihe7VeqMBy2UNlZidKuk/SP5f0lqSb3P0/d/VNe7yIV6fiFiOLC0+rMbi9utIeF9Zuru4l9dChN4qeqTR5ktqbAjfrmfqPX/0seRqA4JnKUJ6k/JyjJDJVBEU/R0npM/X7O54+aqjc366+UI+9sEfP3fbneuSpBz5p+97/ukra9K3o74n7+ZPuS5LmF0zT09FzSapUNLPuRg2/N39fh947opl1N+rkg/ub/twnH9wvs+bH5JRD+9Pf79SqF2qAsjTL3w8kfSDpFEkVSfeY2erU363VIl5x0y4OQNJsJ91O1R63/kFcW9I0n90sxMbUsgORnUz1MW+tZg8qQqbI08CEy1Q7C02mzVTgvPXqHCWRqZLLzzkqcBbXXXym/uCVZ/XcPd/Q/7nrUj13zzf0B688q3UXn6k/fHW7tjxxt057e78WyD8ZKveHr27Xz7f8QLc9/v15bbc9/n39fMsPPulJqv/5P7r6mpbrMM2su3Fe75M0VzBJ0sJ9e5se7oX79urNpcubtr25dLksZgY9qxVBw8PzG+qLpomJqEfKLPpa3wNVqUi7dkmzs9HXDBRTUkaG/JnZYkkHJJ3l7r+qbntI0h53Xx/3vsSu35Ur03cZDuAKR7tX26S5Gxf/3SM/j+0yjrvaduflZ0tqPq631U2Ue6snr07/viKckPI2lCJTmZL6nrekIRFkKhvKkKnUedq16+grylJ7mUpqa5W3hLa4G8/T5unVLV9JfK9EpjqVp0zl7hwVOotS7DC6mXU3anjf0WtVzfzjFfp/hz/QaW8f3Sv0+pLlOnHxb8S+b9Ebe2VNPu+7mdylBU1SMyvTAp/V6586Ofbv3HfDpnn3UEnSkWOP0/++5bv63Mplmfus3alOMpWVguocSf/D3RfVbbte0hfd/dK49yUGa8GCXIzPbEfo8eBJbUknoyyOle+XPJ2opIxlSsrcxQ0yNXhlyFTqPM3OJn84lMLnTUr9gTPVLF/Ve2vj2iXFtpGp5vKUqdyco3bt6k0Wk9oSZrKbTSh+pPi2N5cuj72f6aNZjy2YTjv0pm796r/XDY/+2SeTS0jSzNBx+s7l1+nWR/5DbP4l5aJoSpLHguqfSvqxu3+6bts1kirufkHDa8cljUvSyMjIudPNfiGl5AAkTbuYFLraVcOM/HKkvTIoKdXJqJ3pqIsqTycqKWOZknpzcaNHPVtkqj+Kmqkgedq1K/nDoRQ+b1LwIu2nuw7EX73ecK2+den182YBk6IPajddEn0Yi2ubOu/LscXWuovP1HO3/bn+pP4eky9dpfNb3WNSAHnKVG7OUbOzvcliUltCTmc++ChV79XW37kyNk/Lhn8jsWCq3bcVm6kC6yRTWbmH6l1JSxq2LZH0TuML3X3C3cfcfWz58uZjNyUlj89MWjE5abrGdsa891HSmO+kseJJbUlj5RljnivZyVTavCUt4NejLJIpJGgrU0HyJKXPVNq8pW1LyOnn7v3uvGJKkhZ9+L4+d+93JUk3PffQvA9xkjT80fu66bmHEtu+9/HLuuvJ+feY3PXk3frexy/rspefaXr/yWUvP9O7e2wydE92juTjHFX/tdP3pm1L+DmGt96ljxbOP998tHCRhrfepR9ecrVmho6b1zYzdJx+eMnVmjrvy1p/ybV6fclyzcr0+pLlWl+9OPHZ9X+kTb/3x/PaNv3eH+uz6/9IUnRePH/Tt/TVm36k02/8W331ph+VopjqmLsP/CFpsaIbE8+o2/agpC1J7zv33HM90cMPu4+OuptFXx9+eG778LB79F9q9Bgennt9/fbaY3Q0uS1jtv3sdV+18e999MbHP3ms2vj3vu1nrye21d77hTv/wVfe+Lh/4c5/+GR7mUma8gxkpd1HpjKVNm9mzdtqf3dSFuP2swtkKqwyZCp1nmptaTLVi/Nb2pwmtbn7bEz7rFliW09+jl4c71b/xu38DnTw/1ieMpWbc1Q37+3V701M27afve5/etk6f23Jcv9Y5q8tWe5/etk6zlFd6CRTAw/VJzsi/UjSX1VDdp6kQ5JWJ72nZbCSpAldi5NDLz7EdSMpIISnM3k6UdUemclUUlsviq1WJ6pW+5qATIVThkx1lSf39B+2Q3+o7NVFyH4XcWkv0vSiSOvm3yNG3jKVi3NUt+8NWDC3g3NUWHktqE6U9Jikw5J2S/paq/d0fbKKE/dL3s1/nMi1vJ2oPGuZStLvnmSymgllyNRA8tRKVoq0bt6bpZ62bnrSA4+IyVumcnOOQmnlsqBK8+h7sNL+B4/cy9uJKu0jcyerXvQkD2C4II5WhkxlLk+90u8r/1nqaUtbpLl3NVyyGTIFhEVB1Utx/8G3858fH9RyqwwnKs/bySpNT7J798MFEUQZMpWrPOVNVnraurnQWvIeqjQPMoV+oqAaBIYZFVoZTlSetUyl1Spr3Q4X5KJIEGXIVCHyVCb9LNJatZfgHqo0DzKFfqKgGoRuPsQh88pwovKsZaobrT78pBkuyEWRoMqQqcLkCen1arKDJsgUEFYnmcrEwr5pJa6YPQhJCwYmLQw3O9vf/UTH8rRgYjcyl6leictq0sKQUm4W/c6DMmSqNHlCJpApIKw8LuxbDJVK9MFqdjb6Wv9hqtXCcSwMCPRPXFaTFobM0aLfAACgfyio+iXpgxofxoBsqFSkiYmo18ks+joxEW1PuiiyYYM0MzN/+8xMtF3iggkAAAVGQdUvSR/UWn0YA9A/9F4BAIAOUFD1U9wHtaQPYwCyoVe9VwAAINcoqLKA+6uAfAjdeyWRbwAAco6CKgu4vwrIt7S9V+QbAIDco6DKAu6vAvIvTe8V+QYAIPcoqLKC+6uAYkq6YMJwQAAAcm9o0DuAFkZGmi8mGjeMCED2VCrNF/lNyndtOGCtB6s2HLD2/QAAQCbQQ5V1ScOFJK5gA3nWzXBAsg8AQCZQUGVd0nAhbmgH8i3tcECyDwBAZlBQ5UHc/VXc0A7kX1y+WdsKAIBcoKDKMyasAIqrm7WtAABA31BQ5VmrBYEB5Ffata0k7q8CAKCPKKjyrNWEFQDyLc3aVtxfBQBAX1FQ5VnSFWyJq9RAUbEYOAAAmUFBlXdxV7C5Sg0UWzeLgXOxBQCAYCioioqr1EA5tXN/FRdbAAAIhoKqqJgFDCinVvdWcrEFAICgKKiKihkAgXJqdW9lq4stDAcEAKAjFFRFxQyAQHnF3V8lJV9sYTggAAAdo6AqKmYABNBM0sUWhgMCANAxCqoiYwZAAI2SLrZw7yUAAB2joCojrkID5RZ3saWdGQLp2QYAYB4KqjLiKjSAZpKGA9KzDQBAUxRUZcQMgACaSRoOSM82AABNUVCVETMAAogTNxyQnm0AAJqioCqjVjMAAkAj7q8CAKCpIAWVmV1rZlNm9r6ZPdCk/SIz22lmM2b2tJmN1rUdZ2b3m9nbZvaGmV0XYp/QQtI6NXwwGjgyhczJ+f1VZAoIi0wBc0L1UO2VdIek+xsbzOwkSY9KulnSiZKmJD1S95JbJZ0haVTShZJuMLNLAu0XOpWDD0YlQaaQLfm/v4pMAWGRKaAqSEHl7o+6+2OS/m+T5ssl7XD3H7v7e4pCtMbMVlXbr5R0u7sfcPdfSrpX0lUh9gsp5OODUeGRKWRSju+vIlNAWGQKmNOPe6hWS3qx9sTdD0v6taTVZrZM0qn17dU/r477ZmY2Xu1intq/f3+PdrnEcvDBCGQKGZP/mUODZYo8AZLIFEqmHwXV8ZIONWw7JOmEapsa2mttTbn7hLuPufvY8uXLg+4oVIQPRmVAppAt+Z85NFimyBMgiUyhZFoWVGb2jJl5zOO5Nv6OdyUtadi2RNI71TY1tNfaMAj5/2CUeWQKhTPgmUPJFBAWmQI607KgcvcL3N1iHue38XfskLSm9sTMFks6XdHY2gOS9tW3V/+8o7MfA8EwpXrPkSkUUtLMoT1GpoCwyBTQmVDTpg+Z2UJJx0g6xswWmtlQtXmbpLPMbG31NZskveTuO6vtD0raaGbLqjcrXiPpgRD7hZQG+MEIETIFhEWmgLDIFDAn1D1UGyUdkbRe0terf94oSe6+X9JaSZslHZD0eUlX1L33FkU3Kk5LelbSVnd/ItB+ITTWqOoXMgWERaaAsMgUUGXuPuh9SG1sbMynpqYGvRvlUVujqn5a9eHhUgwJNLPn3X1s0PvRa2QK/VKGTJEn9BOZAsLqJFP9mOUPRcEaVQAAAMA8FFRoH2tUAQAAAPNQUKF9rFEFAAAAzENBhfaxRhUAAAAwDwUV2scaVQAAAMA8Q61fAtSpVCigAAAAgCp6qBAOa1QBAACgZOihQhiNa1RNT0fPJXq0AAAAUFj0UCEM1qgCAABACVFQIQzWqAIAAEAJUVAhDNaoAgAAQAlRUCEM1qgCAABACVFQIQzWqAIAAEAJMcsfwmGNKgAAAJQMPVQAAAAAkBIFFfqHhX8BAABQMAz5Q3+w8C8AAAAKiB4q9AcL/wIAAKCAKKjQHyz8CwAAgAKioEJ/sPAvAAAACoiCCv3Bwr8AAAAoIAoq9AcL/wIAAKCAmOUP/cPCvwAAACgYeqgAAAAAICUKKmQDi/4CAAAghxjyh8Fj0V8AAADkFD1UGDwW/QUAAEBOUVBh8Fj0FwAAADlFQYXBY9FfAAAA5BQFFQaPRX8BAACQUxRUGDwW/QUAAEBOdV1QmdlxZnafmU2b2Ttm9oKZ/W7Day4ys51mNmNmT5vZaMP77zezt83sDTO7rtt9Qg5VKtKuXdLsbPS1xMUUmQLCIlNAWGQKmC9ED9WQpNckfVHSpyTdLOmvzWylJJnZSZIerW4/UdKUpEfq3n+rpDMkjUq6UNINZnZJgP0C8opMAWGRKSAsMgXU6bqgcvfD7n6ru+9y91l3f1zSq5LOrb7kckk73P3H7v6eohCtMbNV1fYrJd3u7gfc/ZeS7pV0Vbf7BeQVmQLCIlNAWGQKmC/4PVRmdoqk35K0o7pptaQXa+3ufljSryWtNrNlkk6tb6/+eXXo/UKOTU5KK1dKCxZEXycnB71HfUWmgLDIFBAWmULZBS2ozOxYSZOS/tLdd1Y3Hy/pUMNLD0k6odqmhvZaW9zfMW5mU2Y2tX///jA7juyanJTGx6Xpack9+jo+XpqiikwBYfU6U+QJZUOmgDYKKjN7xsw85vFc3esWSHpI0geSrq37Fu9KWtLwbZdIeqfapob2WltT7j7h7mPuPrZ8+fJWu4+827BBmpmZv21mJtqeU2QKCCtLmSJPKAIyBXSmZUHl7he4u8U8zpckMzNJ90k6RdJad/+w7lvskLSm9sTMFks6XdHY2gOS9tW3V/+8Q4Ak7d7d2fYcIFNAWGQKCItMAZ0JNeTvHkm/LelSdz/S0LZN0llmttbMFkraJOmlum7hByVtNLNl1ZsVr5H0QKD9Qt6NjHS2vTjIFBAWmQLCIlNAVYh1qEYl/StJn5X0hpm9W31UJMnd90taK2mzpAOSPi/pirpvcYuiGxWnJT0raau7P9HtfqEgNm+WhofnbxsejrYXFJkCwiJTQFhkCphvqNtv4O7TkqzFa/6bpFUxbe9L+mb1AcxXW+B3w4ZomN/ISFRMFXjhXzIFhEWmgLDIFDBf1wUV0HOVSqELKAAAAORX8HWoAAAAAKAsKKgAAAAAICUKKgAAAABIiYIK+TY5Ka1cKS1YEH2dnBz0HgEAAKBEmJQC+TU5KY2PSzMz0fPp6ei5xCQWAAAA6At6qJBfGzbMFVM1MzPRdgAAAKAPKKiQX7t3d7YdAAAACIyCCvk1MtLZdgAAACAwCirk1+bN0vDw/G3Dw9F2AAAAoA8oqJBflYo0MSGNjkpm0deJCSakAAAAQN8wyx/yrVKhgAIAAMDA0EMFAAAAAClRUAEAAABAShRUAAAAAJASBRWKa3JSWrlSWrAg+jo5Oeg9AgAAQMEwKQWKaXJSGh+XZmai59PT0XOJSSwAAAAQDD1UKKYNG+aKqZqZmWg7AAAAEAgFFYpp9+7OtgMAAAApUFChmEZGOtsOAAAApEBBhWLavFkaHp6/bXg42g4AAAAEQkGFYqpUpIkJaXRUMou+TkwwIQUAAACCYpY/FFelQgEFAACAnqKHCgAAAABSoqACAAAAgJQoqAAAAAAgJQoqAAAAAEiJggrlNDkprVwpLVgQfZ2cHPQeAQAAIIeY5Q/lMzkpjY9LMzPR8+np6LnErIAAAADoCD1UKJ8NG+aKqZqZmWg7AAAA0AEKKpTP7t2dbQcAAABiBCmozOxhM9tnZm+b2a/M7OqG9ovMbKeZzZjZ02Y2Wtd2nJndX33vG2Z2XYh9AmKNjHS2fQDIFBAWmQLCIlPAnFA9VHdKWunuSyT9vqQ7zOxcSTKzkyQ9KulmSSdKmpL0SN17b5V0hqRRSRdKusHMLgm0X8DRNm+WhofnbxsejrZnB5kCwiJTQFhkCqgKUlC5+w53f7/2tPo4vfr8ckk73P3H7v6eohCtMbNV1fYrJd3u7gfc/ZeS7pV0VYj9ApqqVKSJCWl0VDKLvk5MZGpCCjIFhEWmgLDIFDAn2D1UZvafzGxG0k5J+yT9pNq0WtKLtde5+2FJv5a02syWSTq1vr3659Wh9gtoqlKRdu2SZmejrxkqpmrIFBAWmQLCIlNAJNi06e7+b8zs30r6J5IukFS7anG8pP0NLz8k6YRqW+15Y1tTZjYuqTrHtd41s1fa2L2TJL3VxuvKiGMTr/7YjCa9sBfIVG5xbOIVPlPkqSc4PvHIVHP8zsTj2CRLlamWBZWZPSPpizHN2939/NoTd/9Y0nNm9nVJ/1rS9yW9K2lJw/uWSHqn2lZ7/l5DW1PuPiFpotV+N/wMU+4+1sl7yoJjE69Xx4ZMFRvHJl4ZMkWewuP4xCNTsfvP70wMjk2ytMen5ZA/d7/A3S3mcX7M24Y0N452h6Q1dTu6uNq2w90PKOoiXlP33jXV9wCFRKaAsMgUEBaZAjrT9T1UZnaymV1hZseb2TFmdrGkfynpqepLtkk6y8zWmtlCSZskveTuO6vtD0raaGbLqjcrXiPpgW73C8grMgWERaaAsMgUMF+ISSlcURfv65IOSPqupD9x9/8iSe6+X9JaSZur7Z+XdEXd+29RdKPitKRnJW119ycC7Fe9jrqKS4ZjE29Qx4ZM5RvHJh6Zao7fmWQcn3hkqjl+Z+JxbJKlOj7m7qF3BAAAAABKIdi06QAAAABQNhRUAAAAAJBSoQsqMzvRzLaZ2WEzmzazrw16nwbFzK41sykze9/MHmhou8jMdprZjJk9bWZ9X8tikMzsODO7r/o78o6ZvWBmv1vXXurjU49MzSFTzZGn9pGnOeQpHplqH5maQ6bi9SJThS6oJP1A0geSTpFUkXSPmZV1Je69ku6QdH/9RjM7SdKjkm6WdKKkKUmP9H3vBmtI0muK1tz4lKJj8ddmtpLjcxQyNYdMNUee2kee5pCneGSqfWRqDpmKFzxThZ2UwqI1Dw5IOsvdf1Xd9pCkPe6+fqA7N0Bmdoek09z9qurzcUlXufsXqs8XK1oh+py66U1Lx8xekvRtSf9IHB9JZCoOmWqNPB2NPDVHntpDpo5GppojU+3pNlNF7qH6LUkf10JV9aKksl6piLNa0XGRJLn7YUVTmZb2OJnZKZhPqP0AAAH6SURBVIp+f3aI41OPTLWH35k65CkWeWoPvzMNyFQsMtUefmcahMhUkQuq4yUdath2SNIJA9iXLOM41TGzYyVNSvrL6pUIjs8cjkV7OE5V5CkRx6I9HKc6ZCoRx6I9HKc6oTJV5ILqXUlLGrYtkfTOAPYlyzhOVWa2QNJDisZfX1vdzPGZw7FoD8dJ5KkNHIv2cJyqyFRLHIv2cJyqQmaqyAXVryQNmdkZddvWKOrOw5wdio6LpE/Gip6ukh0nMzNJ9ym6kXWtu39YbeL4zCFT7Sn97wx5agt5ag+/MyJTbSJT7eF3RuEzVdiCqjrm8VFJt5nZYjM7T9K/UFSJlo6ZDZnZQknHSDrGzBaa2ZCkbZLOMrO11fZNkl4q4Y2J90j6bUmXuvuRuu0cnyoyNR+ZSkSeWiBP85GnlshUC2RqPjLVUthMuXthH4qmO3xM0mFJuyV9bdD7NMBjcaskb3jcWm37Z5J2Sjoi6RlJKwe9v30+NqPV4/Geoq7e2qPC8TnqWJGpuWNBppofF/LU/rEiT3PHgjzFHxsy1f6xIlNzx4JMxR+b4Jkq7LTpAAAAANBrhR3yBwAAAAC9RkEFAAAAAClRUAEAAABAShRUAAAAAJASBRUAAAAApERBBQAAAAApUVABAAAAQEoUVAAAAACQEgUVAAAAAKT0/wFx0txpoz32QAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_,axs = plt.subplots(1,4,figsize=(12,3))\n", "for ax in axs: show_preds(apply_step(params, False), ax)\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 7: stop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We just decided to stop after 10 epochs arbitrarily. In practice, we would watch the training and validation losses and our metrics to decide when to stop, as we've discussed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summarizing Gradient Descent" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "\n", "init\n", "\n", "init\n", "\n", "\n", "\n", "predict\n", "\n", "predict\n", "\n", "\n", "\n", "init->predict\n", "\n", "\n", "\n", "\n", "\n", "loss\n", "\n", "loss\n", "\n", "\n", "\n", "predict->loss\n", "\n", "\n", "\n", "\n", "\n", "gradient\n", "\n", "gradient\n", "\n", "\n", "\n", "loss->gradient\n", "\n", "\n", "\n", "\n", "\n", "step\n", "\n", "step\n", "\n", "\n", "\n", "gradient->step\n", "\n", "\n", "\n", "\n", "\n", "step->predict\n", "\n", "\n", "repeat\n", "\n", "\n", "\n", "stop\n", "\n", "stop\n", "\n", "\n", "\n", "step->stop\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#hide_input\n", "#id gradient_descent\n", "#caption The gradient descent process\n", "#alt Graph showing the steps for Gradient Descent\n", "gv('''\n", "init->predict->loss->gradient->step->stop\n", "step->predict[label=repeat]\n", "''')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To summarize, at the beginning, the weights of our model can be random (training *from scratch*) or come from a pretrained model (*transfer learning*). In the first case, the output we will get from our inputs won't have anything to do with what we want, and even in the second case, it's very likely the pretrained model won't be very good at the specific task we are targeting. So the model will need to *learn* better weights.\n", "\n", "We begin by comparing the outputs the model gives us with our targets (we have labeled data, so we know what result the model should give) using a *loss function*, which returns a number that we want to make as low as possible by improving our weights. To do this, we take a few data items (such as images) from the training set and feed them to our model. We compare the corresponding targets using our loss function, and the score we get tells us how wrong our predictions were. We then change the weights a little bit to make it slightly better.\n", "\n", "To find how to change the weights to make the loss a bit better, we use calculus to calculate the *gradients*. (Actually, we let PyTorch do it for us!) Let's consider an analogy. Imagine you are lost in the mountains with your car parked at the lowest point. To find your way back to it, you might wander in a random direction, but that probably wouldn't help much. Since you know your vehicle is at the lowest point, you would be better off going downhill. By always taking a step in the direction of the steepest downward slope, you should eventually arrive at your destination. We use the magnitude of the gradient (i.e., the steepness of the slope) to tell us how big a step to take; specifically, we multiply the gradient by a number we choose called the *learning rate* to decide on the step size. We then *iterate* until we have reached the lowest point, which will be our parking lot, then we can *stop*.\n", "\n", "All of that we just saw can be transposed directly to the MNIST dataset, except for the loss function. Let's now see how we can define a good training objective. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The MNIST Loss Function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already have our independent variables `x`—these are the images themselves. We'll concatenate them all into a single tensor, and also change them from a list of matrices (a rank-3 tensor) to a list of vectors (a rank-2 tensor). We can do this using `view`, which is a PyTorch method that changes the shape of a tensor without changing its contents. `-1` is a special parameter to `view` that means \"make this axis as big as necessary to fit all the data\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need a label for each image. We'll use `1` for 3s and `0` for 7s:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([12396, 784]), torch.Size([12396, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n", "train_x.shape,train_y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Dataset` in PyTorch is required to return a tuple of `(x,y)` when indexed. Python provides a `zip` function which, when combined with `list`, provides a simple way to get this functionality:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784]), tensor([1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dset = list(zip(train_x,train_y))\n", "x,y = dset[0]\n", "x.shape,y" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n", "valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n", "valid_dset = list(zip(valid_x,valid_y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need an (initially random) weight for every pixel (this is the *initialize* step in our seven-step process):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = init_params((28*28,1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function `weights*pixels` won't be flexible enough—it is always equal to 0 when the pixels are equal to 0 (i.e., its *intercept* is 0). You might remember from high school math that the formula for a line is `y=w*x+b`; we still need the `b`. We'll initialize it to a random number too:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bias = init_params(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In neural networks, the `w` in the equation `y=w*x+b` is called the *weights*, and the `b` is called the *bias*. Together, the weights and bias make up the *parameters*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> jargon: Parameters: The _weights_ and _biases_ of a model. The weights are the `w` in the equation `w*x+b`, and the biases are the `b` in that equation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now calculate a prediction for one image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([20.2336], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(train_x[0]*weights.T).sum() + bias" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While we could use a Python `for` loop to calculate the prediction for each image, that would be very slow. Because Python loops don't run on the GPU, and because Python is a slow language for loops in general, we need to represent as much of the computation in a model as possible using higher-level functions.\n", "\n", "In this case, there's an extremely convenient mathematical operation that calculates `w*x` for every row of a matrix—it's called *matrix multiplication*. <> shows what matrix multiplication looks like." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This image shows two matrices, `A` and `B`, being multiplied together. Each item of the result, which we'll call `AB`, contains each item of its corresponding row of `A` multiplied by each item of its corresponding column of `B`, added together. For instance, row 1, column 2 (the yellow dot with a red border) is calculated as $a_{1,1} * b_{1,2} + a_{1,2} * b_{2,2}$. If you need a refresher on matrix multiplication, we suggest you take a look at the [Intro to Matrix Multiplication](https://youtu.be/kT4Mp9EdVqs) on *Khan Academy*, since this is the most important mathematical operation in deep learning.\n", "\n", "In Python, matrix multiplication is represented with the `@` operator. Let's try it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[20.2336],\n", " [17.0644],\n", " [15.2384],\n", " ...,\n", " [18.3804],\n", " [23.8567],\n", " [28.6816]], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def linear1(xb): return xb@weights + bias\n", "preds = linear1(train_x)\n", "preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first element is the same as we calculated before, as we'd expect. This equation, `batch@weights + bias`, is one of the two fundamental equations of any neural network (the other one is the *activation function*, which we'll see in a moment)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check our accuracy. To decide if an output represents a 3 or a 7, we can just check whether it's greater than 0.0, so our accuracy for each item can be calculated (using broadcasting, so no loops!) with:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ True],\n", " [ True],\n", " [ True],\n", " ...,\n", " [False],\n", " [False],\n", " [False]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corrects = (preds>0.0).float() == train_y\n", "corrects" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4912068545818329" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corrects.float().mean().item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see what the change in accuracy is for a small change in one of the weights (note that we have to ask PyTorch not to calculate gradients as we do this, which is what `with torch.no_grad()` is doing here):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad(): weights[0] *= 1.0001" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4912068545818329" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = linear1(train_x)\n", "((preds>0.0).float() == train_y).float().mean().item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we've seen, we need gradients in order to improve our model using SGD, and in order to calculate gradients we need some *loss function* that represents how good our model is. That is because the gradients are a measure of how that loss function changes with small tweaks to the weights.\n", "\n", "So, we need to choose a loss function. The obvious approach would be to use accuracy, which is our metric, as our loss function as well. In this case, we would calculate our prediction for each image, collect these values to calculate an overall accuracy, and then calculate the gradients of each weight with respect to that overall accuracy.\n", "\n", "Unfortunately, we have a significant technical problem here. The gradient of a function is its *slope*, or its steepness, which can be defined as *rise over run*—that is, how much the value of the function goes up or down, divided by how much we changed the input. We can write this in mathematically as: `(y_new - y_old) / (x_new - x_old)`. This gives us a good approximation of the gradient when `x_new` is very similar to `x_old`, meaning that their difference is very small. But accuracy only changes at all when a prediction changes from a 3 to a 7, or vice versa. The problem is that a small change in weights from `x_old` to `x_new` isn't likely to cause any prediction to change, so `(y_new - y_old)` will almost always be 0. In other words, the gradient is 0 almost everywhere." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A very small change in the value of a weight will often not actually change the accuracy at all. This means it is not useful to use accuracy as a loss function—if we do, most of the time our gradients will actually be 0, and the model will not be able to learn from that number.\n", "\n", "> S: In mathematical terms, accuracy is a function that is constant almost everywhere (except at the threshold, 0.5), so its derivative is nil almost everywhere (and infinity at the threshold). This then gives gradients that are 0 or infinite, which are useless for updating the model.\n", "\n", "Instead, we need a loss function which, when our weights result in slightly better predictions, gives us a slightly better loss. So what does a \"slightly better prediction\" look like, exactly? Well, in this case, it means that if the correct answer is a 3 the score is a little higher, or if the correct answer is a 7 the score is a little lower.\n", "\n", "Let's write such a function now. What form does it take?\n", "\n", "The loss function receives not the images themselves, but the predictions from the model. Let's make one argument, `prds`, of values between 0 and 1, where each value is the prediction that an image is a 3. It is a vector (i.e., a rank-1 tensor), indexed over the images.\n", "\n", "The purpose of the loss function is to measure the difference between predicted values and the true values — that is, the targets (aka labels). Let's make another argument, `trgts`, with values of 0 or 1 which tells whether an image actually is a 3 or not. It is also a vector (i.e., another rank-1 tensor), indexed over the images.\n", "\n", "So, for instance, suppose we had three images which we knew were a 3, a 7, and a 3. And suppose our model predicted with high confidence (`0.9`) that the first was a 3, with slight confidence (`0.4`) that the second was a 7, and with fair confidence (`0.2`), but incorrectly, that the last was a 7. This would mean our loss function would receive these values as its inputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trgts = tensor([1,0,1])\n", "prds = tensor([0.9, 0.4, 0.2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a first try at a loss function that measures the distance between `predictions` and `targets`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mnist_loss(predictions, targets):\n", " return torch.where(targets==1, 1-predictions, predictions).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're using a new function, `torch.where(a,b,c)`. This is the same as running the list comprehension `[b[i] if a[i] else c[i] for i in range(len(a))]`, except it works on tensors, at C/CUDA speed. In plain English, this function will measure how distant each prediction is from 1 if it should be 1, and how distant it is from 0 if it should be 0, and then it will take the mean of all those distances.\n", "\n", "> note: Read the Docs: It's important to learn about PyTorch functions like this, because looping over tensors in Python performs at Python speed, not C/CUDA speed! Try running `help(torch.where)` now to read the docs for this function, or, better still, look it up on the PyTorch documentation site." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try it on our `prds` and `trgts`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.1000, 0.4000, 0.8000])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.where(trgts==1, 1-prds, prds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that this function returns a lower number when predictions are more accurate, when accurate predictions are more confident (higher absolute values), and when inaccurate predictions are less confident. In PyTorch, we always assume that a lower value of a loss function is better. Since we need a scalar for the final loss, `mnist_loss` takes the mean of the previous tensor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.4333)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mnist_loss(prds,trgts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For instance, if we change our prediction for the one \"false\" target from `0.2` to `0.8` the loss will go down, indicating that this is a better prediction:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2333)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mnist_loss(tensor([0.9, 0.4, 0.8]),trgts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One problem with `mnist_loss` as currently defined is that it assumes that predictions are always between 0 and 1. We need to ensure, then, that this is actually the case! As it happens, there is a function that does exactly that—let's take a look." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sigmoid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `sigmoid` function always outputs a number between 0 and 1. It's defined as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sigmoid(x): return 1/(1+torch.exp(-x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pytorch defines an accelerated version for us, so we don’t really need our own. This is an important function in deep learning, since we often want to ensure values are between 0 and 1. This is what it looks like:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEMCAYAAAA/Jfb8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXyU5b3+8c8XCCQkJGwh7DvIpoBEEBStVetyatViq1VxrxW1bqeeWv1prdra2uOx1VqXUxTFXStuVG3VqrhUWQOEfd9JIGTfk+/vjwk9MSZmgCTPzOR6v17zknnmnuEyzFw83M8z92PujoiIxJY2QQcQEZGmp3IXEYlBKncRkRikchcRiUEqdxGRGKRyFxGJQSp3iTlmdqeZrQs6x35mNsvM3mtkzCVmVtlSmST2qdwlqphZgpndbWZrzazEzPaa2Xwzu67WsP8Gjg4qYz2uB34QdAhpXdoFHUDkAD0CnECoMDOAZGA80H//AHcvBAoDSVcPd88LOoO0Ptpzl2hzFvB7d3/N3Te6e4a7z3L3u/YPqG9axsxuMLNtZlZsZu+a2XQzczPrW/P4JWZWaWYnmNmymn8VfGRmvc3sODNbbGZFZvaemfWp89oXm9kKMyur+T3uMbN2tR7/yrSMhdxtZllmVmhmLwBdmusHJq2Tyl2izU7gVDPrGu4TzOz7hKZqfg+MBZ4HflfP0DbAL4ErgGOA3sCLwF3ADOBYoC/wP7Ve+z+AJ4DZwOHAfwLX1LxOQ64DbgJuBo4EFjUyXuTAubtuukXNjVDpbgaqgKXA48CZgNUacyewrtb9T4HZdV7nt4ADfWvuX1Jzf1ytMTfXbJtQa9uNwJ5a9+cBL9V57euBEqB9zf1ZwHu1Ht8G/LrOc14BKoP++eoWOzftuUtUcfdPgSHAVOApIA34K/CGmVkDTxsF/KvOts/re3lgWa37u2r+u7TOtm5m1rbm/mjg4zqv8xEQX5PzK8wsGegDfFbnoU8ayC5yUFTuEnXcvdLdP3P3+939TEJ73d8Fjvump4Xx0tXuXlX3Oe5eUc/rWD3bqPNYfb/nNz0m0mRU7hILVtb8t0cDj68AJtfZ1lSnSmYCx9fZdhyhaZkNdQd76MyZ7YSml2qre1/kkOhUSIkqZvYRoQOiC4BsYCjwGyAX+GcDT7sfeNHMvgTeBqYAF9U8dqh70PcCb5rZLcCrwDhCc/73u3v5N+S528xWEZou+h5w0iHmEPkK7blLtHkbuAD4G7AaeBJYCxzj7nvqe4K7vwr8F3ALoTn1C4Bf1Txceihh3P1vwGXAxcBy4AHgz7Vevz5/BB6sGbuE0L8q7vqG8SIHzNw19Setj5ndAVzv7t2CziLSHDQtIzHPzOIInX/+N6CI0DdcbwYeDjKXSHPSnrvEvJpvi74FTAA6ARuBpwl901WLdUlMUrmLiMQgHVAVEYlBETHn3r17dx84cGDQMUREosrChQv3uHtqfY9FRLkPHDiQBQsWBB1DRCSqmNnmhh7TtIyISAwKq9zN7FozW1CzXvWsRsbeaGa7zCzPzJ4wsw5NklRERMIW7p77DuAeQutWN8jMTiH0LcATgYHAYL75m3oiItIMwip3d3/V3V8D9jYy9GJgprtnuvs+4G5CK/aJiEgLauo599GErmu5XwaQZmb6ireISAtq6nJPAmpfDHj/rzvVHWhmV9bM4y/Izs5u4hgiIq1bU5d7IaGr0e+3/9cFdQe6++Punu7u6amp9Z6mKSIiB6mpz3PPJHQB4pdq7o8Fdrt7Y3P1IiIxzd3JKSpnV34pWfllZBWUsju/jPH9OzN1WNPv4IZV7jULL7UD2gJtzSye0MV86y669DQwy8yeJXSV+v9H6OLAIiIxrbyymu25JWzbV8y2fSVs31fCjtwStueWsDOvlF35pZRXVn/teTO+NSS4cidU0r+sdf9C4Fdm9gShS5iNcvct7v6Omd1H6Io4CYQuXPzLr72aiEgUqqiqZktOMRuyi9i4p5CNe4rZtKeILTnF7MwrobrWOoxt2xg9k+Pp3Tmecf0606tzPD2TQ7ceyfH06NSB1E4diI9r2/BveAgiYlXI9PR01/IDIhIpqqqdjXsKWbWrgDW7C1m7u4C1WYVs3ltERdX/dWaXjnEM7J7IgK4d6d8tkf5dO9KvSwJ9u3YkrVMH2rVt3kUAzGyhu6fX91hErC0jIhKU0ooqVu8qYNn2PDJ35JG5I5/Vuwooq5lCaWMwoFsiQ3skcfKoNIamJjE4NZHB3ZNI6RgXcPqGqdxFpNVwd7bkFLNw8z4Wb8klY1suK3fm/3tvPCUhjtG9k5l+9ABG9kpmRK9ODElNarapk+akcheRmFVd7azclc8XG3L4cmMOCzbvY09hGQCJ7dtyRN/OXDF1MEf0SWFMnxT6dknAzAJO3TRU7iISUzbtKeKTdXv4dN0ePlu/l7ySCgD6dklg6rDuTBjQhfSBXRjWoxNt28RGkddH5S4iUa20oorP1+/lw9VZfLgmm817iwHonRLPd0alMXlINyYN7kafzgkBJ21ZKncRiTq5xeX8Y8Vu3lu5m4/X7KGkoor4uDZMGdKdy48dxNRhqQzs1jFmplgOhspdRKLCvqJy3sncxd+W7eTz9XuprHZ6pcRzzoS+nDiyB0cP7haVBz6bi8pdRCJWSXkVf1+xizeW7OCjNdlUVjsDunXkx8cN5rQxPTm8T0qr3jv/Jip3EYko7s6iLft4ZeE23srYSUFZJT2T47ns2EF8b2xvRvdOVqGHQeUuIhEhr7iCvy7axnNfbmFdViEJcW05/fBeTJvQh6MHdaNNDJ/Z0hxU7iISqBU78pn12UZeX7KDsspqxvbrzH3TjuD0I3qR1EEVdbD0kxORFldd7by3cjczP9nIFxtziI9rw/eP7MuFR/dndO+UoOPFBJW7iLSYssoqXlu8ncc+3sCG7CL6dE7g1tNHcG56/4hepyUaqdxFpNmVVlTxwpdbePSjDezKL2V072Qe/NF4Th/Ts9lXTmytVO4i0mxKK6p49ostPPrRerILypg4qCu//8ERHDu0u854aWYqdxFpcpVV1byycBt/fH8tO/NKmTKkGw/9aDxHD+4WdLRWQ+UuIk3G3XlvZRb3vr2SDdlFjOvXmft/MJYpQ7sHHa3VUbmLSJNYvj2Pu99awRcbcxicmsjj0ydw8qg0Tb8EROUuIockp6ic37+7mhfmb6FLx/bcfeZozpvYnzgdKA2Uyl1EDkp1tfPcl1v4/burKSyr5NIpg7jh5GEkx+uUxkigcheRA7ZqVz6/eHUZi7fkMnlwN3515miGp3UKOpbUonIXkbCVVlTx4PtrefzjDSQnxPHAuWM5a1wfzatHIJW7iIRl8ZZ93PzKUtZlFXLOhL7cdvpIuiS2DzqWNEDlLiLfqKyyigf+sZbHP15PWnI8T102keOHpwYdSxqhcheRBq3ZXcD1Lyxh5c58zk3vx23fHakDplFC5S4iX+PuPPXZJu59exVJHdrxl4vSOWlUWtCx5ACo3EXkK3KLy/nZy0t5b+VuTjgslfvOGUtqpw5Bx5IDpHIXkX9buHkf1z2/mKyCUm7/7iguO2agzoSJUip3EcHdefLTTfzmbyvp1TmeV66awth+nYOOJYdA5S7SyhWXV/KLV5fx+pIdnDQyjft/OJaUBB00jXYqd5FWbMveYq6cvYDVuwu4+ZTDmHH8EF2IOkaEtbKPmXU1szlmVmRmm83s/AbGdTCzR81st5nlmNmbZtanaSOLSFP4fP1eznz4E3bmlTLr0olcc8JQFXsMCXfZtoeBciANuAB4xMxG1zPuemAycATQG8gFHmqCnCLShJ77YgvTZ35Bt6QOvH7NMfpSUgxqtNzNLBGYBtzu7oXu/gnwBjC9nuGDgHfdfbe7lwIvAPX9JSAiAaiqdu5+awW3zlnGscO68+rVUxjYPTHoWNIMwplzHw5UufuaWtsygOPrGTsT+KOZ7d9rvwB4+5BTisghKymv4oYXF/Nu5m4umTKQ2787iraaholZ4ZR7EpBXZ1seUN/6nmuALcB2oApYBlxb34ua2ZXAlQD9+/cPM66IHIy9hWVc/tQCMrblcsd3R3HZsYOCjiTNLJw590Iguc62ZKCgnrGPAPFANyAReJUG9tzd/XF3T3f39NRUzfeJNJetOcWc8+jnrNqVz6MXTlCxtxLhlPsaoJ2ZDau1bSyQWc/YscAsd89x9zJCB1MnmpmujisSgJU785n2yGfkFJXz7BWTOGV0z6AjSQtptNzdvYjQHvhdZpZoZscAZwKz6xk+H7jIzFLMLA64Gtjh7nuaMrSING7+phx++NjntDHj5asmM2FA16AjSQsK91TIq4EEIAt4Hpjh7plmNtXMCmuN+xlQCqwFsoHTgbObMK+IhGHe2mymz/yC1KQOvDJjsi6B1wqF9Q1Vd88Bzqpn+zxCB1z3399L6AwZEQnI3zN3ce1zixmcmsjsyydpRcdWSssPiMSQNzN2cMOLSxjTJ4WnLj2Kzh11GbzWSuUuEiNeX7KdG19cQvqArsy8JJ1OumJSq6ZyF4kBry3ezk0vLeGogV154pKjSOygj3Zrp3eASJTbX+wTB4WKvWN7faxF5S4S1eYu3clNLy1h0qBuPHHJUSS0bxt0JIkQ4Z4KKSIR5u+Zu7j+hcUc2b8LMy9JV7HLV6jcRaLQR2uyufa5xYzuk8KTl2oqRr5O5S4SZRZsyuEnsxcwpEcST186UWfFSL1U7iJRZMWOfC6dNZ/eKQnMvnwiKR1V7FI/lbtIlNi4p4iLnviSpA7tmH3FJLon6Zun0jCVu0gUyMovZfrML6h2Z/blk+jTOSHoSBLhVO4iES6/tIKLn5xPTlE5sy49iqE9khp/krR6KneRCFZWWcVVsxeydncBj144gSP6dg46kkQJnT8lEqGqq52fvbyUz9bv5YFzx3LccF2xTMKnPXeRCPW7d1fxZsYObjltBGeP7xt0HIkyKneRCPTMvzbz2EcbuPDo/vzkuMFBx5EopHIXiTAfrNrNHa8v59sjenDnGaMxs6AjSRRSuYtEkBU78rn2ucWM6p3MQz8aT7u2+ojKwdE7RyRCZOWXcsVT80lJiGPmxVqTXQ6N3j0iEaCkvIofP72A3JIKXr5qMmnJ8UFHkiinchcJWOiUxwyWbs/jsQsnMLp3StCRJAZoWkYkYA9+sJa5y3Zyy6kj+M7onkHHkRihchcJ0NvLdvKH99Yy7ci+XKlTHqUJqdxFApK5I4+bXspgfP/O/PrsMTrlUZqUyl0kAHsKy7jy6YV07hjHY9MnEB+nS+RJ09IBVZEWVlFVzTXPLmJPYRmvXDWFHp10Zow0PZW7SAv79dyVfLExhwfOHcvhfXVmjDQPTcuItKCXF2xl1mebuPzYQVoMTJqVyl2khSzdlsttry1nypBu/OK0EUHHkRinchdpAXsLy7hq9kJSkzrwp/OP1Jox0uw05y7SzCqrqrnuhcXsKSrnr1dNoWti+6AjSSsQ1u6DmXU1szlmVmRmm83s/G8Ye6SZfWxmhWa228yub7q4ItHnv/++hk/X7eWes8boAKq0mHD33B8GyoE0YBww18wy3D2z9iAz6w68A9wIvAK0B3TUSFqtd5bv4tGP1nP+pP78ML1f0HGkFWl0z93MEoFpwO3uXujunwBvANPrGX4T8K67P+vuZe5e4O4rmzaySHTYkF3Iz17OYGy/zvzyjFFBx5FWJpxpmeFAlbuvqbUtAxhdz9ijgRwz+8zMsszsTTPr3xRBRaJJcXklM55ZRFxb488XHEmHdvoGqrSscMo9Ccirsy0P6FTP2L7AxcD1QH9gI/B8fS9qZlea2QIzW5CdnR1+YpEI5+7cNmc5a7IK+MN54+nTOSHoSNIKhVPuhUBynW3JQEE9Y0uAOe4+391LgV8BU8zsa0eR3P1xd0939/TU1NQDzS0SsZ79YgtzFm/nhhOHc/xwvbclGOGU+xqgnZkNq7VtLJBZz9ilgNe6v//XWu5OWoVl2/K4680VHDc8lZ9+e2jQcaQVa7Tc3b0IeBW4y8wSzewY4Exgdj3DnwTONrNxZhYH3A584u65TRlaJBLlFVdw9XML6ZbUnj+cO442bbRPI8EJ92tyVwMJQBahOfQZ7p5pZlPNrHD/IHf/ALgVmFszdijQ4DnxIrHC3fnZKxnszC3lT+cfqS8qSeDCOs/d3XOAs+rZPo/QAdfa2x4BHmmSdCJR4i/zNvKPFbu5/bujmDCgS9BxRLS2jMihWrh5H797ZxWnju7JZccMDDqOCKByFzkk+4rK+elzi+jVOZ7fnXOELpUnEUMLh4kcpOpq5z9fzmBPYTl/nTGFlIS4oCOJ/Jv23EUO0v/O28AHq7K47T9GakEwiTgqd5GDsHBzDve9u5rTD+/JRZMHBB1H5GtU7iIHKDTPvpg+nRP47TTNs0tk0py7yAFwd35Wa549OV7z7BKZtOcucgD+Mm8j76/K4tbTR2ieXSKayl0kTIu3hM5nP2V0GhdPGRh0HJFvpHIXCUNecQXXPreYninx3HfOWM2zS8TTnLtII9ydn/91KbvzS3n5qsk6n12igvbcRRrx9OebeSdzFz8/dQTj+2vdGIkOKneRb7B8ex6/nruSb4/owRVTBwUdRyRsKneRBhSUVnDtc4voltSe+3+geXaJLppzF6mHu3PrnOVs3VfCC1ceTRetzy5RRnvuIvV4cf5W3szYwU0nD+eogV2DjiNywFTuInWs3lXAL9/I5Nih3Zlx/JCg44gcFJW7SC3F5ZVc89wiOsXH8YCugypRTHPuIrXc8Xom67MLeebySaR26hB0HJGDpj13kRp/XbiNVxZu46ffHsYxQ7sHHUfkkKjcRYB1WQX8v9eWM3FQV64/cVjQcUQOmcpdWr2S8iqueXYxCe3b8uB542mreXaJAZpzl1bvzjcyWb27gKcum0jPlPig44g0Ce25S6v22uLtvLhgK1d/awjHD08NOo5Ik1G5S6u1LquQW+cs46iBXbjp5OFBxxFpUip3aZVC8+yLiI9ry4M/Gk+7tvooSGzRnLu0Svvn2WddehS9UhKCjiPS5LS7Iq3Oq4u28eKCrVxzwhC+dViPoOOINAuVu7Qqa3cXcNuc0PnsN56keXaJXSp3aTWKyiqZ8ewiEju05U+aZ5cYpzl3aRXcndvmLGNDzboxPZJ1PrvEtrB2Xcysq5nNMbMiM9tsZuc3Mr69ma0ys21NE1Pk0Dz35RZeW7KDG08azhStGyOtQLh77g8D5UAaMA6Ya2YZ7p7ZwPibgSwg6dAjihyapdty+dUbKzh+eCrXnDA06DgiLaLRPXczSwSmAbe7e6G7fwK8AUxvYPwg4ELg3qYMKnIwcovLmfHMIlI7ddD67NKqhDMtMxyocvc1tbZlAKMbGP8QcCtQcojZRA5JdbVzw4tLyC4o488XHElXXQdVWpFwyj0JyKuzLQ/oVHegmZ0NtHP3OY29qJldaWYLzGxBdnZ2WGFFDsRDH6zjw9XZ3HHGKMb26xx0HJEWFU65FwLJdbYlAwW1N9RM39wH/DSc39jdH3f3dHdPT03Vgk3StD5cncUf3l/D2eP7cMGk/kHHEWlx4RxQXQO0M7Nh7r62ZttYoO7B1GHAQGCemQG0B1LMbBdwtLtvapLEIo3YsreY619YwmFpnfjN2YdT834UaVUaLXd3LzKzV4G7zOwKQmfLnAlMqTN0OdCv1v0pwJ+AIwHNu0iLKCmv4qpnFuLuPDZ9Agnt2wYdSSQQ4X5F72oggdDpjc8DM9w908ymmlkhgLtXuvuu/TcgB6iuuV/VLOlFanF3bnttGSt35fPH88YzoFti0JFEAhPWee7ungOcVc/2eTRwLru7fwj0PZRwIgfi6c838+qi7dxw0jBOGKEFwaR10+IaEhM+X7+Xu95awUkj07ju27rAtYjKXaLe9twSrnluEQO7deSBc8fqi0oiqNwlypVWVPGT2QuoqKzm8YvS6RQfF3QkkYigVSElark7N7+ylMwd+fzlonSGpGopI5H9tOcuUevPH67nzYwd3HzKYZw4Mi3oOCIRReUuUenvmbv4/burOXNcb2YcPyToOCIRR+UuUWfVrnxufHEJY/um8LtpR+gbqCL1ULlLVMkuKOPyWQtIim/HY9PTiY/TN1BF6qMDqhI1SiuquHL2AnKKynn5qsn0TNGl8kQaonKXqLD/zJjFW3J59MIJjOmTEnQkkYimaRmJCg/8Yw1vZuzg56eO4NQxPYOOIxLxVO4S8V6av5UHP1jHuen9uOr4wUHHEYkKKneJaPPWZnPrnGUcNzyVe84eozNjRMKkcpeItXJnPjOeWcTQHkk8fP544trq7SoSLn1aJCJt21fMJU9+SVKHdjx56VFaM0bkAKncJeLsKyrnoie+pKS8iqcvn0ivlISgI4lEHZ0KKRGlpLyKy56az7Z9JTxz+SSGp3UKOpJIVNKeu0SM8spqrn52IRlbc3nwvPFMHNQ16EgiUUt77hIRqqqd/3w5g3+uzube7x+uc9lFDpH23CVw7s4dry/nzYwd3HLaCH40sX/QkUSinspdAuXu3Pfuap79YgtXHT+Eq7R8r0iTULlLoB58fx2PfLie8yf15+enHhZ0HJGYoXKXwDz20XoeeG8N50zoyz1n6tunIk1J5S6BePLTjdz79irOGNub3007gjZtVOwiTUlny0iLe+KTjdz11gpOHd2T//nhWNqq2EWanMpdWtRf5m3gnrkrOXV0Tx7SejEizUafLGkx+4v9tDEqdpHmpj13aXbuzkMfrON//rGG/zi8F384b5yKXaSZqdylWbk7v31nFY99tIFpR/bld9MOp52KXaTZqdyl2VRVO798YznP/GsL048ewK++N1pnxYi0EJW7NIuyyipuejGDuct28pPjB3PLqSN0HrtICwrr38dm1tXM5phZkZltNrPzGxh3s5ktN7MCM9toZjc3bVyJBoVllVw2az5zl+3kttNH8ovTRqrYRVpYuHvuDwPlQBowDphrZhnunllnnAEXAUuBIcDfzWyru7/QVIElsmXll3LZU/NZubOA+38wlmkT+gYdSaRVanTP3cwSgWnA7e5e6O6fAG8A0+uOdff73H2Ru1e6+2rgdeCYpg4tkWnN7gLO/vNnbMgu4i8XpavYRQIUzrTMcKDK3dfU2pYBjP6mJ1no3+FTgbp79xKDPl23h2l//ozyqmpe+slkThjRI+hIIq1aOOWeBOTV2ZYHNHb9sztrXv/J+h40syvNbIGZLcjOzg4jhkSqZ7/YzMVPfEmvzvG8ds0xjOmTEnQkkVYvnDn3QiC5zrZkoKChJ5jZtYTm3qe6e1l9Y9z9ceBxgPT0dA8rrUSUyqpq7n5rBU99vpnjh6fy0PnjSY6PCzqWiBBeua8B2pnZMHdfW7NtLA1Mt5jZZcAtwHHuvq1pYkqkySkq57rnF/PJuj38eOogbjltpBYAE4kgjZa7uxeZ2avAXWZ2BaGzZc4EptQda2YXAL8BTnD3DU0dViLDsm15XPXMQrILy7jvnCP4YXq/oCOJSB3hfg/8aiAByAKeB2a4e6aZTTWzwlrj7gG6AfPNrLDm9mjTRpYgvbRgK9Me/QyAV66arGIXiVBhnefu7jnAWfVsn0fogOv++4OaLppEkuLySu54PZNXFm7j2KHdefBH4+ma2D7oWCLSAC0/II1avauAa55bxPrsQq47cRjXnzhM8+siEU7lLg1yd575Ygu/nruCpA5xPHP5JI4Z2j3oWCISBpW71Cu7oIyf/3UpH6zK4rjhqfz3D46gR6f4oGOJSJhU7vI17yzfxW1zllFQVsmdZ4zioskDtVSvSJRRucu/5RSV88s3MnkzYwejeyfz/LnjGJ7W2BeRRSQSqdwFd2fusp3c+UYmeSUV3HTycGZ8a4guhScSxVTurdzWnGLueH05/1ydzeF9Uph9+SRG9qq72oSIRBuVeytVVlnFzE828tD76zCD2787iosnD9D1TUVihMq9FfpwdRa/enMFG/cUcfKoNO783mj6dE4IOpaINCGVeyuydncB9769ig9WZTG4eyJPXTaR44enBh1LRJqByr0VyC4o4w/vreGF+Vvp2L4tvzhtBJceM4j27TQFIxKrVO4xLK+4gsfnreeJTzZRUVXN9KMHcN2Jw7QmjEgroHKPQfmlFTz16Sb+d94G8ksr+d7Y3tx48nAGdU8MOpqItBCVewzJLS7nyU838cSnGykoreSkkT246eTDGNVbpzaKtDYq9xiwPbeEmfM28sL8LRSXV3HK6DR++u1hupapSCumco9iS7bm8uSnG3lr6U4MOGNsb648brC+hCQiKvdoU1pRxTvLdzHrs00s2ZpLUod2XDx5IJdPHaRz1UXk31TuUWJDdiHPf7mFVxZuY19xBYO6J3LnGaM4J70fSR30xygiX6VWiGD5pRXMXbqTVxZuY+HmfbRrY5w8Ko0LJg1gypBuWoZXRBqkco8wpRVVfLg6mzcytvP+yizKKqsZ2iOJW04bwffH96FHsi6YISKNU7lHgNKKKj5ek83by3fx3srdFJRW0j2pPecd1Y+zxvdhXL/OmGkvXUTCp3IPSE5ROf9clcV7K3fz8ZpsisqrSEmI45TRPfne2N5MGdJNKzSKyEFTubeQqmpn+fY8PlydzUdrsliyNZdqh7TkDnxvXB9OG9OTyUO66QIZItIkVO7NxN1Zn13Evzbs5dN1e/hs/V7ySiowgyP6pHDtt4dx0sgejOmdogOjItLkVO5NpKKqmpU781mwaR8LNufw5cYc9hSWA9A7JZ7vjErj2GHdOXZod7oldQg4rYjEOpX7QXB3tu0rYdn2PJZszWXJ1lyWbcujpKIKCJX51GGpTBrUlUmDuzGwW0cdEBWRFqVyb0R5ZTXrswtZtSuflTsLWLEjn+U78sgtrgCgfds2jOqdzLlH9SN9YBeO7N+F3vqmqIgETOVeo7Siio17ilifXci6rELWZhWyZlcBG/cUUVntALRv14bhaUmcNqYnY/qkMKZ3CiN7JeuiFyIScVpVueeVVLBtXzFbc4rZvLeYLTnFbNpbxKY9xezIK8FDHY4Z9OvSkeFpSZw0Ko0RPTsxslcyg7sn6vREEYkKMbsXLKMAAAVjSURBVFPuRWWVZBWUsTOvhN35pezKK2NHbgk780rYnlvKtn3FFJRWfuU5nTvGMbBbIhMHdWVgt0QGpyYytEcSg7onEh/XNqD/ExGRQxfV5f7PVVnc9dYKsvJLKSqv+trjKQlx9O6cQO+UeCYO7ELfLh3p0yWB/l070q9rR1IS4gJILSLS/MIqdzPrCswEvgPsAX7h7s/VM86A3wJX1GyaCfzcff+ER9Pq3DGOUb2S+dZhqfToFE+PTh3olRJPz5pbx/ZR/XeXiMhBC7f9HgbKgTRgHDDXzDLcPbPOuCuBs4CxgAP/ADYAjzZN3K8a378LD1/QpTleWkQkqjV6dNDMEoFpwO3uXujunwBvANPrGX4xcL+7b3P37cD9wCVNmFdERMIQzqkfw4Eqd19Ta1sGMLqesaNrHmtsnIiINKNwyj0JyKuzLQ/oFMbYPCDJ6vl6ppldaWYLzGxBdnZ2uHlFRCQM4ZR7IVD3isvJQEEYY5OBwvoOqLr74+6e7u7pqamp4eYVEZEwhFPua4B2Zjas1raxQN2DqdRsGxvGOBERaUaNlru7FwGvAneZWaKZHQOcCcyuZ/jTwE1m1sfMegP/CcxqwrwiIhKGcL9LfzWQAGQBzwMz3D3TzKaaWWGtcY8BbwLLgOXA3JptIiLSgsI6z93dcwidv153+zxCB1H333fgv2puIiISEGumL48eWAizbGDzQT69O6FvzUaaSM0FkZtNuQ6Mch2YWMw1wN3rPSMlIsr9UJjZAndPDzpHXZGaCyI3m3IdGOU6MK0tl9avFRGJQSp3EZEYFAvl/njQARoQqbkgcrMp14FRrgPTqnJF/Zy7iIh8XSzsuYuISB0qdxGRGKRyFxGJQTFX7mY2zMxKzeyZoLMAmNkzZrbTzPLNbI2ZXdH4s5o9Uwczm2lmm82swMwWm9lpQecCMLNra5aCLjOzWQFn6Wpmc8ysqOZndX6QeWoyRczPp7YIf09F3GewtubqrFi8yOjDwPygQ9RyL3C5u5eZ2QjgQzNb7O4LA8zUDtgKHA9sAU4HXjKzw919U4C5AHYA9wCnEFrPKEjhXl6yJUXSz6e2SH5PReJnsLZm6ayY2nM3s/OAXOD9oLPs5+6Z7l62/27NbUiAkXD3Ine/0903uXu1u78FbAQmBJmrJtur7v4asDfIHAd4eckWEyk/n7oi/D0VcZ/B/Zqzs2Km3M0sGbiL0DLDEcXM/mxmxcAqYCfwt4AjfYWZpRG6nKLW3v8/B3J5Sakj0t5TkfgZbO7OiplyB+4GZrr71qCD1OXuVxO6LOFUQmvjl33zM1qOmcUBzwJPufuqoPNEkAO5vKTUEonvqQj9DDZrZ0VFuZvZh2bmDdw+MbNxwEnAA5GUq/ZYd6+q+ad9X2BGJOQyszaELrpSDlzbnJkOJFeEOJDLS0qNln5PHYiW/Aw2piU6KyoOqLr7t77pcTO7ARgIbKm5FncS0NbMRrn7kUHlakA7mnm+L5xcNRctn0noYOHp7l7RnJnCzRVB/n15SXdfW7NNl438BkG8pw5Ss38Gw/AtmrmzomLPPQyPE/rDGldze5TQVaBOCTKUmfUws/PMLMnM2prZKcCPgA+CzFXjEWAkcIa7lwQdZj8za2dm8UBbQm/2eDNr8Z2QA7y8ZIuJlJ9PAyLuPRXBn8Hm7yx3j7kbcCfwTATkSAU+InQ0PJ/Q5Qd/HAG5BhA6Y6CU0PTD/tsFEZDtTv7vjIb9tzsDytIVeA0oInR63/n6+UTXeypSP4MN/Lk2aWdp4TARkRgUK9MyIiJSi8pdRCQGqdxFRGKQyl1EJAap3EVEYpDKXUQkBqncRURikMpdRCQG/X9A/lu2GB8+RgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, it takes any input value, positive or negative, and smooshes it onto an output value between 0 and 1. It's also a smooth curve that only goes up, which makes it easier for SGD to find meaningful gradients. \n", "\n", "Let's update `mnist_loss` to first apply `sigmoid` to the inputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mnist_loss(predictions, targets):\n", " predictions = predictions.sigmoid()\n", " return torch.where(targets==1, 1-predictions, predictions).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can be confident our loss function will work, even if the predictions are not between 0 and 1. All that is required is that a higher prediction corresponds to higher confidence an image is a 3.\n", "\n", "Having defined a loss function, now is a good moment to recapitulate why we did this. After all, we already had a metric, which was overall accuracy. So why did we define a loss?\n", "\n", "The key difference is that the metric is to drive human understanding and the loss is to drive automated learning. To drive automated learning, the loss must be a function that has a meaningful derivative. It can't have big flat sections and large jumps, but instead must be reasonably smooth. This is why we designed a loss function that would respond to small changes in confidence level. This requirement means that sometimes it does not really reflect exactly what we are trying to achieve, but is rather a compromise between our real goal and a function that can be optimized using its gradient. The loss function is calculated for each item in our dataset, and then at the end of an epoch the loss values are all averaged and the overall mean is reported for the epoch.\n", "\n", "Metrics, on the other hand, are the numbers that we really care about. These are the values that are printed at the end of each epoch that tell us how our model is really doing. It is important that we learn to focus on these metrics, rather than the loss, when judging the performance of a model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SGD and Mini-Batches" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have a loss function that is suitable for driving SGD, we can consider some of the details involved in the next phase of the learning process, which is to change or update the weights based on the gradients. This is called an *optimization step*.\n", "\n", "In order to take an optimization step we need to calculate the loss over one or more data items. How many should we use? We could calculate it for the whole dataset, and take the average, or we could calculate it for a single data item. But neither of these is ideal. Calculating it for the whole dataset would take a very long time. Calculating it for a single item would not use much information, so it would result in a very imprecise and unstable gradient. That is, you'd be going to the trouble of updating the weights, but taking into account only how that would improve the model's performance on that single item.\n", "\n", "So instead we take a compromise between the two: we calculate the average loss for a few data items at a time. This is called a *mini-batch*. The number of data items in the mini-batch is called the *batch size*. A larger batch size means that you will get a more accurate and stable estimate of your dataset's gradients from the loss function, but it will take longer, and you will process fewer mini-batches per epoch. Choosing a good batch size is one of the decisions you need to make as a deep learning practitioner to train your model quickly and accurately. We will talk about how to make this choice throughout this book.\n", "\n", "Another good reason for using mini-batches rather than calculating the gradient on individual data items is that, in practice, we nearly always do our training on an accelerator such as a GPU. These accelerators only perform well if they have lots of work to do at a time, so it's helpful if we can give them lots of data items to work on. Using mini-batches is one of the best ways to do this. However, if you give them too much data to work on at once, they run out of memory—making GPUs happy is also tricky!\n", "\n", "As we saw in our discussion of data augmentation in <>, we get better generalization if we can vary things during training. One simple and effective thing we can vary is what data items we put in each mini-batch. Rather than simply enumerating our dataset in order for every epoch, instead what we normally do is randomly shuffle it on every epoch, before we create mini-batches. PyTorch and fastai provide a class that will do the shuffling and mini-batch collation for you, called `DataLoader`.\n", "\n", "A `DataLoader` can take any Python collection and turn it into an iterator over mini-batches, like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([ 3, 12, 8, 10, 2]),\n", " tensor([ 9, 4, 7, 14, 5]),\n", " tensor([ 1, 13, 0, 6, 11])]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "coll = range(15)\n", "dl = DataLoader(coll, batch_size=5, shuffle=True)\n", "list(dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For training a model, we don't just want any Python collection, but a collection containing independent and dependent variables (that is, the inputs and targets of the model). A collection that contains tuples of independent and dependent variables is known in PyTorch as a `Dataset`. Here's an example of an extremely simple `Dataset`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#26) [(0, 'a'),(1, 'b'),(2, 'c'),(3, 'd'),(4, 'e'),(5, 'f'),(6, 'g'),(7, 'h'),(8, 'i'),(9, 'j')...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = L(enumerate(string.ascii_lowercase))\n", "ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we pass a `Dataset` to a `DataLoader` we will get back mini-batches which are themselves tuples of tensors representing batches of independent and dependent variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(tensor([17, 18, 10, 22, 8, 14]), ('r', 's', 'k', 'w', 'i', 'o')),\n", " (tensor([20, 15, 9, 13, 21, 12]), ('u', 'p', 'j', 'n', 'v', 'm')),\n", " (tensor([ 7, 25, 6, 5, 11, 23]), ('h', 'z', 'g', 'f', 'l', 'x')),\n", " (tensor([ 1, 3, 0, 24, 19, 16]), ('b', 'd', 'a', 'y', 't', 'q')),\n", " (tensor([2, 4]), ('c', 'e'))]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dl = DataLoader(ds, batch_size=6, shuffle=True)\n", "list(dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to write our first training loop for a model using SGD!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Putting It All Together" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's time to implement the process we saw in <>. In code, our process will be implemented something like this for each epoch:\n", "\n", "```python\n", "for x,y in dl:\n", " pred = model(x)\n", " loss = loss_func(pred, y)\n", " loss.backward()\n", " parameters -= parameters.grad * lr\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's re-initialize our parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = init_params((28*28,1))\n", "bias = init_params(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `DataLoader` can be created from a `Dataset`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([256, 784]), torch.Size([256, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dl = DataLoader(dset, batch_size=256)\n", "xb,yb = first(dl)\n", "xb.shape,yb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll do the same for the validation set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_dl = DataLoader(valid_dset, batch_size=256)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create a mini-batch of size 4 for testing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 784])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = train_x[:4]\n", "batch.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-11.1002],\n", " [ 5.9263],\n", " [ 9.9627],\n", " [ -8.1484]], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = linear1(batch)\n", "preds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5006, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = mnist_loss(preds, train_y[:4])\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can calculate the gradients:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784, 1]), tensor(-0.0001), tensor([-0.0008]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss.backward()\n", "weights.grad.shape,weights.grad.mean(),bias.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's put that all in a function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def calc_grad(xb, yb, model):\n", " preds = model(xb)\n", " loss = mnist_loss(preds, yb)\n", " loss.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and test it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0002), tensor([-0.0015]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calc_grad(batch, train_y[:4], linear1)\n", "weights.grad.mean(),bias.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But look what happens if we call it twice:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0003), tensor([-0.0023]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calc_grad(batch, train_y[:4], linear1)\n", "weights.grad.mean(),bias.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The gradients have changed! The reason for this is that `loss.backward` actually *adds* the gradients of `loss` to any gradients that are currently stored. So, we have to set the current gradients to 0 first:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights.grad.zero_()\n", "bias.grad.zero_();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> note: Inplace Operations: Methods in PyTorch whose names end in an underscore modify their objects _in place_. For instance, `bias.zero_()` sets all elements of the tensor `bias` to 0." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our only remaining step is to update the weights and biases based on the gradient and learning rate. When we do so, we have to tell PyTorch not to take the gradient of this step too—otherwise things will get very confusing when we try to compute the derivative at the next batch! If we assign to the `data` attribute of a tensor then PyTorch will not take the gradient of that step. Here's our basic training loop for an epoch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_epoch(model, lr, params):\n", " for xb,yb in dl:\n", " calc_grad(xb, yb, model)\n", " for p in params:\n", " p.data -= p.grad*lr\n", " p.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also want to check how we're doing, by looking at the accuracy of the validation set. To decide if an output represents a 3 or a 7, we can just check whether it's greater than 0. So our accuracy for each item can be calculated (using broadcasting, so no loops!) with:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[False],\n", " [ True],\n", " [ True],\n", " [False]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(preds>0.0).float() == train_y[:4]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That gives us this function to calculate our validation accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def batch_accuracy(xb, yb):\n", " preds = xb.sigmoid()\n", " correct = (preds>0.5) == yb\n", " return correct.float().mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check it works:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5000)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_accuracy(linear1(batch), train_y[:4])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then put the batches together:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def validate_epoch(model):\n", " accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n", " return round(torch.stack(accs).mean().item(), 4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5219" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate_epoch(linear1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's our starting point. Let's train for one epoch, and see if the accuracy improves:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6883" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr = 1.\n", "params = weights,bias\n", "train_epoch(linear1, lr, params)\n", "validate_epoch(linear1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then do a few more:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8314 0.9017 0.9227 0.9349 0.9438 0.9501 0.9535 0.9564 0.9594 0.9618 0.9613 0.9638 0.9643 0.9652 0.9662 0.9677 0.9687 0.9691 0.9691 0.9696 " ] } ], "source": [ "for i in range(20):\n", " train_epoch(linear1, lr, params)\n", " print(validate_epoch(linear1), end=' ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking good! We're already about at the same accuracy as our \"pixel similarity\" approach, and we've created a general-purpose foundation we can build on. Our next step will be to create an object that will handle the SGD step for us. In PyTorch, it's called an *optimizer*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating an Optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because this is such a general foundation, PyTorch provides some useful classes to make it easier to implement. The first thing we can do is replace our `linear1` function with PyTorch's `nn.Linear` module. A *module* is an object of a class that inherits from the PyTorch `nn.Module` class. Objects of this class behave identically to standard Python functions, in that you can call them using parentheses and they will return the activations of a model.\n", "\n", "`nn.Linear` does the same thing as our `init_params` and `linear` together. It contains both the *weights* and *biases* in a single class. Here's how we replicate our model from the previous section:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "linear_model = nn.Linear(28*28,1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Every PyTorch module knows what parameters it has that can be trained; they are available through the `parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 784]), torch.Size([1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w,b = linear_model.parameters()\n", "w.shape,b.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use this information to create an optimizer:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BasicOptim:\n", " def __init__(self,params,lr): self.params,self.lr = list(params),lr\n", "\n", " def step(self, *args, **kwargs):\n", " for p in self.params: p.data -= p.grad.data * self.lr\n", "\n", " def zero_grad(self, *args, **kwargs):\n", " for p in self.params: p.grad = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can create our optimizer by passing in the model's parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt = BasicOptim(linear_model.parameters(), lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our training loop can now be simplified to:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_epoch(model):\n", " for xb,yb in dl:\n", " calc_grad(xb, yb, model)\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our validation function doesn't need to change at all:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4157" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate_epoch(linear_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's put our little training loop in a function, to make things simpler:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_model(model, epochs):\n", " for i in range(epochs):\n", " train_epoch(model)\n", " print(validate_epoch(model), end=' ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The results are the same as in the previous section:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4932 0.8618 0.8203 0.9102 0.9331 0.9468 0.9555 0.9629 0.9658 0.9673 0.9687 0.9707 0.9726 0.9751 0.9761 0.9761 0.9775 0.978 0.9785 0.9785 " ] } ], "source": [ "train_model(linear_model, 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai provides the `SGD` class which, by default, does the same thing as our `BasicOptim`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4932 0.852 0.8335 0.9116 0.9326 0.9473 0.9555 0.9624 0.9648 0.9668 0.9692 0.9712 0.9731 0.9746 0.9761 0.9765 0.9775 0.978 0.9785 0.9785 " ] } ], "source": [ "linear_model = nn.Linear(28*28,1)\n", "opt = SGD(linear_model.parameters(), lr)\n", "train_model(linear_model, 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai also provides `Learner.fit`, which we can use instead of `train_model`. To create a `Learner` we first need to create a `DataLoaders`, by passing in our training and validation `DataLoader`s:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = DataLoaders(dl, valid_dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create a `Learner` without using an application (such as `vision_learner`) we need to pass in all the elements that we've created in this chapter: the `DataLoaders`, the model, the optimization function (which will be passed the parameters), the loss function, and optionally any metrics to print:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n", " loss_func=mnist_loss, metrics=batch_accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can call `fit`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossbatch_accuracytime
00.6368570.5035490.49558400:00
10.5457250.1702810.86604500:00
20.1992230.1848930.83120700:00
30.0865800.1078360.91118700:00
40.0451850.0784810.93277700:00
50.0291080.0627920.94651600:00
60.0225600.0530170.95534800:00
70.0196870.0465000.96221800:00
80.0182520.0419290.96516200:00
90.0174020.0385730.96761500:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(10, lr=lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there's nothing magic about the PyTorch and fastai classes. They are just convenient pre-packaged pieces that make your life a bit easier! (They also provide a lot of extra functionality we'll be using in future chapters.)\n", "\n", "With these classes, we can now replace our linear model with a neural network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding a Nonlinearity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have a general procedure for optimizing the parameters of a function, and we have tried it out on a very boring function: a simple linear classifier. A linear classifier is very constrained in terms of what it can do. To make it a bit more complex (and able to handle more tasks), we need to add something nonlinear between two linear classifiers—this is what gives us a neural network.\n", "\n", "Here is the entire definition of a basic neural network:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def simple_net(xb): \n", " res = xb@w1 + b1\n", " res = res.max(tensor(0.0))\n", " res = res@w2 + b2\n", " return res" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's it! All we have in `simple_net` is two linear classifiers with a `max` function between them.\n", "\n", "Here, `w1` and `w2` are weight tensors, and `b1` and `b2` are bias tensors; that is, parameters that are initially randomly initialized, just like we did in the previous section:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1 = init_params((28*28,30))\n", "b1 = init_params(30)\n", "w2 = init_params((30,1))\n", "b2 = init_params(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key point about this is that `w1` has 30 output activations (which means that `w2` must have 30 input activations, so they match). That means that the first layer can construct 30 different features, each representing some different mix of pixels. You can change that `30` to anything you like, to make the model more or less complex.\n", "\n", "That little function `res.max(tensor(0.0))` is called a *rectified linear unit*, also known as *ReLU*. We think we can all agree that *rectified linear unit* sounds pretty fancy and complicated... But actually, there's nothing more to it than `res.max(tensor(0.0))`—in other words, replace every negative number with a zero. This tiny function is also available in PyTorch as `F.relu`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD7CAYAAABt0P8jAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhU5fn/8fctW1hlC7iBCAUUkDXivlRtcZcWtbK4tFYURau2Vluh4lK3trZqKZZvtaggLgiudalVvu7aEDaDEJQdBMIWk7AEkvv7x0x+v3FMyAmZJTPzeV3XXMyc85xz7nOYuXPmOWfux9wdERHJHPslOwAREUksJX4RkQyjxC8ikmGU+EVEMowSv4hIhmmY7ACCaN++vXfp0iXZYYiIpJQ5c+Zscvfs6Okpkfi7dOlCbm5ussMQEUkpZrayqunq6hERyTBK/CIiGUaJX0Qkwyjxi4hkGCV+EZEMU2PiN7MmZvaYma00s2Izm2tmZ+6l/Y1mtt7MiszscTNrEjGvi5m9a2bbzWyxmZ0eqx0REZFggpzxNwRWAycD+wPjgefMrEt0QzMbAtwKnAZ0AboCd0Q0mQ7MBdoBtwEzzOw795iKiEj81Jj43b3U3Se4+wp3r3D3V4HlwKAqml8GPObu+e6+FbgLuBzAzHoAA4Hb3X2Hu78ALASGxWhfRETSxleFJfzxzSXsKa+I+bpr3cdvZh2BHkB+FbN7A/MjXs8HOppZu/C8Ze5eHDW/dzXbGW1muWaWW1hYWNswRURS1vayPYyZOoenP1vF5tKymK+/VonfzBoB04An3H1xFU1aAEURryuft6xiXuX8llVty90nu3uOu+dkZ6s3SEQyg7tz26zPWbqxhIcu7k/HVlkx30bgxG9m+wFPAWXA2GqalQCtIl5XPi+uYl7l/GJERASAaZ+uYtbctdx4eg9O7B6fk95Aid/MDHgM6AgMc/fd1TTNB/pFvO4HbHD3zeF5Xc2sZdT8qrqMREQyzoI127jzlUWc0jObsd//Xty2E/SMfxJwBHCuu+/YS7sngSvMrJeZtQHGAVMA3L0AmAfcbmZZZvYjoC/wwr4GLyKSLraWljFmah7ZLZvw54v6s99+FrdtBbmP/1DgKqA/sN7MSsKPkWbWOfy8M4C7vwE8ALwLrAw/bo9Y3cVADrAVuA+4wN115VZEMlpFhXPjc/MoLN7F30YOpE3zxnHdXo1lmd19JbC3Pz0toto/CDxYzbpWAKcED09EJP1NfPdLZi8p5K6hfejXqXXct6eSDSIiSfTB0k08+HYBQ/sfxKijOydkm0r8IiJJ8nXRDq5/Zi7dO7Tgnh8fSeg+mvhT4hcRSYKyPRVcMy2PXbvLmTRqEM0aJ25AxJQYelFEJN3c868vmLtqGxNHDKRbdouaF4ghnfGLiCTYqwvWMeWjFfz0+C6c3ffAhG9fiV9EJIG+3FjCLTMWMLBza35z5hFJiUGJX0QkQUp3hYqvNWnUgIkjB9K4YXJSsPr4RUQSwN357ayFfFlYwlM/O5oD92+atFh0xi8ikgBTP1nJS/PWcdPpPTihe/ukxqLELyISZ/NWb+POVxfx/Z7ZXBvH4mtBKfGLiMTR1tIyrp2WR4eWWfz5J/EtvhaU+vhFROKkosK54dlQ8bUZY46ldbP4Fl8LSmf8IiJx8sg7X/K/BYXcfl4v+h4S/+JrQSnxi4jEwXsFhfzlPwX8eMDBjBicmOJrQSnxi4jE2NptO/jFM3Pp0aElv/9R4oqvBRV06MWxZpZrZrvMbMpe2j0aMVBLSbh9ccT82Wa2M2L+khjsg4hIvVG2p4Jrp+Wxu9yZNGogTRs3SHZI3xH04u464G5gCFDtrw7c/Wrg6srX4T8SFVHNxrr7P2oXpohIavj9a4uYt3obfxs5kK4JLr4WVKDE7+4zAcwsBzgkyDJm1hwYBpyzz9GJiKSQl+ev44mPV/LzEw7jrCMTX3wtqHj28Q8DCoH3oqbfa2abzOxDMzuluoXNbHS4eym3sFDD8opI/bZ0QzG3vrCAnEPbcMuZhyc7nL2KZ+K/DHjS3T1i2i1AV+BgYDLwipl1q2phd5/s7jnunpOdnR3HMEVE6qZ01x7GTMujWeNQ8bVGDer3fTNxic7MOgEnA09GTnf3T9292N13ufsTwIfAWfGIQUQkEdydW2cuZFlhCQ8PH0DHVlnJDqlG8fqzdCnwkbsvq6GdA/XrPicRkVp48uOVvDJ/Hb/8YU+O65bc4mtBBb2ds6GZZQENgAZmlmVme7swfCkwJWodrc1sSOWyZjYSOAl4cx9jFxFJqrxVW7n7tUWcengHxpxcZa91vRT0jH8csAO4FRgVfj7OzDqH78f/fz9LM7NjCd3583zUOhoRuiW0ENgEXAcMdXfdyy8iKWdLaRljp+XRsVUWf76ofhRfCyro7ZwTgAnVzP7Wjaru/jHQvIp1FAJH1S48EZH6p7zC+cUzc9lUWsbMMcexf7NGyQ6pVur3pWcRkXro4f8s5f2lm7jjvN70OXj/ZIdTa0r8IiK1MHvJRh5+ZynDBh7CxUd1SnY4+0SJX0QkoDVbt3PDs/Po2bEldw/tU++KrwWlxC8iEsCuPeVcOy2P8nJn0qhB9bL4WlAagUtEJIC7X/2C+WuKeHTUIA5r/537V1KKzvhFRGrw0ry1PPXJSkaf1JUz+hyQ7HDqTIlfRGQvCjYUc+sLCxncpS03D+mZ7HBiQolfRKQaJbv2cPXUOTRv0pBHRgyo98XXgkqPvRARiTF355YXFrBiUykPD++fEsXXglLiFxGpwpSPVvDagq+5ecjhKVN8LSglfhGRKHNWbuX3r33B6Ud05OqTuyY7nJhT4hcRibC5ZBdjn87joNZN+dNF/VL2R1p7o/v4RUTCQsXX5rG5svha09QqvhaUzvhFRMIeeruAD77cxF3np2bxtaCU+EVEgHcXb+Thd77kwkGH8JOjOte8QAoLOgLXWDPLNbNdZjZlL+0uN7Py8OAslY9TIuZ3MbN3zWy7mS02s9PrvgsiInWzekuo+NoRB7birqF9kh1O3AXt419HaPSsIUDTGtp+7O4nVDNvOvAxoQHWzwJmmFn38CAtIiIJt2tPOdc+nUdFhTNp5ECyGqVu8bWgAp3xu/tMd38R2LyvGzKzHsBA4HZ33+HuLwALgWH7uk4Rkbq685VFLFhTxB8v6keXFC++FlQ8+vgHmNkmMysws/ERg7L3Bpa5e3FE2/nh6d9hZqPD3Uu5hYX6QiAisTczbw3TPl3FVSd1ZUjv1C++FlSsE/97QB+gA6Ez+eHAzeF5LYCiqPZFQMuqVuTuk909x91zsrOzYxymiGS6xeu/4bezFjL4sPQpvhZUTBO/uy9z9+XuXuHuC4E7gQvCs0uAVlGLtAKKERFJoOKduxkzNY+WWY3464gBNEyT4mtBxXtvHaj82Vs+0NXMIs/w+4Wni4gkhLvz6xkLWLVlO38dPoAOLdOn+FpQQW/nbGhmWUADoIGZZUX03Ue2O9PMOoafHw6MB14CcPcCYB5we3j5HwF9gRdisysiIjV77IPlvP75en49pCdHd22X7HCSIugZ/zhgB3ArMCr8fJyZdQ7fq1/5a4fTgAVmVgr8C5gJ3BOxnouBHGArcB9wgW7lFJFEyV2xhfteX8wPe3Vk9EnpV3wtKHP3ZMdQo5ycHM/NzU12GCKSwjaV7OLsh98nq1EDXh57QtrW4YlkZnPcPSd6uoq0iUjaK69wrp8+l23bdzPrmsEZkfT3RolfRNLeg/9ewkdfbeaBC/rS66DomwszT2bdwyQiGeedxRuY+O5X/CSnExfldEp2OPWCEr+IpK3VW7ZzwzPz6HVgK+44v8oiARlJiV9E0tLO3eVcMy0PBx4dNSgjiq8FpT5+EUlLd7yyiIVri/ifS3Po3K5ZssOpV3TGLyJp54U5a5j+2SquPrkbP+jVMdnh1DtK/CKSVhav/4bbXlzIMV3b8qsf9kh2OPWSEr+IpI3K4mutshrxyPCBGVd8LSj18YtIWnB3bn4+VHxt+pXHkN2ySbJDqrf051BE0sJjHyznjfz13HrG4Qw+rG2yw6nXlPhFJOV9tnwL976+mDN6H8DPTzws2eHUe0r8IpLSNhbvZOzTeXRq05QHLuyLmdW8UIZTH7+IpKw95RVcP30u3+zczRM/G0yrrMwuvhZU0IFYxoYHPt9lZlP20u4yM5tjZt+Y2RozeyBywBYzm21mO8M1/EvMbEkM9kFEMtSf/l3AJ8u2cPfQIzniQBVfCypoV8864G7g8RraNQNuANoDRxMamOVXUW3GunuL8COzRjgWkZj596INTJr9FcMHd+KCQYckO5yUEqirx91nAphZDlDtEXb3SREv15rZNOD7dYpQRCTKqs3buem5efQ5uBW3n6via7UV74u7J/HdwdTvNbNNZvahmZ1S3YJmNjrcvZRbWKjRGUUkZOfucsZMm4MBk0aq+Nq+iFviN7OfEhpf948Rk28BugIHA5OBV8ysW1XLu/tkd89x95zs7Ox4hSkiKWbCy/nkr/uGv1zcn05tVXxtX8Ql8ZvZUEKDqZ/p7psqp7v7p+5e7O673P0J4EPgrHjEICLp57nc1Tzz39Vc+/1unHq4iq/tq5jfzmlmZwD/A5zt7gtraO6AbroVkRrlryti/Iufc1y3dtz0A90XUhdBb+dsaGZZQAOggZllRd6mGdHuVGAaMMzdP4ua19rMhlQua2YjCV0DeLPuuyEi6axox26umZZH62aNeHj4ABrsp/PFugja1TMO2AHcCowKPx9nZp3D9+N3DrcbD+wP/CviXv3Xw/MaEboltBDYBFwHDHV33csvItUKFV+bz9qtO5g4YiDtW6j4Wl0FvZ1zAjChmtktItpVe+umuxcCR9UiNhERJr+3jLcWbWDc2UeQ00XF12JBtXpEpN76dNlmHnhzCWcdeQBXnKDia7GixC8i9dLGb3YydvpcDm3bjPuHqfhaLKlIm4jUO3vKK7hu+lxKdu5h6hVH01LF12JKiV9E6p0/vLWET5dv4c8/6UfPA1omO5y0o64eEalX3spfz9//dxkjj+7Mjwao+Fo8KPGLSL2xcnMpv3x+Pn0P2Z/fndsr2eGkLSV+EakXdu4uZ8zUPPYzY+KIgTRpqOJr8aI+fhGpF3730ucs+vob/nn5USq+Fmc64xeRpHv2v6t4LncN1536Pb5/eIdkh5P2lPhFJKk+X1vE+JfyOeF77bnh9B7JDicjKPGLSNJUFl9r26wxD13cX8XXEkR9/CKSFBUVzi+fm8+6bTt49qpjaKfiawmjM34RSYq/v7eMt7/YwG/POoJBh6r4WiIp8YtIwn381Wb+8OZizu57ID89vkuyw8k4SvwiklAbv9nJddPn0qV9cxVfS5KgI3CNNbNcM9tlZlNqaHujma03syIze9zMmkTM62Jm75rZdjNbbGan1zF+EUkhu8srGPv0XEp37eHRUYNo0USXGZMh6Bn/OkKjZz2+t0ZmNoTQKF2nAV2ArsAdEU2mA3OBdsBtwAwzy65dyCKSqv7w5hI+W7GF+4YdSY+OKr6WLIESv7vPdPcXgc01NL0MeMzd8919K3AXcDmAmfUABgK3u/sOd38BWAgM29fgRSR1vPH5eia/t4xLjjmU8/sfnOxwMlqs+/h7A/MjXs8HOppZu/C8Ze5eHDW/d1UrMrPR4e6l3MLCwhiHKSKJtGJTKTc/P59+nVoz7pwjkh1Oxot14m8BFEW8rnzesop5lfOr/L7n7pPdPcfdc7Kz1Rskkqp2lJVz9dQ5NGhgTBwxQMXX6oFYX1kpAVpFvK58XlzFvMr5xYhIWnJ3xr/0OUs2FPPPy4/ikDYqvlYfxPqMPx/oF/G6H7DB3TeH53U1s5ZR8/NjHIOI1BPP/nc1M+as4bpTu3NKTxVfqy+C3s7Z0MyygAZAAzPLMrOqvi08CVxhZr3MrA0wDpgC4O4FwDzg9vDyPwL6Ai/EYD9EpJ75fG0Rv3s5nxO7t+cXp3VPdjgSIegZ/zhgB6FbNUeFn48zs85mVmJmnQHc/Q3gAeBdYGX4cXvEei4GcoCtwH3ABe6uK7ciaaZo+26unjqHds0b89DFA1R8rZ4xd092DDXKycnx3NzcZIchIgFUVDhXPpnLe0sLee6qYxnQuU2yQ8pYZjbH3XOip6tkg4jE1KPvfcV/Fm9k3Nm9lPTrKSV+EYmZj77axB/fXMK5/Q7i0mMPTXY4Ug0lfhGJifVFO7l++lwOa9+c+358pIqv1WOqkCQidRYqvpbH9rJypl95DM1VfK1e0/+OiNTZ/a8vJnflVh4ePoDuKr5W76mrR0Tq5PWFX/OPD5Zz2bGHcl6/g5IdjgSgxC8i+2xZYQk3z1hA/06tue3sXskORwJS4heRfbKjrJxrpuXRqIExceRAGjdUOkkV6uMXkVpzd257cSFLNhTzxE8Hc3DrpskOSWpBf6JFpNamf7aamXlr+cVp3Tmph8qmpxolfhGplQVrtjHh5XxO6pHN9aeq+FoqUuIXkcC2bS9jzNQ82rdozF9+0p/9VHwtJamPX0QCqahwbnx2HhuLd/L81cfRtnnjZIck+0hn/CISyN9mf8m7SwoZf04v+ndqnexwpA6U+EWkRh9+uYkH/13Aef0O4pJjVHwt1QUdgautmc0ys1IzW2lmI6pp93p4YJbKR5mZLYyYv8LMdkTMfytWOyIi8VFZfK1rdgvuVfG1tBC0j38iUAZ0BPoDr5nZfHf/1ni57n5m5Gszmw28E7Wuc9397X0LV0QSaXd5Bdc+ncfO3eU8OmqQiq+liRrP+M2sOTAMGO/uJe7+AfAycEkNy3UBTgSeqnuYIpIM9/5rMXNWbuX+C/ryvQ4tkh2OxEiQrp4eQHl4sPRK84HeNSx3KfC+uy+Pmj7NzArN7C0z61fdwmY22sxyzSy3sFDD8ook2msLvubxD5dz+XFdOKeviq+lkyCJvwVQFDWtCKip9uqlwJSoaSOBLsChhAZkf9PMqrw9wN0nu3uOu+dkZ+uXgSKJ9FVhCb+eMZ+BnVvz27OOSHY4EmNBEn8J0CpqWiuguLoFzOwE4ABgRuR0d//Q3Xe4+3Z3vxfYRqg7SETqie1lexgzdQ5NGjVQ8bU0FeR/tABoaGaRv83uB+RX0x7gMmCmu5fUsG4HdIuASD3h7tw263OWbizhoYv7c+D+Kr6WjmpM/O5eCswE7jSz5mZ2PHA+1Vy0NbOmwIVEdfOYWWczO97MGptZlpndDLQHPqzjPohIjEz7dBWz5q7lxtN7cGJ3dbGmq6Df4a4BmgIbgenAGHfPN7MTzSz6rH4ooWsA70ZNbwlMArYCa4EzgDPdffO+Bi8isTN/9TbufGURp/TMZuz3v5fscCSOzN2THUONcnJyPDc3N9lhiKStraVlnPPIBwC8et0JtFEdnrRgZnPcPSd6un6NIZLhKiqcG5+bR2HxLp6/+lgl/Qygy/UiGe6v737J7CWFjD+3F/1UfC0jKPGLZLD3lxby57cLGNr/IEYd3TnZ4UiCKPGLZKh123bwi2fm0b1DC+5R8bWMosQvkoHK9lQw9uk8du0uZ9KoQTRrrMt9mUT/2yIZ6J5/fUHeqm1MHDGQbtkqvpZpdMYvkmFemb+OKR+t4GfHH8bZfQ9MdjiSBEr8Ihnky40l3PrCAgYd2obfnHV4ssORJFHiF8kQpbtCxdeyGjVg4oiBNGqgj3+mUh+/SAZwd347ayFfFZbw1BVHc8D+WckOSZJIf/JFMsDUT1by0rx13PSDHhz/vfbJDkeSTIlfJM3NXbWVO19dxKmHd+CaU1R8TZT4RdLaltIyrp2WR8dWWTx4UT/2208/0hL18YukrfIK54Zn57GppIwZY46ldTMVX5OQQGf8ZtbWzGaZWamZrTSzEdW0m2Bmu82sJOLRNWJ+fzObY2bbw//2j9WOiMi3PfLOUt4rKOT283rR9xAVX5P/L2hXz0SgDOhIaMD0SWbWu5q2z7p7i4jHMgAzawy8BEwF2gBPAC+Fp4tIDP1vQSEP/WcpPx5wMCMGq/iafFuNid/MmgPDgPHuXuLuHwAvA5fUclunEOpa+ou773L3hwmNt3tqLdcjInuxdtsObnhmLj06tOT3P1LxNfmuIGf8PYBydy+ImDYfqO6M/1wz22Jm+WY2JmJ6b2CBf3vIrwXVrcfMRptZrpnlFhYWBghTRHbtKeeaaXnsLncmjRpI08YNkh2S1ENBEn8LQmPoRioiNIZutOeAI4Bs4Ergd2Y2fB/Wg7tPdvccd8/JztagzyJB/P61L5i/eht/vLAvXVV8TaoRJPGXAK2iprUCiqMbuvsid1/n7uXu/hHwEHBBbdcjIrX30ry1PPnxSn5+wmGc0UfF16R6QRJ/AdDQzLpHTOsH5AdY1gn14xNu39e+3eHYN+B6RGQvlm4o5jczF3JUlzbccqaKr8ne1Zj43b0UmAncaWbNzex44Hzgqei2Zna+mbWxkMHA9YTu5AGYDZQD15tZEzMbG57+Tgz2QyRjlezaw9VT59CscQMeGa7ia1KzoO+Qa4CmwEZgOjDG3fPN7EQzK4lodzHwJaHumyeB+939CQB3LwOGApcC24CfAUPD00VkH7g7t76wgOWbSnl4+AAVX5NAAv1y1923EEra0dPfJ3TRtvL18Og2Ue3nAoNqGaOIVOOJj1bw6oKvuXlIT47rpuJrEoy+E4qkqDkrt3L3a19w2uEdGHNyt2SHIylEiV8kBW0u2cXYp/M4sHUWD17UX8XXpFZUpE0kxVQWX9tcWsbMMcexf7NGyQ5JUozO+EVSzEP/Wcr7Szdxx3m96XPw/skOR1KQEr9ICpm9ZCOPvLOUYQMP4eKjOiU7HElRSvwiKWLN1u3c8Ow8enZsyd1D+6j4muwzJX6RFFBZfK283Jk0apCKr0md6OKuSAq469VFLFhTxKOjBnFY++bJDkdSnM74Req5F+euZeonqxh9UlfO6HNAssORNKDEL1KPFYSLrw3u0pabh/RMdjiSJpT4ReqpyuJrzZs05K8jBqj4msSM3kki9ZC7c8uMBazYVMojwwfQoZWKr0nsKPGL1EP//HAFry38mpuHHM6x3dolOxxJM0r8IvVM7oot3POvLzj9iI5cdVLXZIcjaUiJX6Qe2VSyi2ufzuOg1k3500X9VHxN4iJQ4jeztmY2y8xKzWylmY2opt3NZva5mRWb2XIzuzlq/goz22FmJeHHW7HYCZF0UF7h/OKZuWzdvpu/jRzI/k1VfE3iI+gPuCYCZUBHoD/wmpnNd/fo8XKN0AhbC4BuwFtmttrdn4loc667v13HuEXSzp//XcCHX27m/mFHqviaxFWNZ/xm1hwYBox39xJ3/wB4Gbgkuq27P+Duee6+x92XEBpv9/hYBy2Sbt5ZvIG/vvslFw46hJ8c1TnZ4UiaC9LV0wMod/eCiGnzgd57W8hCFaROBKK/FUwzs0Ize8vM+u1l+dFmlmtmuYWFhQHCFElNq7ds58Zn53PEga24a2ifZIcjGSBI4m8BFEVNKwJa1rDchPD6/xkxbSTQBTgUeBd408xaV7Wwu0929xx3z8nOzg4Qpkjq2bk7VHytwp1HRw0kq5GKr0n8BUn8JUCrqGmtgOLqFjCzsYT6+s92912V0939Q3ff4e7b3f1eYBuhbwUiGenOVxexcG0Rf7qwH4e2U/E1SYwgib8AaGhm3SOm9eO7XTgAmNnPgFuB09x9TQ3rdkIXhEUyzsy8NTz96SquPrkbP+yt4muSODUmfncvBWYCd5pZczM7HjgfeCq6rZmNBO4BfuDuy6LmdTaz482ssZllhW/1bA98GIsdEUkli9d/w29nLeSYrm351Q97JDscyTBBf8B1DdAU2AhMB8a4e76ZnWhmJRHt7gbaAf+NuFf/0fC8lsAkYCuwFjgDONPdN8diR0RSRfHO3YyZmkerrEY8PHwADVV8TRIs0H387r4FGFrF9PcJXfytfH3YXtaRD/TdhxhF0oa78+sZC1i1ZTvTrzyGDi1VfE0ST6caIgn02AfLef3z9dxyRk8GH9Y22eFIhlLiF0mQ/67Ywr2vL2ZI745ceaKKr0nyKPGLJEBh8S6unZZHpzZN+cOF/Qj9vlEkOTTYukic7Smv4PrpcynasZspPx1MqywVX5PkUuIXibMH/13Ax8s284cL+tLroOjfQooknrp6ROLo7UUb+Nvsr7j4qE5cmNMp2eGIAEr8InGzavN2bnpuHr0PasWE8/Za01AkoZT4ReJg5+5yrnl6DgCTRg5S8TWpV9THLxIHd7ySz+drv+Efl+bQuV2zZIcj8i064xeJsRlz1jD9s9Vcc0o3Tu/VMdnhiHyHEr9IDH3x9TfcNmshx3Ztx00/UPE1qZ+U+EVi5JuduxkzdQ77N1XxNanf1McvEgPuzq+fX8DqrTt4ZvQxZLdskuyQRKqlUxKRGPjH+8t5I389vznzcI7qouJrUr8p8YvU0afLNnPfG4s5s88BXHFCtZXJReqNQInfzNqa2SwzKzWzlWY2opp2Zmb3m9nm8OMBi6hGZWb9zWyOmW0P/9s/VjsikgyfLNvMtU/PpXPbZjxwQV8VX5OUEPSMfyJQBnQERgKTzKyqnyKOJjRgSz9Cg66cA1wFYGaNgZeAqUAb4AngpfB0kZRSvHM3t81ayMWTP6FZ4wb8/ZJBtFTxNUkRNV7cNbPmwDCgj7uXAB+Y2cvAJYQGVY90GfCnykHWzexPwJXAo8Ap4e39xd0deNjMfgWcCrwRm935tp8/8V9Wbt4ej1VLhttUsouiHbv5+QmH8csf9qRpY/0yV1JHkLt6egDl7l4QMW0+cHIVbXuH50W26x0xb0E46VdaEJ7+ncRvZqMJfYOgc+fOAcL8rs5tm9O4oS5jSOz1PqgVlx3XhQGd2yQ7FJFaC5L4WwBFUdOKCA2eXlPbIqBFuJ+/NuvB3ScDkwFycnK8qjY1+d25vfZlMRGRtBbkdLgEiC4i3gooDtC2FVASPsuvzXpERCROgiT+AqChmXWPmNYPyK+ibX54XlXt8oG+9u3bHvpWsx4REYmTGhO/u5cCM4E7zay5mR0PnA88VUXzJ4GbzOxgMzsI+CUwJTxvNlAOXG9mTcxsbHj6O9VBRU4AAAUOSURBVHXbBRERqY2gVz6vAZoCG4HpwBh3zzezE82sJKLd34FXgIXA58Br4Wm4exmhWz0vBbYBPwOGhqeLiEiC2LdvsqmfcnJyPDc3N9lhiIikFDOb4+450dN1r6OISIZR4hcRyTBK/CIiGSYl+vjNrBBYuY+Ltwc2xTCcWFFctaO4akdx1U66xnWou2dHT0yJxF8XZpZb1cWNZFNctaO4akdx1U6mxaWuHhGRDKPELyKSYTIh8U9OdgDVUFy1o7hqR3HVTkbFlfZ9/CIi8m2ZcMYvIiIRlPhFRDKMEr+ISIZJq8QfLvf8mJmtNLNiM5trZmfWsMyNZrbezIrM7HEzaxKn2MaaWa6Z7TKzKTW0vdzMys2sJOJxSrLjCrdP1PFqa2azzKw0/P85Yi9tJ5jZ7qjj1TWRcVjI/Wa2Ofx4IGrsiZirRWxxOz5VbKs27/OEvJdqE1ciP3vh7dUqZ8XqmKVV4ic0lORqQuMB7w+MB54zsy5VNTazIYQGjD8N6AJ0Be6IU2zrgLuBxwO2/9jdW0Q8Zic7rgQfr4lAGdARGAlMMrPee2n/bNTxWpbgOEYTKjvej9AAQ+cAV8UohrrGBvE7PtECvZ8S/F4KHFdYoj57UIucFdNj5u5p/SA0oPuwauY9DdwT8fo0YH2c47kbmFJDm8uBDxJ8nILElZDjBTQnlNB6REx7CrivmvYTgKnJjAP4CBgd8foK4JM4/n/VJra4HJ+6vJ+S8dkLGFfCP3tVxFBlzorlMUu3M/5vMbOOQA+qH96xNzA/4vV8oKOZtYt3bAEMMLNNZlZgZuPNrGGyAyJxx6sHUO7uBVHb2tsZ/7lmtsXM8s1sTBLiqOrY7C3eRMYG8Tk+daHPXhVqyFkxO2Zpm/jNrBEwDXjC3RdX06wFUBTxuvJ5y3jGFsB7QB+gAzAMGA7cnNSIQhJ1vKK3U7mt6rbzHHAEkA1cCfzOzIYnOI6qjk2LOPbz1ya2eB2futBnL0qAnBWzY5ZSid/MZpuZV/P4IKLdfoS+9pYBY6tdIZQArSJeVz4vjkdcQbn7Mndf7u4V7r4QuBO4oLbriXVcJO54RW+ncltVbsfdF7n7Oncvd/ePgIfYh+NVhdrEUdWxKfHwd/I4CBxbHI9PXcTkvRRrsfrs1VbAnBWzY5ZSid/dT3F3q+ZxAoTurgAeI3TBa5i7797LKvMJXYyr1A/Y4O6bYx1XHTlQ6zPHOMSVqONVADQ0s+5R26quy+47m2AfjlcVahNHVccmaLzxji1arI5PXcTkvZQAcT9WtchZMTtmKZX4A5pE6Gvtue6+o4a2TwJXmFkvM2sDjAOmxCMoM2toZllAA6CBmWVV13doZmeG+/ows8MJXel/KdlxkaDj5e6lwEzgTjNrbmbHA+cTOiOqah/ON7M2FjIYuJ4YHK9axvEkcJOZHWxmBwG/JE7vpdrGFq/jU5VavJ8S9tmrTVyJ/OxFCJqzYnfMknn1OtYP4FBCf6F3EvpaVPkYGZ7fOfy6c8QyNwEbgG+AfwJN4hTbhHBskY8JVcUF/DEcUymwjNDXzUbJjivBx6st8GL4GKwCRkTMO5FQN0rl6+nA5nCsi4Hr4x1HFTEY8ACwJfx4gHAtrDi+34PGFrfjE/T9lMz3Um3iSuRnL7y9anNWPI+ZirSJiGSYdOzqERGRvVDiFxHJMEr8IiIZRolfRCTDKPGLiGQYJX4RkQyjxC8ikmGU+EVEMsz/AXPhmYaLqaN6AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(F.relu)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> J: There is an enormous amount of jargon in deep learning, including terms like _rectified linear unit_. The vast vast majority of this jargon is no more complicated than can be implemented in a short line of code, as we saw in this example. The reality is that for academics to get their papers published they need to make them sound as impressive and sophisticated as possible. One of the ways that they do that is to introduce jargon. Unfortunately, this has the result that the field ends up becoming far more intimidating and difficult to get into than it should be. You do have to learn the jargon, because otherwise papers and tutorials are not going to mean much to you. But that doesn't mean you have to find the jargon intimidating. Just remember, when you come across a word or phrase that you haven't seen before, it will almost certainly turn out to be referring to a very simple concept." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic idea is that by using more linear layers, we can have our model do more computation, and therefore model more complex functions. But there's no point just putting one linear layer directly after another one, because when we multiply things together and then add them up multiple times, that could be replaced by multiplying different things together and adding them up just once! That is to say, a series of any number of linear layers in a row can be replaced with a single linear layer with a different set of parameters.\n", "\n", "But if we put a nonlinear function between them, such as `max`, then this is no longer true. Now each linear layer is actually somewhat decoupled from the other ones, and can do its own useful work. The `max` function is particularly interesting, because it operates as a simple `if` statement." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> S: Mathematically, we say the composition of two linear functions is another linear function. So, we can stack as many linear classifiers as we want on top of each other, and without nonlinear functions between them, it will just be the same as one linear classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Amazingly enough, it can be mathematically proven that this little function can solve any computable problem to an arbitrarily high level of accuracy, if you can find the right parameters for `w1` and `w2` and if you make these matrices big enough. For any arbitrarily wiggly function, we can approximate it as a bunch of lines joined together; to make it closer to the wiggly function, we just have to use shorter lines. This is known as the *universal approximation theorem*. The three lines of code that we have here are known as *layers*. The first and third are known as *linear layers*, and the second line of code is known variously as a *nonlinearity*, or *activation function*.\n", "\n", "Just like in the previous section, we can replace this code with something a bit simpler, by taking advantage of PyTorch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "simple_net = nn.Sequential(\n", " nn.Linear(28*28,30),\n", " nn.ReLU(),\n", " nn.Linear(30,1)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nn.Sequential` creates a module that will call each of the listed layers or functions in turn.\n", "\n", "`nn.ReLU` is a PyTorch module that does exactly the same thing as the `F.relu` function. Most functions that can appear in a model also have identical forms that are modules. Generally, it's just a case of replacing `F` with `nn` and changing the capitalization. When using `nn.Sequential`, PyTorch requires us to use the module version. Since modules are classes, we have to instantiate them, which is why you see `nn.ReLU()` in this example. \n", "\n", "Because `nn.Sequential` is a module, we can get its parameters, which will return a list of all the parameters of all the modules it contains. Let's try it out! As this is a deeper model, we'll use a lower learning rate and a few more epochs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, simple_net, opt_func=SGD,\n", " loss_func=mnist_loss, metrics=batch_accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossbatch_accuracytime
00.3058280.3996630.50834100:00
10.1429600.2257020.80765500:00
20.0795160.1135190.91952900:00
30.0523910.0767920.94308100:00
40.0397960.0600830.95633000:00
50.0333680.0507130.96369000:00
60.0296800.0447970.96565300:00
70.0272900.0407290.96810600:00
80.0255680.0377710.96859700:00
90.0242330.0355080.97055900:00
100.0231490.0337140.97203100:00
110.0222420.0322430.97252200:00
120.0214680.0310060.97350300:00
130.0207960.0299440.97448500:00
140.0202070.0290160.97546600:00
150.0196830.0281960.97644800:00
160.0192150.0274630.97644800:00
170.0187910.0268060.97693800:00
180.0184050.0262120.97792000:00
190.0180510.0256710.97792000:00
200.0177250.0251790.97792000:00
210.0174220.0247280.97841000:00
220.0171410.0243130.97890100:00
230.0168780.0239320.97939200:00
240.0166320.0235800.97988200:00
250.0164000.0232540.97988200:00
260.0161810.0229520.97988200:00
270.0159750.0226720.98086400:00
280.0157790.0224110.98086400:00
290.0155930.0221680.98184500:00
300.0154170.0219410.98184500:00
310.0152490.0217280.98184500:00
320.0150880.0215290.98184500:00
330.0149350.0213410.98184500:00
340.0147880.0211640.98184500:00
350.0146470.0209980.98233600:00
360.0145120.0208400.98282600:00
370.0143820.0206910.98282600:00
380.0142570.0205500.98282600:00
390.0141360.0204150.98282600:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide_output\n", "learn.fit(40, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're not showing the 40 lines of output here to save room; the training process is recorded in `learn.recorder`, with the table of output stored in the `values` attribute, so we can plot the accuracy over training as:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD+CAYAAADBCEVaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZfElEQVR4nO3df5Dc9X3f8efrfkv3AxA6JLABBZBsLCei9TlxQ3CSxjGxmwwE7GkMxnRSBxfGE6ZOm3gypqW4M9RuO5k2Q7CZIcbGrmwngQSHGjvT4BjsukbEFa5sfOfElnAMdycJ7m5Pd7e3d+/+8d09rVZ7d1+dVtrd7/f1mLnR7Xe/t/fmo9OLz32+n31/FRGYmVm2dDS7ADMzazyHu5lZBjnczcwyyOFuZpZBDnczswxyuJuZZZDD3cwsg1KFu6T3S9onaUHSQ+uc+68lvSRpStIfS+ptSKVmZpaa0ryJSdINwDJwLbApIv7FKuddC3wK+KfAj4FHgW9ExAfXev2tW7fGjh07TqlwM7O8e/bZZw9HxHC957rSvEBEPAIgaQR49Rqn3go8GBEHyud/GPgMsGa479ixg3379qUpxczMyiQdXO25Rq+57wb2Vz3eD2yTdH6Dv4+Zma2h0eE+AExVPa58Plh7oqTbyuv4+yYnJxtchplZvjU63AvAUNXjyucztSdGxAMRMRIRI8PDdZeMzMxsgxod7geAPVWP9wDjEXGkwd/HzMzWkHYrZJekPqAT6JTUJ6nexdhPAf9S0usknQd8CHioYdWamVkqaWfuHwLmSHa9vLv8+YckXSKpIOkSgIh4Avgo8CRwsPzx7xtetZmZrSnVPvczbWRkJLwV0szs1Eh6NiJG6j2Xap+7mVk9EcFCaZnCQonl5eZPFOsJYHFpmWJpmcWloFhapri0RLEUFFeOJ38WS8ssLC2zWFpeea60tHxG6xvZsYU372r8phKHu1mTlZaSIFksBQtLSyshUx1EC+XHp/Ob9tJyUFgoMbuwxOxCiZmFErPlj5mFEsXS6iEWAQul5Osqr1Eof77UoqHeSNKZe+1/9fOXO9zN1lIJybni0gnhUwmkyuel0wij5QgWS0FxaWklfBfKYVysmfEVq2aDi0vLJ8wSF6qONzMb+3s66e/tYqC3i56uDrRGivV2dTDQ28XwYC8Dvd0M9HYy0Ne18vWdHWcwAU9Td0cHPV0ddHcmfyafi97qY1XP9XZ20t0lejo76Opsz/6KDndruNLSclWYLlFYWKRQni0W5k8O3NrHx4pLye/Sq1iOWAnPZoVkZ4fo7lQ5EDrprQqMSlB0d3Yw2Nd1QoBUB0nvSWFTCZaTgyj5PlozfNfTITFQDuL+3k76e7roaOFAttPjcM+ppeWoWmNcOnEZoLTMfGlpJXRnF0rMzJd/nS9WPq8K5/kSs8Xj5y2s8et9tcpMcKCvi/6eJHQuGOxjU08nHWuEmOCE0KsNyb6uDgb6kpllZVaZBFry0XOaM7Gero6WnqWagcM9MyKCV44t8tL0PC9NzzMxPc+R2SIvzxY5OrvIy8eKHJ0trvw5M1/a0PfZ1F0JzM6VUL7wnL4Tfj3v70kCe7AcpgN9XXWDtrtNf901awcO9xZTWCjxwtFjHDp6jJdniyet31Zf1T8yW2RieoGXpucZn56vO2Pu6+7g/P5ezuvv5rzNPVx6/mbO29zDuZu76evuTJYP6iwF9HZ10N/bxeAJod3ZtuuPZnnjcD+LqmfX49PzvDg1z6Gjx3ih8vHyHEdni2u+RmeHyuu54rz+HrYN9XHVxeey/Zw+tg31sW2ol+1DyedbB3rZ1NN5lv7rzKyVONzPgIjgbw+9wpcPvMSPp+YZn5pfdXbd1SFedd4mLtmymWsvOodLtmzm4i3J4/MHeleu4FfWlb3Wa2ZpONwb6OXZIo986x/43DOHGB0v0NPZwYXnJrPoPRefy/ahXrYN9bH9nD62V/3ppQ4zazSH+2laXg6+8fdH2PvMC3zp/71EcWmZqy4+l/90w0/yq3suYqDXQ2xmZ5+TZ4MWSkt84ms/ZO83D3HwyDHO2dTNTT9zCf/8jRdz5YVD67+AmdkZ5HDfgNHxGe787P/luy9O86bLtvCBX97Ftbu309fti5dm1hoc7qcgIvjk13/IvV98noHeLh68dYRfunJbs8syMzuJwz2liZl5/u2fPMffjE7yi68Z5qPv2MPwYG+zyzIzq8vhnsJffWec3/uz55hdKPHh63bz7jddelo9PszMzjSH+xqOFUt8+C+/y95vHuJ1Fw7x3991FVdcMNjssszM1uVwX8XM/CLv/Nj/5nvjM7zvzZfxgbfuorfLF0zNrD043OtYXg5+5/P7GZso8Me3vpFffO0FzS7JzOyU+K2RdfzRV77Pl78zzu+//UoHu5m1JYd7jSe/N8F//atRrr/qIn7z6h3NLsfMbEMc7lV+eHiWO/d+iyu3D3HvDT/lHTFm1rYc7mWzCyXe9/CzdHSIj9/yBrfKNbO25nAneefp7/7Zc4xNzPCH7/pHXLxlc7NLMjM7LQ534IGv/j2PP/civ/srr+WancPNLsfM7LTlPtyfHjvMR554nn/2kxfyvjdf1uxyzMwaItfh/sLRY7x/79+y84JBPvoOX0A1s+zIdbj/x8e/w9Jy8PFb3kC/b6phZhmS63B/7kdTvOXKbezY2t/sUszMGiq34T49v8iLU/NcccFAs0sxM2u43Ib72HgBgF3b3OXRzLInx+E+A8CubZ65m1n2pAp3SVskPSppVtJBSTetct65kj4paaL8cXdDq22g0fECfd0dXHye37BkZtmTdovIfUAR2AZcBTwuaX9EHKg57w+AzcAO4ALgf0k6GBGfaFC9DTM2McMVFwzQ0eHtj2aWPevO3CX1AzcCd0VEISKeBh4Dbqlz+q8BH42IYxHxQ+BB4DcbWG/DjI0X2OW7KplZRqVZltkFLEXEaNWx/cDuVc5Xzeev32BtZ8zU3CIvTc+z0xdTzSyj0oT7ADBVc2wKqJeMTwAflDQo6QqSWXvdRW1Jt0naJ2nf5OTkqdR82r4/kVxM3eltkGaWUWnCvQAM1RwbAmbqnPvbwBwwBvwFsBf4Ub0XjYgHImIkIkaGh89us65Rb4M0s4xLE+6jQJeknVXH9gC1F1OJiKMRcXNEbI+I3eXX/2ZjSm2c0fEZNnV38urzNjW7FDOzM2Ld3TIRMSvpEeAeSe8l2S1zHfCztedKuhx4pfzxVuA24OcbWnEDjI0XvFPGzDIt7ZuY7gA2ARMkSy23R8QBSddIKlSd9wbg2yRLNvcCN9fZLtl0YxMz7PSbl8wsw1Ltc4+Io8D1dY4/RXLBtfL488DnG1bdGTA1t8j49ILX280s03LXfsBtB8wsD3IX7pWdMjv9BiYzy7AchnuyU+ZV53qnjJllV+7CvXIx1TtlzCzL8hfu4wUvyZhZ5uUq3KeOLTIxs+CLqWaWebkK99GJyk4Zz9zNLNvyFe7lbZB+A5OZZV2uwn1svEB/j3fKmFn25SrcR8eTuy9J3iljZtmWs3Av+AYdZpYLuQn3l2eLHC54p4yZ5UNuwn1sotx2wDN3M8uB3IT76Li3QZpZfuQm3MfGZxjo7eKic/qaXYqZ2RmXm3AfLd99yTtlzCwPchPuYxMzvphqZrmRi3A/OlvkcKHohmFmlhu5CPcxtx0ws5zJRbiPlrdBeqeMmeVFLsJ9bHyGwd4uLvROGTPLiVyE++j4DFds804ZM8uPXIT72HiBXb6YamY5kvlwP1JY4Mhs0RdTzSxXMh/u7iljZnmU/XBf6SnjmbuZ5Ufmw310vMBgbxfbh7xTxszyIwfhPsNO75Qxs5zJfLiPTRT85iUzy51Mh/vhwgJHZ4u+mGpmuZPpcB8br7Qd8MVUM8uXbIf7hO++ZGb5lCrcJW2R9KikWUkHJd20ynm9kj4maVzSUUlfkPSqxpac3qEjx+jr7uCCwd5mlWBm1hRpZ+73AUVgG3AzcL+k3XXOuxP4J8BPARcBrwB/2IA6N2R6fpFzN/V4p4yZ5c664S6pH7gRuCsiChHxNPAYcEud038C+FJEjEfEPPBZoN7/BM6K6bkSQ5u6mvXtzcyaJs3MfRewFBGjVcf2Uz+0HwSulnSRpM0ks/wv1ntRSbdJ2idp3+Tk5KnWncr0/CJDfd1n5LXNzFpZmnAfAKZqjk0B9a5SjgKHgH8ApoErgXvqvWhEPBARIxExMjw8nL7iUzA9v8jQJoe7meVPmnAvAEM1x4aAmTrn3g/0AecD/cAjrDJzPxum50oM9XlZxszyJ024jwJdknZWHdsDHKhz7h7goYg4GhELJBdTf1rS1tMv9dR55m5mebVuuEfELMkM/B5J/ZKuBq4DHq5z+jPAeySdI6kbuAP4cUQcbmTRaUQE03NeczezfEq7FfIOYBMwAewFbo+IA5KukVSoOu/fAPPAGDAJvB349QbWm9pscYnlwLtlzCyXUiVfRBwFrq9z/CmSC66Vx0dIdsg03fTcIoBn7maWS5ltPzA9Xw53r7mbWQ5lN9znSoBn7maWTxkO98rM3WvuZpY/2Q33ea+5m1l+ZTfc57zmbmb5ld1wn0/W3Af9DlUzy6HshvvcIpt7OunuzOx/opnZqjKbfO4IaWZ5lt1wdy93M8ux7Ia7Z+5mlmPZDnfvlDGznMpuuLuXu5nlWHbD3TN3M8uxTIa7e7mbWd5lMtzdy93M8i6T4e5e7maWd9kMd/dyN7Ocy2a4u5e7meVcRsPdvdzNLN+yGe7u5W5mOZfNcHcvdzPLuWyGu3u5m1nOZTPc3cvdzHIuk+nnjpBmlneZDPepuUXvlDGzXMtkuCcdIT1zN7P8yma4uyOkmeVcdsPdO2XMLMeyGe5zJc/czSzXMhfuy8vBjHfLmFnOZS7cZ4sl93I3s9xLFe6Stkh6VNKspIOSblrlvC9KKlR9FCV9u7Elr63y7lTP3M0sz9JOb+8DisA24CrgcUn7I+JA9UkR8bbqx5K+Avx1A+pMzX1lzMxSzNwl9QM3AndFRCEingYeA25Z5+t2ANcAD59+men5LkxmZumWZXYBSxExWnVsP7B7na97D/BURPxgo8VtxMqyjNfczSzH0oT7ADBVc2wKGFzn694DPLTak5Juk7RP0r7JyckUZaRTmbmf42UZM8uxNOFeAIZqjg0BM6t9gaSfA7YDf7raORHxQESMRMTI8PBwmlpT8Y06zMzShfso0CVpZ9WxPcCBVc4HuBV4JCIKp1PcRlTun+pe7maWZ+uGe0TMAo8A90jql3Q1cB2rXCiVtAl4J2ssyZxJ0/OL9Pd00uVe7maWY2kT8A5gEzAB7AVuj4gDkq6RVDs7v55kTf7JxpWZ3vScm4aZmaVau4iIoyShXXv8KZILrtXH9pL8D6ApfKMOM7MMth9ImoZ5vd3M8i174e6Zu5lZRsPda+5mlnPZC/e5km/UYWa5l6lwX+nl7pm7meVcpsJ9pZe719zNLOcyFe5uGmZmlshWuLvdr5kZkNVw95q7meVctsLdt9gzMwOyFu4rM3evuZtZvmUr3N3L3cwMyFq4u5e7mRmQtXB3L3czMyBr4e5e7mZmQNbC3R0hzcyArIW7e7mbmQFZC3fP3M3MgCyGu9fczcwyFu7u5W5mBmQo3N3L3czsuMyEu3u5m5kdl5lwdy93M7PjshPu7uVuZrYie+HuNXczswyFu3u5m5mtyEy4T7mXu5nZisyEu9fczcyOy064l2/U4V7uZmZZCve5knu5m5mVZSYJ3VfGzOy47IT7nDtCmplVpAp3SVskPSppVtJBSTetce4/lvRVSQVJ45LubFy5q0tm7l5vNzOD9DP3+4AisA24Gbhf0u7akyRtBZ4APg6cD1wBfLkxpa4t6QjpmbuZGaQId0n9wI3AXRFRiIingceAW+qc/gHgSxHxmYhYiIiZiPhuY0uuz2vuZmbHpZm57wKWImK06th+4KSZO/Am4Kikr0uakPQFSZc0otD1JGvuXpYxM4N04T4ATNUcmwIG65z7auBW4E7gEuAHwN56LyrpNkn7JO2bnJxMX3Edy8vBzELJM3czs7I04V4AhmqODQEzdc6dAx6NiGciYh74D8DPSjqn9sSIeCAiRiJiZHh4+FTrPrHAYolwL3czsxVpwn0U6JK0s+rYHuBAnXOfA6LqceVzbay8dKbdV8bM7ATrhntEzAKPAPdI6pd0NXAd8HCd0z8B/LqkqyR1A3cBT0fEK40sutb0nDtCmplVS7sV8g5gEzBBsoZ+e0QckHSNpELlpIj4a+D3gcfL514BrLonvlEqfWW85m5mlki1jhERR4Hr6xx/iuSCa/Wx+4H7G1JdSu4IaWZ2oky0H/D9U83MTpSNcPfM3czsBNkId/dyNzM7QTbCfa7EQG+Xe7mbmZVlIg2n5916wMysWjbCfc5Nw8zMqmUj3Od9ow4zs2rZCPe5krdBmplVyUa4e+ZuZnaCbIS719zNzE7Q9uG+0svdu2XMzFa0fbiv9HL3zN3MbEXbh7tbD5iZnSwD4e6mYWZmtdo/3Oc9czczq9X+4T7nG3WYmdVq/3Cf9y32zMxqtX+4++bYZmYnaf9wL6+5D/Q63M3MKto/3N3L3czsJG2fiO7lbmZ2svYPd/eVMTM7SfuHuztCmpmdpP3D3b3czcxO0v7h7pm7mdlJ2j/cveZuZnaStg5393I3M6uvrcN9ZsG93M3M6mnrcHcvdzOz+to73OfdV8bMrJ72Dvc5d4Q0M6unvcN93r3czczqSRXukrZIelTSrKSDkm5a5by7JS1KKlR9XNbYko87v7+Ht71+O8ODvWfqW5iZtaW0i9X3AUVgG3AV8Lik/RFxoM65n4uIdzeqwLWM7NjCyI4tZ+NbmZm1lXVn7pL6gRuBuyKiEBFPA48Bt5zp4szMbGPSLMvsApYiYrTq2H5g9yrn/5qko5IOSLp9tReVdJukfZL2TU5OnkLJZma2njThPgBM1RybAgbrnPt54EpgGPgt4N9Jele9F42IByJiJCJGhoeHT6FkMzNbT5pwLwBDNceGgJnaEyPiOxHx44hYioivA/8NeMfpl2lmZqciTbiPAl2SdlYd2wPUu5haKwBtpDAzM9u4dcM9ImaBR4B7JPVLuhq4Dni49lxJ10k6T4mfBn4b+ItGF21mZmtL+yamO4BNwASwF7g9Ig5IukZSoeq83wC+T7Jk8yngIxHxyUYWbGZm60u1zz0ijgLX1zn+FMkF18rjuhdPzczs7FJENLsGJE0CBzf45VuBww0sp5Fc28a0cm3Q2vW5to1p19oujYi62w1bItxPh6R9ETHS7DrqcW0b08q1QWvX59o2Jou1tXXjMDMzq8/hbmaWQVkI9weaXcAaXNvGtHJt0Nr1ubaNyVxtbb/mbmZmJ8vCzN3MzGo43M3MMqhtwz3t3aGaRdJXJM1X3ZHqe02q4/3l1soLkh6qee6XJD0v6ZikJyVd2gq1SdohKWru6HXXWa6tV9KD5Z+tGUnfkvS2quebNnZr1dYiY/dpSS9KmpY0Kum9Vc81+2eubm2tMG5VNe4sZ8enq47dVP77npX055LWv0tRRLTlB0kbhM+RvEP250jaEO9udl1V9X0FeG8L1HEDybuL7wceqjq+tTxm7wT6gP8MfKNFattB0nSuq4nj1g/cXa6lA/hVkrYaO5o9duvU1gpjtxvoLX/+WuAl4A3NHrd1amv6uFXV+GXgKeDTVTXPAG8u593/AD673uukvc1eS6m6O9TrI6IAPC2pcneoDza1uBYTEY8ASBoBXl311A3AgYj4k/LzdwOHJb02Ip5vcm1NF0nDvLurDv2lpB+QBMH5NHHs1qnt2TP9/dcTJ95+M8ofl5PU1+yfudVqO3I2vv96JP0G8ArwdeCK8uGbgS9ExFfL59wFfFfSYESc1Hq9ol2XZU717lDNcq+kw5K+JukXml1Mjd0kYwasBMbf0VpjeFDSjyR9QtLWZhYiaRvJz90BWmzsamqraOrYSfojSceA54EXgf9Ji4zbKrVVNG3cJA0B9wC/U/NU7bj9Hck9rXet9XrtGu6ncneoZvk94DLgVST7VL8g6fLmlnSCVh7Dw8AbgUtJZnuDwGeaVYyk7vL3/2R5htkyY1entpYYu4i4o/y9ryFpGb5Ai4zbKrW1wrh9GHgwIl6oOb6hcWvXcE99d6hmiYj/ExEzEbEQSdvjrwFvb3ZdVVp2DCO5Efu+iChFxDjwfuCt5ZnNWSWpg+TeBcVyHdAiY1evtlYau0juyPY0yZLb7bTIuNWrrdnjJukq4C3AH9R5ekPj1pZr7lTdHSoixsrH0t4dqlla7a5UB4BbKw/K1zEupzXHsPJOu7M6fpIEPAhsA94eEYvlp5o+dmvUVqspY1eji+Pj02o/c5Xaap3tcfsFkou6h5K/WgaATkmvA54gybekIOkyoJckB1fX7CvDp3FF+bMkO2b6gatpod0ywLnAtSQ7ArpILojMAq9pQi1d5TruJZnlVWoaLo/ZjeVjH+Hs71xYrbafAV5D8pvl+SS7op5swth9DPgGMFBzvBXGbrXamjp2wAUkN+0ZADrL/w5mSe7e1tRxW6e2Zo/bZmB71cd/Af60PGa7gWmSZaR+4NOk2C1z1n4Yz8BgbAH+vPyXcwi4qdk1VdU2DDxD8mvTK+V/hL/cpFru5viugMrH3eXn3kJyUWmOZOvmjlaoDXgX8IPy3+2LJHf12n6Wa7u0XM88ya/FlY+bmz12a9XW7LEr/+z/Tfnnfhr4NvBbVc83c9xWra3Z41an1rspb4UsP76pnHOzJLcu3bLea7i3jJlZBrXrBVUzM1uDw93MLIMc7mZmGeRwNzPLIIe7mVkGOdzNzDLI4W5mlkEOdzOzDHK4m5ll0P8HcyDWzinP0RMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(L(learn.recorder.values).itemgot(2));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can view the final accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.982826292514801" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.recorder.values[-1][2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At this point we have something that is rather magical:\n", "\n", "1. A function that can solve any problem to any level of accuracy (the neural network) given the correct set of parameters\n", "1. A way to find the best set of parameters for any function (stochastic gradient descent)\n", "\n", "This is why deep learning can do things which seem rather magical, such fantastic things. Believing that this combination of simple techniques can really solve any problem is one of the biggest steps that we find many students have to take. It seems too good to be true—surely things should be more difficult and complicated than this? Our recommendation: try it out! We just tried it on the MNIST dataset and you have seen the results. And since we are doing everything from scratch ourselves (except for calculating the gradients) you know that there is no special magic hiding behind the scenes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Going Deeper" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no need to stop at just two linear layers. We can add as many as we want, as long as we add a nonlinearity between each pair of linear layers. As you will learn, however, the deeper the model gets, the harder it is to optimize the parameters in practice. Later in this book you will learn about some simple but brilliantly effective techniques for training deeper models.\n", "\n", "We already know that a single nonlinearity with two linear layers is enough to approximate any function. So why would we use deeper models? The reason is performance. With a deeper model (that is, one with more layers) we do not need to use as many parameters; it turns out that we can use smaller matrices with more layers, and get better results than we would get with larger matrices, and few layers.\n", "\n", "That means that we can train the model more quickly, and it will take up less memory. In the 1990s researchers were so focused on the universal approximation theorem that very few were experimenting with more than one nonlinearity. This theoretical but not practical foundation held back the field for years. Some researchers, however, did experiment with deep models, and eventually were able to show that these models could perform much better in practice. Eventually, theoretical results were developed which showed why this happens. Today, it is extremely unusual to find anybody using a neural network with just one nonlinearity.\n", "\n", "Here is what happens when we train an 18-layer model using the same approach we saw in <>:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.0820890.0095780.99705600:11
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = ImageDataLoaders.from_folder(path)\n", "learn = vision_learner(dls, resnet18, pretrained=False,\n", " loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(1, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nearly 100% accuracy! That's a big difference compared to our simple neural net. But as you'll learn in the remainder of this book, there are just a few little tricks you need to use to get such great results from scratch yourself. You already know the key foundational pieces. (Of course, even once you know all the tricks, you'll nearly always want to work with the pre-built classes provided by PyTorch and fastai, because they save you having to think about all the little details yourself.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Jargon Recap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Congratulations: you now know how to create and train a deep neural network from scratch! We've gone through quite a few steps to get to this point, but you might be surprised at how simple it really is.\n", "\n", "Now that we are at this point, it is a good opportunity to define, and review, some jargon and key concepts.\n", "\n", "A neural network contains a lot of numbers, but they are only of two types: numbers that are calculated, and the parameters that these numbers are calculated from. This gives us the two most important pieces of jargon to learn:\n", "\n", "- Activations:: Numbers that are calculated (both by linear and nonlinear layers)\n", "- Parameters:: Numbers that are randomly initialized, and optimized (that is, the numbers that define the model)\n", "\n", "We will often talk in this book about activations and parameters. Remember that they have very specific meanings. They are numbers. They are not abstract concepts, but they are actual specific numbers that are in your model. Part of becoming a good deep learning practitioner is getting used to the idea of actually looking at your activations and parameters, and plotting them and testing whether they are behaving correctly.\n", "\n", "Our activations and parameters are all contained in *tensors*. These are simply regularly shaped arrays—for example, a matrix. Matrices have rows and columns; we call these the *axes* or *dimensions*. The number of dimensions of a tensor is its *rank*. There are some special tensors:\n", "\n", "- Rank zero: scalar\n", "- Rank one: vector\n", "- Rank two: matrix\n", "\n", "A neural network contains a number of layers. Each layer is either *linear* or *nonlinear*. We generally alternate between these two kinds of layers in a neural network. Sometimes people refer to both a linear layer and its subsequent nonlinearity together as a single layer. Yes, this is confusing. Sometimes a nonlinearity is referred to as an *activation function*.\n", "\n", "<> summarizes the key concepts related to SGD.\n", "\n", "```asciidoc\n", "[[dljargon1]]\n", ".Deep learning vocabulary\n", "[options=\"header\"]\n", "|=====\n", "| Term | Meaning\n", "|ReLU | Function that returns 0 for negative numbers and doesn't change positive numbers.\n", "|Mini-batch | A small group of inputs and labels gathered together in two arrays. A gradient descent step is updated on this batch (rather than a whole epoch).\n", "|Forward pass | Applying the model to some input and computing the predictions.\n", "|Loss | A value that represents how well (or badly) our model is doing.\n", "|Gradient | The derivative of the loss with respect to some parameter of the model.\n", "|Backward pass | Computing the gradients of the loss with respect to all model parameters.\n", "|Gradient descent | Taking a step in the directions opposite to the gradients to make the model parameters a little bit better.\n", "|Learning rate | The size of the step we take when applying SGD to update the parameters of the model.\n", "|=====\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> note: _Choose Your Own Adventure_ Reminder: Did you choose to skip over chapters 2 & 3, in your excitement to peek under the hood? Well, here's your reminder to head back to chapter 2 now, because you'll be needing to know that stuff very soon!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Questionnaire" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. How is a grayscale image represented on a computer? How about a color image?\n", "1. How are the files and folders in the `MNIST_SAMPLE` dataset structured? Why?\n", "1. Explain how the \"pixel similarity\" approach to classifying digits works.\n", "1. What is a list comprehension? Create one now that selects odd numbers from a list and doubles them.\n", "1. What is a \"rank-3 tensor\"?\n", "1. What is the difference between tensor rank and shape? How do you get the rank from the shape?\n", "1. What are RMSE and L1 norm?\n", "1. How can you apply a calculation on thousands of numbers at once, many thousands of times faster than a Python loop?\n", "1. Create a 3×3 tensor or array containing the numbers from 1 to 9. Double it. Select the bottom-right four numbers.\n", "1. What is broadcasting?\n", "1. Are metrics generally calculated using the training set, or the validation set? Why?\n", "1. What is SGD?\n", "1. Why does SGD use mini-batches?\n", "1. What are the seven steps in SGD for machine learning?\n", "1. How do we initialize the weights in a model?\n", "1. What is \"loss\"?\n", "1. Why can't we always use a high learning rate?\n", "1. What is a \"gradient\"?\n", "1. Do you need to know how to calculate gradients yourself?\n", "1. Why can't we use accuracy as a loss function?\n", "1. Draw the sigmoid function. What is special about its shape?\n", "1. What is the difference between a loss function and a metric?\n", "1. What is the function to calculate new weights using a learning rate?\n", "1. What does the `DataLoader` class do?\n", "1. Write pseudocode showing the basic steps taken in each epoch for SGD.\n", "1. Create a function that, if passed two arguments `[1,2,3,4]` and `'abcd'`, returns `[(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')]`. What is special about that output data structure?\n", "1. What does `view` do in PyTorch?\n", "1. What are the \"bias\" parameters in a neural network? Why do we need them?\n", "1. What does the `@` operator do in Python?\n", "1. What does the `backward` method do?\n", "1. Why do we have to zero the gradients?\n", "1. What information do we have to pass to `Learner`?\n", "1. Show Python or pseudocode for the basic steps of a training loop.\n", "1. What is \"ReLU\"? Draw a plot of it for values from `-2` to `+2`.\n", "1. What is an \"activation function\"?\n", "1. What's the difference between `F.relu` and `nn.ReLU`?\n", "1. The universal approximation theorem shows that any function can be approximated as closely as needed using just one nonlinearity. So why do we normally use more?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Further Research" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. Create your own implementation of `Learner` from scratch, based on the training loop shown in this chapter.\n", "1. Complete all the steps in this chapter using the full MNIST datasets (that is, for all digits, not just 3s and 7s). This is a significant project and will take you quite a bit of time to complete! You'll need to do some of your own research to figure out how to overcome some obstacles you'll meet on the way." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }