{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Mini-batch Training from Foundations\n", "\n", "#### Last Time\n", "[Most recently](http://nbviewer.jupyter.org/github/jamesdellinger/fastai_deep_learning_course_part2_v3/blob/master/02_fully_connected_my_reimplementation.ipynb?flush_cache=true) we saw how to implement from scratch both the forward and backward passes of a neural network.\n", "\n", "After an extended focus on weight initialization, by which we saw how to derive the basic principals that underpin the now widely-used [Kaiming weight init](02_fully_connected_my_reimplementation.ipynb), we spent some extended time refactoring the code of our forward/backward passes. \n", "\n", "We learned that organizing this code into classes was more concise and interpretable (by human readers) than leaving the logic inside sundry, scattered methods. Finally, we wrapped things up by creating our own `Module()` class that's similar to PyTorch's `nn.Module`, and let our custom loss/linear layer/ReLU classes inherit from it.\n", "\n", "According to the rules we set for ourselves at the beginning of this course, we're free to use the PyTorch versions of all classes/functionalities we've thus far implemented from scratch.\n", "\n", "#### Minibatch Training\n", "Today we'll implement a model that supports another must-have feature of any deep learning model: the ability to train using mini-batches.\n", "\n", "Mini-batches allow us to update our model weights by leveraging the parallel processing capability of Nvidia GPUs to train on several inputs *at the same time*. \n", "\n", "This allows our model to complete a single pass through all the training samples in our dataset in *a much shorter amount of time* than if it were to have to train and update weights for each and every single input in the training set, one at a time!\n", "\n", "#### What Components We Implement Here\n", "In building up to being able to create a model that can successfully train on mini-batches, we implement from scratch several other crucial components below. These include:\n", "* Cross entropy loss\n", "* Updating and registering model parameters\n", "* Optimzer classes\n", "* Dataset and Dataloader classes\n", "* Random Sampling\n", "* Setting aside a Validation Set\n", "\n", "#### Attribution\n", "Virtually all the code that appears in this notebook is the creation of [Sylvain Gugger](https://www.fast.ai/about/#sylvain) and [Jeremy Howard](https://www.fast.ai/about/#jeremy). The original version of this notebook that they made for the course lecture can be found [here](https://github.com/fastai/course-v3/blob/master/nbs/dl2/03_minibatch_training.ipynb). I simply re-typed, line-by-line, the pieces of logic necessary to implement the functionality that their notebook demonstrated. In some cases I changed the order of code cells and or variable names so as to fit an organization and style that seemed more intuitive to me. Any and all mistakes are my own.\n", "\n", "On the other hand, all long-form text explanations in this notebook are solely my own creation. Writing extensive descriptions of the concepts and code in plain and simple English forces me to make sure that I actually understand how they work." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#export\n", "from exports.nb_02 import *\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing the Data\n", "We continue to use the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset as a baseline to test the functionality and performance of all the classes we create from scratch." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "mpl.rcParams['image.cmap'] = 'gray'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "x_train, y_train, x_valid, y_valid = get_data()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "n,m = x_train.shape # 50,000 images x 784 pixels per image\n", "c = int(y_train.max()) + 1 # number of classes in dataset\n", "nh = 50 # size of hidden layers" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh, n_out)]\n", " \n", " def __call__(self, x):\n", " for l in self.layers: x = l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model = Model(m, nh, c)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "pred = model(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cross Entropy Loss\n", "\n", "In the previous notebook we used a quick-and-dirty mean squared error loss function just so we could have a simple loss function to use to test whether our model was correctly calculating weight gradients. Now, however, it's time to implement a loss function which is better tailored to the MNIST task, which entails predicting one class (out of ten total) to which a handwriting sample of a single-digit number most likely belongs.\n", "\n", "#### Log softmax\n", "To build cross entropy loss, we first calculate the softmax of our activations, $$\\textrm{softmax}(x)_{i} = \\frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + ... + e^{x_{n-1}}}$$ or more concisely, $$\\textrm{softmax}(x)_{i} = \\frac{e^{x_{i}}}{\\sum_{0\\leq{j}\\leq{n-1}}e^{x_{j}}}$$\n", "\n", "Note that in practice, we need to take the log of the softmax in order to calculate the loss." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return (x.exp()/x.exp().sum(-1,keepdim=True)).log()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0887, -0.0349, -0.0814, ..., 0.0308, -0.0196, -0.0004],\n", " [ 0.0576, -0.0033, 0.0132, ..., 0.0507, -0.0592, 0.0061],\n", " [ 0.0516, -0.0581, 0.0250, ..., 0.0023, -0.0396, -0.0962],\n", " ...,\n", " [ 0.0841, -0.1409, -0.0611, ..., 0.0254, -0.0366, 0.0731],\n", " [-0.0209, -0.0425, 0.0053, ..., -0.0104, 0.0560, 0.1751],\n", " [ 0.0554, -0.1532, 0.0325, ..., 0.0359, -0.1168, 0.0629]],\n", " grad_fn=)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([50000, 10])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred.shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.0887, -0.0349, -0.0814, -0.0429, 0.0312, 0.1822, 0.0856, 0.0308,\n", " -0.0196, -0.0004], grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our model generates, as predictions, a list of length 10 (the number of categories in MNIST) for each of the 50,000 input images. The problem is, if we just use these lists of ten predictions, we really have no standardized way of ascertaining and comparing the degree to which the model believes that a target image belongs to each of the ten categories.\n", "\n", "The softmax function we introduced just above, however, thankfully gives us a way to do this. Softmax will take the list of ten predictions for each image, and turn it into a list of ten probabilities that all sum to 1." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "softmax_pred = (pred.exp()/pred.exp().sum(-1,keepdim=True))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.1064, 0.0940, 0.0898, 0.0933, 0.1004, 0.1168, 0.1061, 0.1004, 0.0955,\n", " 0.0973], grad_fn=)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax_pred[0]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1., grad_fn=)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax_pred[0].sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, in practice we use the log of the softmax instead of just softmax. Why? We like using logarithms because they have a nice property that lets us use subtraction instead of division. Avoiding division is one surefire way to make our loss calculation more numerically stable. [The answer here](https://discuss.pytorch.org/t/logsoftmax-vs-softmax/21386/4) gives a nice explanation." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "logsoftmax_pred = log_softmax(pred)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.2406, -2.3642, -2.4107, -2.3722, -2.2981, -2.1470, -2.2436, -2.2985,\n", " -2.3489, -2.3297], grad_fn=)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logsoftmax_pred[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### To one-hot encode or not to one-hot encode\n", "The cross entropy loss between a one-hot encoded target $x$ and a prediction $p(x)$ generated by a model is $$-\\sum_{i=1}^{n}{x_{i}\\log{\\left(p_{i}(x)\\right)}}$$ where each $i$ is one of the one-hot encoded label's $n$ total categories.\n", "\n", "In other words, it's the sum of the products of the values at all indices in the label list $x$ with the prediction probabilities at corresponding indices in the prediction list $p(x)$. Recall that the values in $x$ will one or zero since the label for $x$ is one-hot encoded. \n", "\n", "Here's a very intuitive [explanation](https://youtu.be/AcA8HAYh7IE?t=1925) using Excel.\n", "\n", "Here's a nice, concrete example. Remember that the first training sample's ground truth label is the '5' digit:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(5)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHOlJREFUeJzt3XvMbXV5J/DvU06VgQiIaSVtxyJMgRSLDKgodLgGlWm1IDDRxJa0aNoOjmKVtLHawbY0TFpviKOkthAxERtMtVoqTAQEi6UBioxBQQuUocUiIPeLPZzf/LHXqadv3/dc9trn3e/57c8n2Vlnr7We/XtYrJzvWXuvS7XWAgD06Yfm3QAAsP0IegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDo2Lp5N7A9VNWdSXZLctecWwGAae2d5JHW2gvHfEiXQZ9JyO85vABgYc31q/uq+omq+tOq+qeqerqq7qqqD1bVc0d+9F2z6A8A5uyusR8wtyP6qto3yXVJfjTJ55J8M8nLkrwtyaur6ojW2gPz6g8AejDPI/r/nUnIv7W1dmJr7bdaa8cm+UCS/ZOcM8feAKAL1Vpb/UGr9kny95l8JbFva23DJsuek+TeJJXkR1trj0/x+TcmOWQ23QLA3NzUWjt0zAfM64j+2GF6xaYhnySttUeT/HWSXZK8fLUbA4CezOs3+v2H6e0rLP9Wklcm2S/Jl1b6kOHIfTkHTN8aAPRjXkf0uw/Th1dYvnH+HqvQCwB0a61eR1/DdLMnEKz0u4Xf6AFgYl5H9BuP2HdfYfluS9YDAKYwr6C/bZjut8LynxqmK/2GDwBshXkF/VXD9JVV9W96GC6vOyLJk0n+ZrUbA4CezCXoW2t/n+SKTG7Yf8aSxe9NsmuST0xzDT0A8APzPBnvv2dyC9zzquq4JN9IcliSYzL5yv6359gbAHRhbrfAHY7qX5LkokwC/h1J9k1yXpJXuM89AIw318vrWmv/L8kvz7MHAOjZXB9TCwBsX4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADq2bt4NANPbaaedRtXvvvvuM+pk9b3lLW+ZunaXXXYZNfb+++8/de0ZZ5wxauw/+qM/mrr2DW94w6ixn3rqqalrzz333FFjv/e97x1Vv8jmdkRfVXdVVVvh9Z159QUAPZn3Ef3DST64zPzHVrsRAOjRvIP+odba2XPuAQC65WQ8AOjYvI/on11Vb0zygiSPJ7klyTWttWfm2xYA9GHeQb9XkouXzLuzqn65tfblLRVX1Y0rLDpgdGcA0IF5fnV/YZLjMgn7XZP8TJILkuyd5K+q6sXzaw0A+jC3I/rW2tKLIr+e5Neq6rEk70hydpKTtvAZhy43fzjSP2QGbQLADm0tnoz3sWF65Fy7AIAOrMWgv2+Y7jrXLgCgA2sx6F8xTO+YaxcA0IG5BH1VHVhVey4z/yeTnD+8/eTqdgUA/ZnXyXinJvmtqroqyZ1JHk2yb5KfS7JzksuSTP/kBgAgyfyC/qok+yf5z5l8Vb9rkoeSfCWT6+ovbq21OfUGAN2YS9APN8PZ4g1xYGu94AUvGFX/rGc9a+raww8/fNTYP/uzPzt17R577DFq7JNPPnlU/aK65557pq4977zzRo190kmbvep4sx599NFRY3/ta1+buvbLX/ZX/rysxZPxAIAZEfQA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdq9bavHuYuaq6Mckh8+6DbXPwwQdPXXvllVeOGnv33XcfVc+OZcOGDaPqf+VXfmXq2scee2zU2GPce++9o+q/973vTV172223jRp7gd3UWjt0zAc4ogeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOjYunk3ABvdfffdU9c+8MADo8b2mNptd/3114+qf+ihh0bVH3PMMVPXfv/73x819sUXXzyqHlaTI3oA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6Jjn0bNmPPjgg1PXnnXWWaPG/vmf//mpa//u7/5u1NjnnXfeqPoxbr755qlrjz/++FFjP/7446PqDzzwwKlr3/a2t40aG3YkjugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6Vq21efcwc1V1Y5JD5t0HO47ddttt6tpHH3101NgXXHDB1LWnn376qLHf+MY3Tl37qU99atTYwFa5qbV26JgPmMkRfVWdUlUfrqprq+qRqmpV9ckt1BxeVZdV1YNV9URV3VJVZ1bVTrPoCQBI1s3oc96d5MVJHktyT5IDNrdyVf1Cks8keSrJp5M8mOQ1ST6Q5Igkp86oLwBYaLP6jf7tSfZLsluSX9/cilW1W5I/TvJMkqNba6e31s5KcnCSryY5papeP6O+AGChzSToW2tXtda+1bbuB/9TkvxIkktaazds8hlPZfLNQLKFfywAAFtnHmfdHztMv7jMsmuSPJHk8Kp69uq1BAB9mkfQ7z9Mb1+6oLW2PsmdmZw7sM9qNgUAPZrVyXjbYvdh+vAKyzfO32NLHzRcRreczZ4MCACLYi3eMKeGaX8X+APAKpvHEf3GI/bdV1i+25L1VrTSTQTcMAcAJuZxRH/bMN1v6YKqWpfkhUnWJ7ljNZsCgB7NI+ivHKavXmbZkUl2SXJda+3p1WsJAPo0j6C/NMn9SV5fVS/ZOLOqdk7y+8Pbj86hLwDozkx+o6+qE5OcOLzda5i+oqouGv58f2vtnUnSWnukqt6cSeBfXVWXZHIL3NdmcundpZncFhcAGGlWJ+MdnOS0JfP2yQ+uhf+HJO/cuKC19tmqOirJbyc5OcnOSb6d5DeSnLeVd9gDALZgJkHfWjs7ydnbWPPXSf7rLMYHAJY3j8vrYM155JFH5jb2ww9v8UrS7ebNb37z1LWf/vS4X9g2bNgwqh7YOmvxhjkAwIwIegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDoWLXW5t3DzFXVjUkOmXcfsDV23XXXqWs///nPjxr7qKOOmrr2hBNOGDX2FVdcMaoeFsRNrbVDx3yAI3oA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6Jjn0cMObN999x1Vf9NNN01d+9BDD40a+6qrrhpVf8MNN0xd+5GPfGTU2D3+vcma5Xn0AMDKBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdMxjamGBnXTSSVPXXnjhhaPGfs5znjOqfox3vetdo+o/8YlPTF177733jhqbheMxtQDAygQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxzyPHpjKi170olH173//+0fVH3fccaPqx7jgggumrj3nnHNGjf2P//iPo+rZ4ayN59FX1SlV9eGquraqHqmqVlWfXGHdvYflK70umUVPAECybkaf8+4kL07yWJJ7khywFTVfS/LZZeZ/fUY9AcDCm1XQvz2TgP92kqOSXLUVNTe31s6e0fgAwDJmEvSttX8N9qqaxUcCADMwqyP6afxYVf1qkucleSDJV1trt8yxHwDozjyD/vjh9a+q6uokp7XW7t6aDxjOrl/O1pwjAADdm8d19E8k+b0khyZ57vDa+Lv+0Um+VFW7zqEvAOjOqh/Rt9buS/I7S2ZfU1WvTPKVJIcleVOSD23FZy17baHr6AFgYs3cGa+1tj7Jx4e3R86zFwDoxZoJ+sF3h6mv7gFgBtZa0L98mN4x1y4AoBOrHvRVdVhVPWuZ+cdmcuOdJFn29rkAwLaZycl4VXVikhOHt3sN01dU1UXDn+9vrb1z+PP/SnLgcCndPcO8g5IcO/z5Pa2162bRFwAsulmddX9wktOWzNtneCXJPyTZGPQXJzkpyUuTnJDkh5P8c5I/S3J+a+3aGfUEAAtvVrfAPTvJ2Vu57p8k+ZNZjAsAbJ7n0QNzsccee4yqf81rXjN17YUXXjhq7DHP9LjyyitHjX388cdveSV6sjaeRw8ArE2CHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65jG1wMJ5+umnR9WvW7du6tr169ePGvtVr3rV1LVXX331qLGZC4+pBQBWJugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6Nv1DlYGFdtBBB42qP+WUU0bVv/SlL526dszz5Me69dZbR9Vfc801M+qEReGIHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGMeUws7sP33339U/Vve8papa1/3uteNGnuvvfYaVT9PzzzzzNS1995776ixN2zYMKqexeOIHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65nn0MNLY56q/4Q1vmLp2zPPkk2TvvfceVb+juuGGG0bVn3POOVPX/sVf/MWosWFbjT6ir6rnVdWbqurPq+rbVfVkVT1cVV+pqtOratkxqurwqrqsqh6sqieq6paqOrOqdhrbEwAwMYsj+lOTfDTJvUmuSnJ3kucneV2Sjyc5oapOba21jQVV9QtJPpPkqSSfTvJgktck+UCSI4bPBABGmkXQ357ktUn+srW2YePMqnpXkr9NcnImof+ZYf5uSf44yTNJjm6t3TDMf0+SK5OcUlWvb61dMoPeAGChjf7qvrV2ZWvt85uG/DD/O0k+Nrw9epNFpyT5kSSXbAz5Yf2nkrx7ePvrY/sCALb/Wff/MkzXbzLv2GH6xWXWvybJE0kOr6pnb8/GAGARbLez7qtqXZJfGt5uGur7D9Pbl9a01tZX1Z1JDkyyT5JvbGGMG1dYdMC2dQsAfdqeR/TnJnlRkstaa5dvMn/3YfrwCnUb5++xvRoDgEWxXY7oq+qtSd6R5JtJfnFby4dp2+xaSVprh64w/o1JDtnGcQGgOzM/oq+qM5J8KMmtSY5prT24ZJWNR+y7Z3m7LVkPAJjSTIO+qs5Mcn6Sr2cS8t9ZZrXbhul+y9SvS/LCTE7eu2OWvQHAIppZ0FfVb2Zyw5ubMwn5+1ZY9cph+upllh2ZZJck17XWnp5VbwCwqGYS9MPNbs5NcmOS41pr929m9UuT3J/k9VX1kk0+Y+ckvz+8/egs+gKARTf6ZLyqOi3J72Zyp7trk7y1qpaudldr7aIkaa09UlVvziTwr66qSzK5Be5rM7n07tJMbosLAIw0i7PuXzhMd0py5grrfDnJRRvftNY+W1VHJfntTG6Ru3OSbyf5jSTnbXpffABgetVjprq8bvE8//nPH1X/0z/901PXnn/++aPGPuCAxby/0/XXXz+q/g//8A+nrv3c5z43auwNGzZseSWYjZtWupR8a23vW+ACAHMk6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADq2bt4N0I8999xzVP0FF1wwde3BBx88aux99tlnVP2O6rrrrpu69n3ve9+osS+//PJR9U8++eSoelgUjugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65jG1nTnssMNG1Z911llT177sZS8bNfaP//iPj6rfUT3xxBNT15533nmjxv6DP/iDqWsff/zxUWMDq8MRPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0zPPoO3PSSSfNtX5ebr311lH1X/jCF6auXb9+/aix3/e+901d+9BDD40aG+ifI3oA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COVWtt3j3MXFXdmOSQefcBACPd1Fo7dMwHjD6ir6rnVdWbqurPq+rbVfVkVT1cVV+pqtOr6oeWrL93VbXNvC4Z2xMAMLFuBp9xapKPJrk3yVVJ7k7y/CSvS/LxJCdU1ant33918LUkn13m874+g54AgMwm6G9P8tokf9la27BxZlW9K8nfJjk5k9D/zJK6m1trZ89gfABgBaO/um+tXdla+/ymIT/M/06Sjw1vjx47DgCw7WZxRL85/zJM1y+z7Meq6leTPC/JA0m+2lq7ZTv3AwALZbsFfVWtS/JLw9svLrPK8cNr05qrk5zWWrt7e/UFAItkex7Rn5vkRUkua61dvsn8J5L8XiYn4t0xzDsoydlJjknypao6uLX2+JYGGC6jW84B0zYNAD3ZLtfRV9Vbk3woyTeTHNFae3AratYl+UqSw5Kc2Vr70FbUbC7od9n6jgFgTRp9Hf3Mj+ir6oxMQv7WJMdtTcgnSWttfVV9PJOgP3L4jC3VLPsf74Y5ADAx01vgVtWZSc7P5Fr4Y4Yz77fFd4fprrPsCwAW1cyCvqp+M8kHktycScjfN8XHvHyY3rHZtQCArTKToK+q92Ry8t2NmXxdf/9m1j2sqp61zPxjk7x9ePvJWfQFAItu9G/0VXVakt9N8kySa5O8taqWrnZXa+2i4c//K8mBw6V09wzzDkpy7PDn97TWrhvbFwAwm5PxXjhMd0py5grrfDnJRcOfL05yUpKXJjkhyQ8n+eckf5bk/NbatTPoCQCIx9QCwFo2/8fUAgBrl6AHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI71GvR7z7sBAJiBvcd+wLoZNLEWPTJM71ph+QHD9Jvbv5Vu2GbTsd2mY7ttO9tsOmt5u+2dH+TZ1Kq1Nr6VHUxV3ZgkrbVD593LjsI2m47tNh3bbdvZZtNZhO3W61f3AEAEPQB0TdADQMcEPQB0TNADQMcW8qx7AFgUjugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGMLFfRV9RNV9adV9U9V9XRV3VVVH6yq5867t7Vq2EZthdd35t3fvFTVKVX14aq6tqoeGbbHJ7dQc3hVXVZVD1bVE1V1S1WdWVU7rVbf87Yt262q9t7Mvteq6pLV7n8equp5VfWmqvrzqvp2VT1ZVQ9X1Veq6vSqWvbv8UXf37Z1u/W8v/X6PPp/p6r2TXJdkh9N8rlMnj38siRvS/LqqjqitfbAHFtcyx5O8sFl5j+22o2sIe9O8uJMtsE9+cEzrZdVVb+Q5DNJnkry6SQPJnlNkg8kOSLJqduz2TVkm7bb4GtJPrvM/K/PsK+17NQkH01yb5Krktyd5PlJXpfk40lOqKpT2yZ3P7O/JZliuw36299aawvxSnJ5kpbkfyyZ//5h/sfm3eNafCW5K8ld8+5jrb2SHJPkp5JUkqOHfeiTK6y7W5L7kjyd5CWbzN85k398tiSvn/d/0xrcbnsPyy+ad99z3mbHZhLSP7Rk/l6ZhFdLcvIm8+1v0223bve3hfjqvqr2SfLKTELrI0sW/88kjyf5xaradZVbYwfVWruqtfatNvwNsQWnJPmRJJe01m7Y5DOeyuQIN0l+fTu0ueZs43YjSWvtytba51trG5bM/06Sjw1vj95kkf0tU223bi3KV/fHDtMrlvmf/mhV/XUm/xB4eZIvrXZzO4BnV9Ubk7wgk38U3ZLkmtbaM/Nta4excf/74jLLrknyRJLDq+rZrbWnV6+tHcaPVdWvJnlekgeSfLW1dsuce1or/mWYrt9knv1ty5bbbht1t78tStDvP0xvX2H5tzIJ+v0i6JezV5KLl8y7s6p+ubX25Xk0tINZcf9rra2vqjuTHJhknyTfWM3GdhDHD69/VVVXJzmttXb3XDpaA6pqXZJfGt5uGur2t83YzHbbqLv9bSG+uk+y+zB9eIXlG+fvsQq97GguTHJcJmG/a5KfSXJBJr9n/VVVvXh+re0w7H/TeSLJ7yU5NMlzh9dRmZxYdXSSLy34z23nJnlRkstaa5dvMt/+tnkrbbdu97dFCfotqWHqd8MlWmvvHX7r+ufW2hOtta+31n4tk5MY/0OSs+fbYRfsf8tord3XWvud1tpNrbWHhtc1mXz7dn2S/5TkTfPtcj6q6q1J3pHJ1UO/uK3lw3Th9rfNbbee97dFCfqN/4LdfYXluy1Zjy3beDLLkXPtYsdg/5uh1tr6TC6PShZw/6uqM5J8KMmtSY5prT24ZBX72zK2Yrstq4f9bVGC/rZhut8Ky39qmK70Gz7/3n3DdIf8KmuVrbj/Db8XvjCTk4LuWM2mdnDfHaYLtf9V1ZlJzs/kmu5jhjPIl7K/LbGV221zduj9bVGC/qph+spl7ob0nExuIPFkkr9Z7cZ2YK8Ypgvzl8UIVw7TVy+z7MgkuyS5boHPgJ7Gy4fpwux/VfWbmdzw5uZMwuq+FVa1v21iG7bb5uzQ+9tCBH1r7e+TXJHJCWRnLFn83kz+lfaJ1trjq9zamlZVB1bVnsvM/8lM/nWcJJu97StJkkuT3J/k9VX1ko0zq2rnJL8/vP3oPBpby6rqsKp61jLzj03y9uHtQux/VfWeTE4iuzHJca21+zezuv1tsC3bref9rRblvhXL3AL3G0kOy+ROXbcnOby5Be6/UVVnJ/mtTL4RuTPJo0n2TfJzmdxl67IkJ7XWvj+vHuelqk5McuLwdq8kr8rkX/vXDvPub629c8n6l2ZyS9JLMrkl6WszuRTq0iT/bRFuIrMt2224pOnAJFdncrvcJDkoP7hO/D2ttY3B1a2qOi3JRUmeSfLhLP/b+l2ttYs2qVn4/W1bt1vX+9u8b823mq8k/zGTy8XuTfL9JP+QyckZe867t7X4yuTSkk9lcobqQ5ncZOK7Sf5PJteh1rx7nOO2OTuTs5ZXet21TM0Rmfzj6HuZ/FT0fzM5Uthp3v89a3G7JTk9yRcyuaPlY5nc0vXuTO7d/l/m/d+yhrZZS3K1/W3cdut5f1uYI3oAWEQL8Rs9ACwqQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANCx/w/nY//ADdkdRAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 250, "width": 253 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(x_train[0].view(28,28))\n", "y_train[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The one-hot encoded label, $x$, for this image is:" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label = tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]).float()\n", "label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we just saw above, the log softmax predictions $p(x)$ are:" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.2406, -2.3642, -2.4107, -2.3722, -2.2981, -2.1470, -2.2436, -2.2985,\n", " -2.3489, -2.3297])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prediction = logsoftmax_pred[0].detach()\n", "prediction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "According to the formula for cross entropy loss, here's how we calculate the loss for the model's prediction for this image:" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.1470)" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-(label * prediction).sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now you've probably already noticed that our training labels *aren't* actually one-hot encoded:" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 0, 4, ..., 8, 4, 8])" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([50000])" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Indeed, each image's label is just a single integer that indicates the correct label for the digit depicted in each image. It turns out that not only do these integers represent actual digit names (which is a nice side-effect of MNIST only having 10 categories), but they also represent the *index* of the correct digit in a one-hot encoded label. i.e. our first training image is a '5' and so there is a one at index 5 in its one-hot encoded label:\n", "\n", "```\n", "[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]\n", "```\n", "\n", "When we summed our products in our first attempt at calculating cross entropy loss for the first training image above, you probably noticed that the results of nine of the ten products were *zero*. At this point it should be clear to see that we might as well not waste time computing products that are gonna be zero anyhow.\n", "\n", "Indeed, why don't we just calculate the one product that *we know* will return a non-zero value, and not bother with the nine other products and also not bother adding a bunch of zeros to the only non-zero product? \n", "\n", "We can use [numpy-style integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html) to accomplish this. It's as simple as indexing into the prediction list for a single image and grabbing the log softmax value that sits at the index that we know represents the ground truth category of the image.\n", "\n", "For example, once again, here's the ground truth label of the first training sample:" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(5)" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "correct_label_index = y_train[0]\n", "correct_label_index" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here's the log softmax value at the 5th index of our model's prediction for this training image, which is the index that corresponds to the correct ground truth label (the number 5, out of numbers 0 through 9, inclusive)." ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(-2.1470, grad_fn=)" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logsoftmax_pred[0][correct_label]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we take this approach, we can rewrite cross-entropy loss in a much more simple way. It's just $$-\\log{\\left(p_{i}(x)\\right)}$$ where $i$ is the index belonging to the target's ground truth class.\n", "\n", "Using this formula, here's the categorical cross entropy loss for the first training image:" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.1470, grad_fn=)" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-logsoftmax_pred[0][correct_label_index]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How wonderful is that! It's the exact same value we got when we summed the ten products of the prediction and its one-hot encoded label, but without all the needless arithmetic. When you're training for hundreds of epochs over potentially millions of images, any optimizations that speed things up on a per-image basis can substantially decrease overall training time!\n", "\n", "Note that it's also easy to do this for several images at once. Here's how we'd index into our predictions to find softmax values at the appropriate indices for the first three training images:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 0, 4])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train[:3]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.3642, -2.2576, -2.2754], grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logsoftmax_pred[[0,1,2], [5,0,4]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above indexing mechanism grabs the first three predictions, or lists of length ten, from our tensor that contains these lists of training image category predictions. Then, we go into each of these three lists to grab the log softmax value that is sitting at the index that corresponds to the ground truth category of the image. The first training image is a '5' digit, so we want the log softmax value that's sitting at index 5. The second training image is a '0' digit, so we want to get the log softmax value that is sitting at index 0 of its prediction list of ten log softmax values. And so on and so forth.\n", "\n", "#### Negative log-likelihood\n", "Now that we've obtained the softmax log of our predictions we're ready to compute the actual cross-entropy loss.\n", "\n", "The negative log-likelihood function is how we do that." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def nll(input, target): return -input[range(target.shape[0]), target].mean()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.3182, grad_fn=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = nll(logsoftmax_pred, y_train)\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we proceed further, let's pause and remember that thanks to a property of logarithms, we can rewrite the cross entropy function in a more computationally efficient structure (we get rid of the division operation): $$\\log{\\left(\\frac{a}{b}\\right)} = \\log{\\left(a\\right)} - \\log{\\left(b\\right)}$$ \n", "\n", "Let's write a new version of `log_softmax()` that takes advantage of this." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's ensure that this refactoring is computationally accurate:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "test_near(nll(log_softmax(pred), y_train), loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're almost done, but there's one more really helpful tweak that we should build in. \n", "\n", "The [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp) lets us use the following formula to compute the sum of exponentials in a more stable manner: $$\\log{\\left(\\sum_{j=1}^{n}e^{x_{j}}\\right)} = \\log{\\left(e^{a}\\sum_{j=1}^{n}e^{x_{j}-a}\\right)} = a + \\log{\\left(\\sum_{j=1}^{n}e^{x_{j}-a}\\right)}$$ where $a$ is the maximum of all $x_{j}$.\n", "\n", "Given that to calculate cross entropy we have to take a sum of exponential terms (as evidenced by the \n", "```\n", "x.exp().sum(-1,keepdim=True)\n", "``` \n", "portion of the above `log_softmax()` function), implementing a revised version of our `log_softmax()` function that uses this trick will ensure we avoid an overflow if we have to take the exponential of a big activation." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def logsumexp(x):\n", " m = x.max(-1)[0] # take the max along the highest dimension of the tensor\n", " return m + (x - m[:,None]).exp().sum(-1).log()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch has `logsumexp` as a built-in method so lets compare it to ours:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "test_near(logsumexp(pred), # ours\n", " pred.logsumexp(-1) # PyTorch's\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's our final refactored `log_softmax()` with `logsumexp`:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify that our latest log_softmax refactoring is still correct:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "test_near(nll(log_softmax(pred), y_train), loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify the same for PyTorch's own functions:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "test_near(F.cross_entropy(pred, y_train), loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch combines `F.nll_loss` and `F.log_softmax` into one optimized function called `F.cross_entropy`. " ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "test_near(F.cross_entropy(pred, y_train), loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating a Basic Training Loop\n", "\n", "To complete one training loop, our model must be able to perform the following:\n", "1. Get the output of the model on **a batch** of inputs.\n", "* Compare the output to the labels and compute a loss.\n", "* Calculate the gradients of the loss with respect to every parameter in the model.\n", "* Update model parameters with their gradients in order to make the parameters a little bit better.\n", "\n", "Below we implement each of these steps in successive lines of code. Further down in this notebook we will see how to refactor into specific classes that manage tasks like storing a dataset, loading the data, etc." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "loss_func = F.cross_entropy" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([ 0.0942, 0.0491, -0.0652, -0.1084, 0.0840, -0.0479, -0.0238, 0.0033,\n", " 0.0379, 0.0908], grad_fn=), torch.Size([64, 10]))" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs = 64 # batch size\n", " \n", "xb = x_train[0:bs] # a mini-batch from inputs x \n", "preds = model(xb) # predictions on items in the mini-batch\n", "preds[0], preds.shape" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.3140, grad_fn=)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yb = y_train[0:bs]\n", "loss_func(preds, yb)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.1406)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(preds, yb)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "lr = 0.5 # learning rate\n", "epochs = 1 # number of epochs to train for" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "for epoch in range(epochs):\n", " for i in range((n-1)//bs + 1):\n", " start_i = i*bs\n", " end_i = start_i+bs\n", " xb = x_train[start_i:end_i]\n", " yb = y_train[start_i:end_i]\n", " loss = loss_func(model(xb), yb)\n", " \n", " loss.backward()\n", " with torch.no_grad():\n", " for l in model.layers:\n", " if hasattr(l, 'weight'):\n", " l.weight -= l.weight.grad * lr\n", " l.bias -= l.bias.grad * lr\n", " l.weight.grad.zero_()\n", " l.bias .grad.zero_()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.2697, grad_fn=), tensor(0.9375))" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Updating model.parameters\n", "In the training loop that we wrote above, layer weights and biases were manually updated and then zeroed out. Instead of this, we can write our model class in such a way, using `self.l1` and `self.l2`, such that we can update all the model's trainable parameters after each forward pass by calling `model.parameters()`." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.l1 = nn.Linear(n_in, nh)\n", " self.l2 = nn.Linear(nh, n_out)\n", " \n", " def __call__(self, x): return self.l2(F.relu(self.l1(x)))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "model = Model(m,nh,10)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "l1: Linear(in_features=784, out_features=50, bias=True)\n", "l2: Linear(in_features=50, out_features=10, bias=True)\n" ] } ], "source": [ "for name, l in model.named_children(): print(f'{name}: {l}')" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Model(\n", " (l1): Linear(in_features=784, out_features=50, bias=True)\n", " (l2): Linear(in_features=50, out_features=10, bias=True)\n", ")" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Linear(in_features=784, out_features=50, bias=True)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.l1" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for i in range((n-1)//bs + 1):\n", " start_i = i*bs\n", " end_i = start_i+bs\n", " xb = x_train[start_i:end_i]\n", " yb = y_train[start_i:end_i]\n", " loss = loss_func(model(xb), yb)\n", " \n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): p -= p.grad*lr\n", " model.zero_grad()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1160, grad_fn=), tensor(0.9375))" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How does PyTorch know what the model's parameters are? It overrides the `__setattr__` function inside the `nn.Module` class in order to register as model parameters the weights and biases inside the submodules (`self.l`, `self.l2`) that were defined in the model's class.\n", "\n", "Here's a sample dummy module that mocks up what's going on in `nn.Module`:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "class DummyModule():\n", " def __init__(self, n_in, nh, n_out):\n", " self._modules = {}\n", " self.l1 = nn.Linear(n_in,nh)\n", " self.l2 = nn.Linear(nh,n_out)\n", " \n", " def __setattr__(self,k,v):\n", " if not k.startswith(\"_\"): self._modules[k] = v\n", " super().__setattr__(k,v)\n", " \n", " def __repr__(self): return f'{self._modules}'\n", " \n", " def parameters(self):\n", " for l in self._modules.values():\n", " for p in l.parameters(): yield p" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dummy_mdl = DummyModule(m,nh,10)\n", "dummy_mdl" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[torch.Size([50, 784]),\n", " torch.Size([50]),\n", " torch.Size([10, 50]),\n", " torch.Size([10])]" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[o.shape for o in dummy_mdl.parameters()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Registering Modules\n", "\n", "For deeper models, it's obviously going to be a hassle to declare a `self.` variable for each and every layer in the model. It's probably more convenient to just pass in a list that contains all the layers. E.g. something like\n", "```\n", "layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]\n", "self.layers = layers\n", "```\n", "\n", "However in order to do this we have to manually register the modules because `nn.Module` won't automatically do so for us." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, layers):\n", " super().__init__()\n", " self.layers = layers\n", " for i,l in enumerate(self.layers): self.add_module(f'layer_{i}', l)\n", " \n", " def __call__(self,x):\n", " for l in self.layers: x = l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "model = Model(layers)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Model(\n", " (layer_0): Linear(in_features=784, out_features=50, bias=True)\n", " (layer_1): ReLU()\n", " (layer_2): Linear(in_features=50, out_features=10, bias=True)\n", ")" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### nn.ModuleList and nn.Sequential\n", "\n", "Thankfully both the `nn.ModuleList` and `nn.Sequential` classes can help us do this.\n", "\n", "`nn.Sequential` just uses an `nn.ModuleList` object to store the layers. This object automatically registers all layers.\n", "\n", "Here's a home-grown clone of `nn.Sequential` that depicts how `nn.ModuleList` is used:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "class SequentialModel(nn.Module):\n", " def __init__(self, layers):\n", " super().__init__()\n", " self.layers = nn.ModuleList(layers)\n", " \n", " def __call__(self, x):\n", " for l in self.layers: x = l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "model = SequentialModel(layers)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SequentialModel(\n", " (layers): ModuleList(\n", " (0): Linear(in_features=784, out_features=50, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=50, out_features=10, bias=True)\n", " )\n", ")" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0609, grad_fn=), tensor(1.))" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `nn.Sequential` already does all of the above on its own, we can just use it going forward:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.2535, grad_fn=), tensor(0.9375))" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): Linear(in_features=784, out_features=50, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=50, out_features=10, bias=True)\n", ")" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0282, grad_fn=), tensor(1.))" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactoring the Optimizer\n", "\n", "In our training loops above we manually coded the optimization step\n", "```\n", "with torch.no_grad():\n", " for p in model.parameters(): p -= p.grad*lr\n", " model.zero_grad()\n", "```\n", "\n", "We can refactor this logic into our own `Optimizer` class, which can be much more concisely called from our training loop:\n", "```\n", "opt.step()\n", "opt.zero_grad()\n", "```" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "class Optimizer():\n", " def __init__(self, params, lr=0.5):\n", " self.params, self.lr = list(params), lr\n", " \n", " def step(self):\n", " with torch.no_grad():\n", " for p in self.params: p -= p.grad*lr\n", " \n", " def zero_grad(self):\n", " for p in self.params: p.grad.data.zero_()" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "opt = Optimizer(model.parameters())" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "for epoch in range(epochs):\n", " for i in range((n-1)//bs + 1):\n", " start_i = i*bs\n", " end_i = start_i+bs\n", " xb = x_train[start_i:end_i]\n", " yb = y_train[start_i:end_i]\n", " pred = model(xb)\n", " loss = loss_func(pred,yb)\n", " \n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1486, grad_fn=), tensor(0.9375))" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss, acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "loss, acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch's own `optim.SGD` class functions identically to our home-grown `Optimizer` class, with the exception that `optim.SGD` also handles momentum." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch import optim" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "def get_model():\n", " model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))\n", " return model, optim.SGD(model.parameters(), lr=lr)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.3366, grad_fn=)" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model, opt = get_model()\n", "loss_func(model(xb), yb)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "for epoch in range(epochs):\n", " for i in range((n-1)//bs + 1):\n", " start_i = i*bs\n", " end_i = start_i+bs\n", " xb = x_train[start_i:end_i]\n", " yb = y_train[start_i:end_i]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", " \n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1586, grad_fn=), tensor(0.9375))" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss, acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "loss, acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Don't be afraid to include random tests such as this one right below. Although there may well be times when accuracy would dip below `0.7` (due to randomness), having checks like this interspersed throughout your code does much more good than harm, on the whole." ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "assert acc>0.7" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset\n", "\n", "In our early crack at coding up a training loop, we iterated through minibatches of `x` and `y` values separately:\n", "```\n", "xb = x_train[start_i:end_i]\n", "yb = y_train[start_i:end_i]\n", "```\n", "If, however, we create a `Dataset` class to hold our inputs and labels, we can accomplish those steps at once:" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Dataset():\n", " def __init__(self, x, y): self.x, self.y = x,y\n", " def __len__(self): return len(self.x)\n", " def __getitem__(self, i): return self.x[i], self.y[i]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)\n", "assert len(train_ds)==len(x_train)\n", "assert len(valid_ds)==len(x_valid)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]), tensor([5, 0, 4, 1, 9]))" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb, yb = train_ds[0:5]\n", "assert xb.shape==(5,28*28)\n", "assert yb.shape==(5,)\n", "xb, yb" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "model, opt = get_model()" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "for epoch in range(epochs):\n", " for i in range((n-1)//bs + 1):\n", " xb,yb = train_ds[i*bs: i*bs+bs]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", " \n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.2065, grad_fn=), tensor(0.9375))" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss,acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "assert acc>0.7\n", "loss, acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoader\n", "\n", "Our first crack at coding a training loop explicitly iterated over batches using a for-loop that kept track of specific indices. \n", "```\n", "for i in range((n-1)//bs + 1):\n", " xb,yb = train_ds[i*bs: i*bs+bs]\n", "```\n", "\n", "Creating a `DataLoader` class will allow for a more concise implementation thanks to the inclusion of a generator that automatically yields the next batch as soon as needed." ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "class DataLoader():\n", " def __init__(self,ds,bs): self.ds, self.bs = ds,bs\n", " def __iter__(self):\n", " for i in range(0, len(self.ds), self.bs): yield self.ds[i:i+self.bs]" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs)\n", "valid_dl = DataLoader(valid_ds, bs)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "xb,yb = next(iter(valid_dl))\n", "assert xb.shape==(bs,28*28)\n", "assert yb.shape==(bs,)" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3)" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHGNJREFUeJzt3X2sbWV9J/DvTy+VF+VF+kJoUeRSJKHlCtgiUBGuqYNtqqgw8Q9bYrGpjhkLgmlDsXPVTkKTyYBiB5pae1tJhraQ2nREYSIgb3aaYi1DKqICRSuIwMiLgC36zB973Xp7POfee/be96xznvP5JDvP2WutZ63fXazw3c/e66VaawEA+vScsQsAAHYfQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHdswdgG7Q1Xdm2TfJPeNXAoATOvQJI+31l4yy0q6DPpMQv6FwwsA1q1Rv7qvqp+oqo9W1der6jtVdV9VXVJVB8y46vvmUR8AjOy+WVcw2oi+qjYmuS3Jjyb5qyR3JfnZJL+R5LSqOqm19shY9QFAD8Yc0f+PTEL+Xa2101trv9Va25zk4iQvTfJfR6wNALpQrbWV32jVYUm+kslXEhtba9/bbt4LkjyQpJL8aGvt21Os//Ykx86nWgAYzedaa8fNsoKxRvSbh/a67UM+SVprTyS5NcneSV6x0oUBQE/G+o3+pUN79xLzv5TkNUmOSPLppVYyjNwXc+T0pQFAP8Ya0e83tI8tMX/b9P1XoBYA6NZqvY6+hnaHJxAs9buF3+gBYGKsEf22Eft+S8zfd8FyAMAUxgr6Lw7tEUvM/8mhXeo3fABgF4wV9DcM7Wuq6t/VMFxed1KSp5P8zUoXBgA9GSXoW2tfSXJdJjfsf+eC2e9Lsk+SP53mGnoA4PvGPBnvP2VyC9wPVdWrk3whyfFJTs3kK/vfHrE2AOjCaLfAHUb1L0+yNZOAPy/JxiQfSnKC+9wDwOxGvbyutfbVJG8dswYA6Nmoj6kFAHYvQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHdswdgGw1m3atGmm/ueee+7UfTdu3DjTtvfee++p+15wwQUzbXu//fabqf8nP/nJqfs+8cQTM20b1hIjegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDoWLXWxq5h7qrq9iTHjl0Ha8fzn//8qfvef//9M217//33n6n/evXP//zPU/c999xzZ9r2VVddNVN/WIbPtdaOm2UFo43oq+q+qmpLvB4cqy4A6MmGkbf/WJJLFpn+5EoXAgA9Gjvov9Va2zJyDQDQLSfjAUDHxh7RP6+q3pLkRUm+neSOJDe11r47blkA0Iexg/6gJB9bMO3eqnpra+0zO+s8nF2/mCNnrgwAOjDmV/d/nOTVmYT9Pkl+OskfJDk0ySeratN4pQFAH0Yb0bfW3rdg0p1J3l5VTyY5L8mWJG/YyToWvbbQdfQAMLEaT8a7fGhPHrUKAOjAagz6h4Z2n1GrAIAOrMagP2Fo7xm1CgDowChBX1VHVdULF5n+4iQfHt5esbJVAUB/xjoZ78wkv1VVNyS5N8kTSTYm+cUkeya5Jsl/G6k2AOjGWEF/Q5KXJjkmk6/q90nyrSS3ZHJd/cdaj4/VA4AV5jG1kOQFL3jB1H2vueaambb9yCOPTN337//+72fa9jHHHDN13xe/+MUzbfuQQw6Zqf9ee+01dd9vfOMbM237hBNO2PlCu2nbrDtr9zG1AMDuJ+gBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6tmHsAmA1eOKJJ6bu+8pXvnKOlawfP/zDPzxT//e85z2j9E2S0047beq+f/InfzLTtmG5jOgBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65jG1wCgefvjhmfrfeuutU/ed9TG1xxxzzNR9PaaWlWZEDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAd8zx6YBQHHHDATP0vuOCCOVWyfAcffPBo24blMqIHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomMfUAlPZtGnTTP3/4i/+Yqb+hx9++NR977777pm2fd55583UH1bSXEb0VXVGVV1aVTdX1eNV1arqip30ObGqrqmqR6vqqaq6o6rOqarnzqMmAGB+I/oLk2xK8mSSryU5ckcLV9Xrk1yd5Jkkf5bk0SS/lOTiJCclOXNOdQHAujav3+jPTXJEkn2TvGNHC1bVvkn+MMl3k5zSWju7tfaeJC9L8tkkZ1TVm+dUFwCsa3MJ+tbaDa21L7XW2i4sfkaSH0lyZWvt77ZbxzOZfDOQ7OTDAgCwa8Y4637z0H5qkXk3JXkqyYlV9byVKwkA+jRG0L90aH/gtNfW2rNJ7s3k3IHDVrIoAOjRGJfX7Te0jy0xf9v0/Xe2oqq6fYlZOzwZEADWi9V4w5wa2l35vR8A2IExRvTbRuz7LTF/3wXLLam1dtxi04eR/rHLLw0A+jLGiP6LQ3vEwhlVtSHJS5I8m+SelSwKAHo0RtBfP7SnLTLv5CR7J7mttfadlSsJAPo0RtBfleThJG+uqpdvm1hVeyb53eHtZSPUBQDdmctv9FV1epLTh7cHDe0JVbV1+Pvh1tr5SdJae7yqfi2TwL+xqq7M5Ba4r8vk0rurMrktLgAwo3mdjPeyJGctmHZYvn8t/D8lOX/bjNbax6vqVUl+O8mbkuyZ5MtJ3p3kQ7t4hz0AYCfmEvSttS1Jtiyzz61JfmEe2wcAFud59LCOnXXWwi/idt373//+mbZ9yCGHzNT/6aefnrrvO94x2+M0vvrVr87UH1bSarxhDgAwJ4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADrmMbUwsuc///lT9z3//PNn2vaFF144dd/nPGe2ccKjjz46U/+f+7mfm7rvXXfdNdO2YS0xogeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjnkePYxs69atU/d94xvfOL9Clumqq66aqf8ll1wyU3/PlIddY0QPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMY+phZFt3Lhx7BKmctlll83U/7bbbptTJcCOGNEDQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMc8jx5Gdt11103dd9OmTXOsZHlmqTuZ/Xn2F1100dR9v/71r8+0bVhL5jKir6ozqurSqrq5qh6vqlZVVyyx7KHD/KVeV86jJgBgfiP6C5NsSvJkkq8lOXIX+vxDko8vMv3OOdUEAOvevIL+3EwC/stJXpXkhl3o8/nW2pY5bR8AWMRcgr619m/BXlXzWCUAMAdjnox3cFX9epIDkzyS5LOttTtGrAcAujNm0P/88Po3VXVjkrNaa/fvygqq6vYlZu3KOQIA0L0xrqN/KskHkhyX5IDhte13/VOSfLqq9hmhLgDozoqP6FtrDyX5nQWTb6qq1yS5JcnxSd6W5IO7sK7jFps+jPSPnbFUAFjzVs2d8Vprzyb5yPD25DFrAYBerJqgH3xzaH11DwBzsNqC/hVDe8+oVQBAJ1Y86Kvq+Kr6oUWmb87kxjtJsujtcwGA5ZnLyXhVdXqS04e3Bw3tCVW1dfj74dba+cPfv5fkqOFSuq8N045Osnn4+72ttdvmURcArHfzOuv+ZUnOWjDtsOGVJP+UZFvQfyzJG5L8TJLXJtkjyTeS/HmSD7fWbp5TTQCw7s3rFrhbkmzZxWX/KMkfzWO7AMCOVWtt7BrmznX0rCV77bXX1H2vuGK201mOO27RW1Hskhe96EUzbXtWDz744NR93/rWt8607WuvvXam/rAMn1vqnjG7arWddQ8AzJGgB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COeUwtrGF77rnnTP03bNgwdd/HH398pm2P6Zlnnpmp/7vf/e6p+15++eUzbZt1x2NqAYClCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COeR49MJWjjz56pv4XX3zxTP1PPfXUmfrP4v7775+676GHHjq/QlgPPI8eAFiaoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjnlMLSTZe++9p+771FNPzbGS9eOAAw6Yqf9HP/rRqfu+/vWvn2nbs/jxH//xmfo/8MADc6qENcJjagGApQl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjm0YuwCYh40bN87U/5Zbbpm67yc+8YmZtn3nnXdO3XfWZ5OfffbZU/fdY489Ztr2rM9lP/zww2fqP4uvfOUrU/f1PHlW2swj+qo6sKreVlV/WVVfrqqnq+qxqrqlqs6uqkW3UVUnVtU1VfVoVT1VVXdU1TlV9dxZawIAJuYxoj8zyWVJHkhyQ5L7k/xYkjcm+UiS11bVma21tq1DVb0+ydVJnknyZ0keTfJLSS5OctKwTgBgRvMI+ruTvC7JJ1pr39s2saouSPK3Sd6USehfPUzfN8kfJvluklNaa383TH9vkuuTnFFVb26tXTmH2gBgXZv5q/vW2vWttb/ePuSH6Q8muXx4e8p2s85I8iNJrtwW8sPyzyS5cHj7jlnrAgB2/1n3/zq0z243bfPQfmqR5W9K8lSSE6vqebuzMABYD3bbWfdVtSHJrwxvtw/1lw7t3Qv7tNaerap7kxyV5LAkX9jJNm5fYtaRy6sWAPq0O0f0FyX5qSTXtNau3W76fkP72BL9tk3ff3cVBgDrxW4Z0VfVu5Kcl+SuJL+83O5D23a4VJLW2nFLbP/2JMcuc7sA0J25j+ir6p1JPpjkH5Oc2lp7dMEi20bs+2Vx+y5YDgCY0lyDvqrOSfLhJHdmEvIPLrLYF4f2iEX6b0jykkxO3rtnnrUBwHo0t6Cvqt/M5IY3n88k5B9aYtHrh/a0ReadnGTvJLe11r4zr9oAYL2aS9APN7u5KMntSV7dWnt4B4tfleThJG+uqpdvt449k/zu8PayedQFAOvdzCfjVdVZSd6fyZ3ubk7yrqpauNh9rbWtSdJae7yqfi2TwL+xqq7M5Ba4r8vk0rurMrktLgAwo3mcdf+SoX1uknOWWOYzSbZue9Na+3hVvSrJb2dyi9w9k3w5ybuTfGj7++IDANObOehba1uSbJmi361JfmHW7UOSnHnmbM9BOuigg6bu+6u/+qszbXutWuSbu2UZ8/P8k08+OVP/t7/97XOqBHa/3X0LXABgRIIeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgYzM/jx5WgwMPPHDsElimq6++eqb+H/jAB6bu+9BDD8207QcffHCm/rCSjOgBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6Vq21sWuYu6q6PcmxY9fBytljjz1m6r958+ap+77lLW+ZadsHH3zw1H0fe+yxmbY9i0svvXSm/jfffPNM/Z999tmZ+sMa8bnW2nGzrMCIHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65nn0ALB6eR49ALA0QQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANCxmYO+qg6sqrdV1V9W1Zer6umqeqyqbqmqs6vqOQuWP7Sq2g5eV85aEwAwsWEO6zgzyWVJHkhyQ5L7k/xYkjcm+UiS11bVma21tqDfPyT5+CLru3MONQEAmU/Q353kdUk+0Vr73raJVXVBkr9N8qZMQv/qBf0+31rbMoftAwBLmPmr+9ba9a21v94+5IfpDya5fHh7yqzbAQCWbx4j+h3516F9dpF5B1fVryc5MMkjST7bWrtjN9cDAOvKbgv6qtqQ5FeGt59aZJGfH17b97kxyVmttft3V10AsJ7szhH9RUl+Ksk1rbVrt5v+VJIPZHIi3j3DtKOTbElyapJPV9XLWmvf3tkGqur2JWYdOW3RANCT+sGT4eew0qp3JflgkruSnNRae3QX+mxIckuS45Oc01r74C702VHQ773rFQPAqvS51tpxs6xg7iP6qnpnJiH/j0levSshnySttWer6iOZBP3Jwzp21mfRf/zwAeDYXS4aADo11zvjVdU5ST6cybXwpw5n3i/HN4d2n3nWBQDr1dyCvqp+M8nFST6fScg/NMVqXjG09+xwKQBgl8wl6KvqvZmcfHd7Jl/XP7yDZY+vqh9aZPrmJOcOb6+YR10AsN7N/Bt9VZ2V5P1Jvpvk5iTvqqqFi93XWts6/P17SY4aLqX72jDt6CSbh7/f21q7bda6AID5nIz3kqF9bpJzlljmM0m2Dn9/LMkbkvxMktcm2SPJN5L8eZIPt9ZunkNNAEB20+V1Y3PWPQCdmPnyOs+jB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6FivQX/o2AUAwBwcOusKNsyhiNXo8aG9b4n5Rw7tXbu/lG7YZ9Ox36Zjvy2ffTad1bzfDs3382xq1VqbvZQ1pqpuT5LW2nFj17JW2GfTsd+mY78tn302nfWw33r96h4AiKAHgK4JegDomKAHgI4JegDo2Lo86x4A1gsjegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDo2LoK+qr6iar6aFV9vaq+U1X3VdUlVXXA2LWtVsM+aku8Hhy7vrFU1RlVdWlV3VxVjw/744qd9Dmxqq6pqker6qmquqOqzqmq565U3WNbzn6rqkN3cOy1qrpypesfQ1UdWFVvq6q/rKovV9XTVfVYVd1SVWdX1aL/H1/vx9ty91vPx1uvz6P/AVW1McltSX40yV9l8uzhn03yG0lOq6qTWmuPjFjiavZYkksWmf7kSheyilyYZFMm++Br+f4zrRdVVa9PcnWSZ5L8WZJHk/xSkouTnJTkzN1Z7CqyrP02+IckH19k+p1zrGs1OzPJZUkeSHJDkvuT/FiSNyb5SJLXVtWZbbu7nznekkyx3wb9HW+ttXXxSnJtkpbkPy+Y/t+H6ZePXeNqfCW5L8l9Y9ex2l5JTk3yk0kqySnDMXTFEsvum+ShJN9J8vLtpu+ZyYfPluTNY/+bVuF+O3SYv3XsukfeZ5szCennLJh+UCbh1ZK8abvpjrfp9lu3x9u6+Oq+qg5L8ppMQuv3F8z+L0m+neSXq2qfFS6NNaq1dkNr7Utt+D/ETpyR5EeSXNla+7vt1vFMJiPcJHnHbihz1VnmfiNJa+361tpft9a+t2D6g0kuH96est0sx1um2m/dWi9f3W8e2usW+Y/+RFXdmskHgVck+fRKF7cGPK+q3pLkRZl8KLojyU2tte+OW9aase34+9Qi825K8lSSE6vqea2176xcWWvGwVX160kOTPJIks+21u4YuabV4l+H9tntpjnedm6x/bZNd8fbegn6lw7t3UvM/1ImQX9EBP1iDkrysQXT7q2qt7bWPjNGQWvMksdfa+3Zqro3yVFJDkvyhZUsbI34+eH1b6rqxiRntdbuH6WiVaCqNiT5leHt9qHueNuBHey3bbo73tbFV/dJ9hvax5aYv236/itQy1rzx0lenUnY75Pkp5P8QSa/Z32yqjaNV9qa4fibzlNJPpDkuCQHDK9XZXJi1SlJPr3Of267KMlPJbmmtXbtdtMdbzu21H7r9nhbL0G/MzW0fjdcoLX2vuG3rm+01p5qrd3ZWnt7Jicx7pVky7gVdsHxt4jW2kOttd9prX2utfat4XVTJt++/Z8khyd527hVjqOq3pXkvEyuHvrl5XYf2nV3vO1ov/V8vK2XoN/2CXa/Jebvu2A5dm7bySwnj1rF2uD4m6PW2rOZXB6VrMPjr6remeSDSf4xyamttUcXLOJ4W8Qu7LdF9XC8rZeg/+LQHrHE/J8c2qV+w+cHPTS0a/KrrBW25PE3/F74kkxOCrpnJYta4745tOvq+Kuqc5J8OJNruk8dziBfyPG2wC7utx1Z08fbegn6G4b2NYvcDekFmdxA4ukkf7PSha1hJwztuvmfxQyuH9rTFpl3cpK9k9y2js+AnsYrhnbdHH9V9ZuZ3PDm85mE1UNLLOp4284y9tuOrOnjbV0EfWvtK0muy+QEsncumP2+TD6l/Wlr7dsrXNqqVlVHVdULF5n+4kw+HSfJDm/7SpLkqiQPJ3lzVb1828Sq2jPJ7w5vLxujsNWsqo6vqh9aZPrmJOcOb9fF8VdV783kJLLbk7y6tfbwDhZ3vA2Ws996Pt5qvdy3YpFb4H4hyfGZ3Knr7iQnNrfA/XeqakuS38rkG5F7kzyRZGOSX8zkLlvXJHlDa+1fxqpxLFV1epLTh7cHJfkPmXzav3mY9nBr7fwFy1+VyS1Jr8zklqSvy+RSqKuS/Mf1cBOZ5ey34ZKmo5LcmMntcpPk6Hz/OvH3tta2BVe3quqsJFuTfDfJpVn8t/X7Wmtbt+uz7o+35e63ro+3sW/Nt5KvJIdkcrnYA0n+Jck/ZXJyxgvHrm01vjK5tOR/ZnKG6rcyucnEN5P870yuQ62xaxxx32zJ5KzlpV73LdLnpEw+HP2/TH4q+r+ZjBSeO/a/ZzXutyRnJ/lfmdzR8slMbul6fyb3bn/l2P+WVbTPWpIbHW+z7beej7d1M6IHgPVoXfxGDwDrlaAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDo2P8HOhfD/hZD5eMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 250, "width": 253 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "model,opt = get_model()" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for xb,yb in train_dl:\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0822, grad_fn=), tensor(1.))" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss,acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "assert acc>0.7\n", "loss,acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sampling Should be Random\n", "\n", "When training:\n", "* training set should be in *random* order\n", "* that order should differ on each iteration\n", "\n", "However for validation:\n", "* validation set should *never* be randomized" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "class Sampler():\n", " def __init__(self, ds, bs, shuffle=False):\n", " self.n, self.bs, self.shuffle = len(ds), bs, shuffle\n", " \n", " def __iter__(self):\n", " self.idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)\n", " for i in range(0, self.n, self.bs): yield self.idxs[i: i+self.bs]" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])]" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_ds = Dataset(*train_ds[:10])\n", "s = Sampler(small_ds, 3, False)\n", "[o for o in s]" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([1, 9, 0]), tensor([8, 2, 3]), tensor([5, 7, 6]), tensor([4])]" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = Sampler(small_ds, 3, True)\n", "[o for o in s]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks pretty random. Good.\n", "\n", "Let's rewrite our `DataLoader` to take advantage of it." ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "def collate(b):\n", " xs,ys = zip(*b)\n", " return torch.stack(xs), torch.stack(ys)\n", "\n", "class DataLoader():\n", " def __init__(self, ds, sampler, collate_fn=collate):\n", " self.ds, self.sampler, self.collate_fn = ds, sampler, collate_fn\n", " \n", " def __iter__(self):\n", " for s in self.sampler: yield self.collate_fn([self.ds[i] for i in s])" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "train_samp = Sampler(train_ds, bs, shuffle=True)\n", "valid_samp = Sampler(valid_ds, bs, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, sampler=train_samp, collate_fn=collate)\n", "valid_dl = DataLoader(valid_ds, sampler=valid_samp, collate_fn=collate)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3)" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHGNJREFUeJzt3X2sbWV9J/DvTy+VF+VF+kJoUeRSJKHlCtgiUBGuqYNtqqgw8Q9bYrGpjhkLgmlDsXPVTkKTyYBiB5pae1tJhraQ2nREYSIgb3aaYi1DKqICRSuIwMiLgC36zB973Xp7POfee/be96xznvP5JDvP2WutZ63fXazw3c/e66VaawEA+vScsQsAAHYfQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHdswdgG7Q1Xdm2TfJPeNXAoATOvQJI+31l4yy0q6DPpMQv6FwwsA1q1Rv7qvqp+oqo9W1der6jtVdV9VXVJVB8y46vvmUR8AjOy+WVcw2oi+qjYmuS3Jjyb5qyR3JfnZJL+R5LSqOqm19shY9QFAD8Yc0f+PTEL+Xa2101trv9Va25zk4iQvTfJfR6wNALpQrbWV32jVYUm+kslXEhtba9/bbt4LkjyQpJL8aGvt21Os//Ykx86nWgAYzedaa8fNsoKxRvSbh/a67UM+SVprTyS5NcneSV6x0oUBQE/G+o3+pUN79xLzv5TkNUmOSPLppVYyjNwXc+T0pQFAP8Ya0e83tI8tMX/b9P1XoBYA6NZqvY6+hnaHJxAs9buF3+gBYGKsEf22Eft+S8zfd8FyAMAUxgr6Lw7tEUvM/8mhXeo3fABgF4wV9DcM7Wuq6t/VMFxed1KSp5P8zUoXBgA9GSXoW2tfSXJdJjfsf+eC2e9Lsk+SP53mGnoA4PvGPBnvP2VyC9wPVdWrk3whyfFJTs3kK/vfHrE2AOjCaLfAHUb1L0+yNZOAPy/JxiQfSnKC+9wDwOxGvbyutfbVJG8dswYA6Nmoj6kFAHYvQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHdswdgGw1m3atGmm/ueee+7UfTdu3DjTtvfee++p+15wwQUzbXu//fabqf8nP/nJqfs+8cQTM20b1hIjegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDoWLXWxq5h7qrq9iTHjl0Ha8fzn//8qfvef//9M217//33n6n/evXP//zPU/c999xzZ9r2VVddNVN/WIbPtdaOm2UFo43oq+q+qmpLvB4cqy4A6MmGkbf/WJJLFpn+5EoXAgA9Gjvov9Va2zJyDQDQLSfjAUDHxh7RP6+q3pLkRUm+neSOJDe11r47blkA0Iexg/6gJB9bMO3eqnpra+0zO+s8nF2/mCNnrgwAOjDmV/d/nOTVmYT9Pkl+OskfJDk0ySeratN4pQFAH0Yb0bfW3rdg0p1J3l5VTyY5L8mWJG/YyToWvbbQdfQAMLEaT8a7fGhPHrUKAOjAagz6h4Z2n1GrAIAOrMagP2Fo7xm1CgDowChBX1VHVdULF5n+4iQfHt5esbJVAUB/xjoZ78wkv1VVNyS5N8kTSTYm+cUkeya5Jsl/G6k2AOjGWEF/Q5KXJjkmk6/q90nyrSS3ZHJd/cdaj4/VA4AV5jG1kOQFL3jB1H2vueaambb9yCOPTN337//+72fa9jHHHDN13xe/+MUzbfuQQw6Zqf9ee+01dd9vfOMbM237hBNO2PlCu2nbrDtr9zG1AMDuJ+gBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6tmHsAmA1eOKJJ6bu+8pXvnKOlawfP/zDPzxT//e85z2j9E2S0047beq+f/InfzLTtmG5jOgBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65jG1wCgefvjhmfrfeuutU/ed9TG1xxxzzNR9PaaWlWZEDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAd8zx6YBQHHHDATP0vuOCCOVWyfAcffPBo24blMqIHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomMfUAlPZtGnTTP3/4i/+Yqb+hx9++NR977777pm2fd55583UH1bSXEb0VXVGVV1aVTdX1eNV1arqip30ObGqrqmqR6vqqaq6o6rOqarnzqMmAGB+I/oLk2xK8mSSryU5ckcLV9Xrk1yd5Jkkf5bk0SS/lOTiJCclOXNOdQHAujav3+jPTXJEkn2TvGNHC1bVvkn+MMl3k5zSWju7tfaeJC9L8tkkZ1TVm+dUFwCsa3MJ+tbaDa21L7XW2i4sfkaSH0lyZWvt77ZbxzOZfDOQ7OTDAgCwa8Y4637z0H5qkXk3JXkqyYlV9byVKwkA+jRG0L90aH/gtNfW2rNJ7s3k3IHDVrIoAOjRGJfX7Te0jy0xf9v0/Xe2oqq6fYlZOzwZEADWi9V4w5wa2l35vR8A2IExRvTbRuz7LTF/3wXLLam1dtxi04eR/rHLLw0A+jLGiP6LQ3vEwhlVtSHJS5I8m+SelSwKAHo0RtBfP7SnLTLv5CR7J7mttfadlSsJAPo0RtBfleThJG+uqpdvm1hVeyb53eHtZSPUBQDdmctv9FV1epLTh7cHDe0JVbV1+Pvh1tr5SdJae7yqfi2TwL+xqq7M5Ba4r8vk0rurMrktLgAwo3mdjPeyJGctmHZYvn8t/D8lOX/bjNbax6vqVUl+O8mbkuyZ5MtJ3p3kQ7t4hz0AYCfmEvSttS1Jtiyzz61JfmEe2wcAFud59LCOnXXWwi/idt373//+mbZ9yCGHzNT/6aefnrrvO94x2+M0vvrVr87UH1bSarxhDgAwJ4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADrmMbUwsuc///lT9z3//PNn2vaFF144dd/nPGe2ccKjjz46U/+f+7mfm7rvXXfdNdO2YS0xogeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjnkePYxs69atU/d94xvfOL9Clumqq66aqf8ll1wyU3/PlIddY0QPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMY+phZFt3Lhx7BKmctlll83U/7bbbptTJcCOGNEDQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMc8jx5Gdt11103dd9OmTXOsZHlmqTuZ/Xn2F1100dR9v/71r8+0bVhL5jKir6ozqurSqrq5qh6vqlZVVyyx7KHD/KVeV86jJgBgfiP6C5NsSvJkkq8lOXIX+vxDko8vMv3OOdUEAOvevIL+3EwC/stJXpXkhl3o8/nW2pY5bR8AWMRcgr619m/BXlXzWCUAMAdjnox3cFX9epIDkzyS5LOttTtGrAcAujNm0P/88Po3VXVjkrNaa/fvygqq6vYlZu3KOQIA0L0xrqN/KskHkhyX5IDhte13/VOSfLqq9hmhLgDozoqP6FtrDyX5nQWTb6qq1yS5JcnxSd6W5IO7sK7jFps+jPSPnbFUAFjzVs2d8Vprzyb5yPD25DFrAYBerJqgH3xzaH11DwBzsNqC/hVDe8+oVQBAJ1Y86Kvq+Kr6oUWmb87kxjtJsujtcwGA5ZnLyXhVdXqS04e3Bw3tCVW1dfj74dba+cPfv5fkqOFSuq8N045Osnn4+72ttdvmURcArHfzOuv+ZUnOWjDtsOGVJP+UZFvQfyzJG5L8TJLXJtkjyTeS/HmSD7fWbp5TTQCw7s3rFrhbkmzZxWX/KMkfzWO7AMCOVWtt7BrmznX0rCV77bXX1H2vuGK201mOO27RW1Hskhe96EUzbXtWDz744NR93/rWt8607WuvvXam/rAMn1vqnjG7arWddQ8AzJGgB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COeUwtrGF77rnnTP03bNgwdd/HH398pm2P6Zlnnpmp/7vf/e6p+15++eUzbZt1x2NqAYClCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COeR49MJWjjz56pv4XX3zxTP1PPfXUmfrP4v7775+676GHHjq/QlgPPI8eAFiaoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjnlMLSTZe++9p+771FNPzbGS9eOAAw6Yqf9HP/rRqfu+/vWvn2nbs/jxH//xmfo/8MADc6qENcJjagGApQl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjm0YuwCYh40bN87U/5Zbbpm67yc+8YmZtn3nnXdO3XfWZ5OfffbZU/fdY489Ztr2rM9lP/zww2fqP4uvfOUrU/f1PHlW2swj+qo6sKreVlV/WVVfrqqnq+qxqrqlqs6uqkW3UVUnVtU1VfVoVT1VVXdU1TlV9dxZawIAJuYxoj8zyWVJHkhyQ5L7k/xYkjcm+UiS11bVma21tq1DVb0+ydVJnknyZ0keTfJLSS5OctKwTgBgRvMI+ruTvC7JJ1pr39s2saouSPK3Sd6USehfPUzfN8kfJvluklNaa383TH9vkuuTnFFVb26tXTmH2gBgXZv5q/vW2vWttb/ePuSH6Q8muXx4e8p2s85I8iNJrtwW8sPyzyS5cHj7jlnrAgB2/1n3/zq0z243bfPQfmqR5W9K8lSSE6vqebuzMABYD3bbWfdVtSHJrwxvtw/1lw7t3Qv7tNaerap7kxyV5LAkX9jJNm5fYtaRy6sWAPq0O0f0FyX5qSTXtNau3W76fkP72BL9tk3ff3cVBgDrxW4Z0VfVu5Kcl+SuJL+83O5D23a4VJLW2nFLbP/2JMcuc7sA0J25j+ir6p1JPpjkH5Oc2lp7dMEi20bs+2Vx+y5YDgCY0lyDvqrOSfLhJHdmEvIPLrLYF4f2iEX6b0jykkxO3rtnnrUBwHo0t6Cvqt/M5IY3n88k5B9aYtHrh/a0ReadnGTvJLe11r4zr9oAYL2aS9APN7u5KMntSV7dWnt4B4tfleThJG+uqpdvt449k/zu8PayedQFAOvdzCfjVdVZSd6fyZ3ubk7yrqpauNh9rbWtSdJae7yqfi2TwL+xqq7M5Ba4r8vk0rurMrktLgAwo3mcdf+SoX1uknOWWOYzSbZue9Na+3hVvSrJb2dyi9w9k3w5ybuTfGj7++IDANObOehba1uSbJmi361JfmHW7UOSnHnmbM9BOuigg6bu+6u/+qszbXutWuSbu2UZ8/P8k08+OVP/t7/97XOqBHa/3X0LXABgRIIeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgYzM/jx5WgwMPPHDsElimq6++eqb+H/jAB6bu+9BDD8207QcffHCm/rCSjOgBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6Vq21sWuYu6q6PcmxY9fBytljjz1m6r958+ap+77lLW+ZadsHH3zw1H0fe+yxmbY9i0svvXSm/jfffPNM/Z999tmZ+sMa8bnW2nGzrMCIHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65nn0ALB6eR49ALA0QQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANCxmYO+qg6sqrdV1V9W1Zer6umqeqyqbqmqs6vqOQuWP7Sq2g5eV85aEwAwsWEO6zgzyWVJHkhyQ5L7k/xYkjcm+UiS11bVma21tqDfPyT5+CLru3MONQEAmU/Q353kdUk+0Vr73raJVXVBkr9N8qZMQv/qBf0+31rbMoftAwBLmPmr+9ba9a21v94+5IfpDya5fHh7yqzbAQCWbx4j+h3516F9dpF5B1fVryc5MMkjST7bWrtjN9cDAOvKbgv6qtqQ5FeGt59aZJGfH17b97kxyVmttft3V10AsJ7szhH9RUl+Ksk1rbVrt5v+VJIPZHIi3j3DtKOTbElyapJPV9XLWmvf3tkGqur2JWYdOW3RANCT+sGT4eew0qp3JflgkruSnNRae3QX+mxIckuS45Oc01r74C702VHQ773rFQPAqvS51tpxs6xg7iP6qnpnJiH/j0levSshnySttWer6iOZBP3Jwzp21mfRf/zwAeDYXS4aADo11zvjVdU5ST6cybXwpw5n3i/HN4d2n3nWBQDr1dyCvqp+M8nFST6fScg/NMVqXjG09+xwKQBgl8wl6KvqvZmcfHd7Jl/XP7yDZY+vqh9aZPrmJOcOb6+YR10AsN7N/Bt9VZ2V5P1Jvpvk5iTvqqqFi93XWts6/P17SY4aLqX72jDt6CSbh7/f21q7bda6AID5nIz3kqF9bpJzlljmM0m2Dn9/LMkbkvxMktcm2SPJN5L8eZIPt9ZunkNNAEB20+V1Y3PWPQCdmPnyOs+jB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6FivQX/o2AUAwBwcOusKNsyhiNXo8aG9b4n5Rw7tXbu/lG7YZ9Ox36Zjvy2ffTad1bzfDs3382xq1VqbvZQ1pqpuT5LW2nFj17JW2GfTsd+mY78tn302nfWw33r96h4AiKAHgK4JegDomKAHgI4JegDo2Lo86x4A1gsjegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDo2LoK+qr6iar6aFV9vaq+U1X3VdUlVXXA2LWtVsM+aku8Hhy7vrFU1RlVdWlV3VxVjw/744qd9Dmxqq6pqker6qmquqOqzqmq565U3WNbzn6rqkN3cOy1qrpypesfQ1UdWFVvq6q/rKovV9XTVfVYVd1SVWdX1aL/H1/vx9ty91vPx1uvz6P/AVW1McltSX40yV9l8uzhn03yG0lOq6qTWmuPjFjiavZYkksWmf7kSheyilyYZFMm++Br+f4zrRdVVa9PcnWSZ5L8WZJHk/xSkouTnJTkzN1Z7CqyrP02+IckH19k+p1zrGs1OzPJZUkeSHJDkvuT/FiSNyb5SJLXVtWZbbu7nznekkyx3wb9HW+ttXXxSnJtkpbkPy+Y/t+H6ZePXeNqfCW5L8l9Y9ex2l5JTk3yk0kqySnDMXTFEsvum+ShJN9J8vLtpu+ZyYfPluTNY/+bVuF+O3SYv3XsukfeZ5szCennLJh+UCbh1ZK8abvpjrfp9lu3x9u6+Oq+qg5L8ppMQuv3F8z+L0m+neSXq2qfFS6NNaq1dkNr7Utt+D/ETpyR5EeSXNla+7vt1vFMJiPcJHnHbihz1VnmfiNJa+361tpft9a+t2D6g0kuH96est0sx1um2m/dWi9f3W8e2usW+Y/+RFXdmskHgVck+fRKF7cGPK+q3pLkRZl8KLojyU2tte+OW9aase34+9Qi825K8lSSE6vqea2176xcWWvGwVX160kOTPJIks+21u4YuabV4l+H9tntpjnedm6x/bZNd8fbegn6lw7t3UvM/1ImQX9EBP1iDkrysQXT7q2qt7bWPjNGQWvMksdfa+3Zqro3yVFJDkvyhZUsbI34+eH1b6rqxiRntdbuH6WiVaCqNiT5leHt9qHueNuBHey3bbo73tbFV/dJ9hvax5aYv236/itQy1rzx0lenUnY75Pkp5P8QSa/Z32yqjaNV9qa4fibzlNJPpDkuCQHDK9XZXJi1SlJPr3Of267KMlPJbmmtXbtdtMdbzu21H7r9nhbL0G/MzW0fjdcoLX2vuG3rm+01p5qrd3ZWnt7Jicx7pVky7gVdsHxt4jW2kOttd9prX2utfat4XVTJt++/Z8khyd527hVjqOq3pXkvEyuHvrl5XYf2nV3vO1ov/V8vK2XoN/2CXa/Jebvu2A5dm7bySwnj1rF2uD4m6PW2rOZXB6VrMPjr6remeSDSf4xyamttUcXLOJ4W8Qu7LdF9XC8rZeg/+LQHrHE/J8c2qV+w+cHPTS0a/KrrBW25PE3/F74kkxOCrpnJYta4745tOvq+Kuqc5J8OJNruk8dziBfyPG2wC7utx1Z08fbegn6G4b2NYvcDekFmdxA4ukkf7PSha1hJwztuvmfxQyuH9rTFpl3cpK9k9y2js+AnsYrhnbdHH9V9ZuZ3PDm85mE1UNLLOp4284y9tuOrOnjbV0EfWvtK0muy+QEsncumP2+TD6l/Wlr7dsrXNqqVlVHVdULF5n+4kw+HSfJDm/7SpLkqiQPJ3lzVb1828Sq2jPJ7w5vLxujsNWsqo6vqh9aZPrmJOcOb9fF8VdV783kJLLbk7y6tfbwDhZ3vA2Ws996Pt5qvdy3YpFb4H4hyfGZ3Knr7iQnNrfA/XeqakuS38rkG5F7kzyRZGOSX8zkLlvXJHlDa+1fxqpxLFV1epLTh7cHJfkPmXzav3mY9nBr7fwFy1+VyS1Jr8zklqSvy+RSqKuS/Mf1cBOZ5ey34ZKmo5LcmMntcpPk6Hz/OvH3tta2BVe3quqsJFuTfDfJpVn8t/X7Wmtbt+uz7o+35e63ro+3sW/Nt5KvJIdkcrnYA0n+Jck/ZXJyxgvHrm01vjK5tOR/ZnKG6rcyucnEN5P870yuQ62xaxxx32zJ5KzlpV73LdLnpEw+HP2/TH4q+r+ZjBSeO/a/ZzXutyRnJ/lfmdzR8slMbul6fyb3bn/l2P+WVbTPWpIbHW+z7beej7d1M6IHgPVoXfxGDwDrlaAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDo2P8HOhfD/hZD5eMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 250, "width": 253 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xb, yb = next(iter(valid_dl))\n", "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(5)" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHApJREFUeJzt3X2sZWV9L/Dvr9LKdVrAl7Gm6S0jWqTQqheoCOQOL0aqbaqocMMftcRgI0ouL+JNm6q909qb2KS9gK9DaltSbe7YQGrTW1Bv5WVQ9BrHKNeCAoURTbXDiICC2qLP/WOvqdPTc+Zlrz1nn3n255PsPGevtZ71/GbNyvmetfd6qdZaAIA+/ci8CwAADhxBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdO2TeBRwIVXVfksOSbJ9zKQAwrQ1JHmmtPXPMSroM+kxC/inDCwAW1lw/uq+qn66qP62qf6yq71XV9qq6sqqePHLV22dRHwDM2faxK5jbEX1VPSvJbUmenuSvk3wxyQuSXJLkJVV1amvtG/OqDwB6MM8j+vdkEvIXt9bObq39VmvtzCRXJHlOkv8xx9oAoAvVWlv9QauOSvIPmXwk8azW2g92m/cTSb6WpJI8vbX26BTr35bk+NlUCwBz89nW2gljVjCvI/ozh/aju4d8krTWvpXkE0melOSFq10YAPRkXt/RP2do71ph/t1JzkpydJKPrbSS4ch9OcdMXxoA9GNeR/SHD+3DK8zfNf2IVagFALq1Vq+jr6Hd4wkEK31v4Tt6AJiY1xH9riP2w1eYf9iS5QCAKcwr6L80tEevMP9nh3al7/ABgH0wr6C/aWjPqqp/U8Nwed2pSb6T5FOrXRgA9GQuQd9a+4ckH83khv0XLZn9u0nWJfnzaa6hBwB+aJ4n470hk1vgvqOqXpTkziQnJTkjk4/s3zzH2gCgC3O7Be5wVH9ikmsyCfjLkzwryTuSnOw+9wAw3lwvr2utfSXJa+ZZAwD0bK6PqQUADixBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0LG5BX1Vba+qtsLr6/OqCwB6csicx384yZXLTP/2ahcCAD2ad9A/1FrbNOcaAKBbvqMHgI7N+4j+iVX1a0l+JsmjSW5PsrW19v35lgUAfZh30D8jyfuXTLuvql7TWrtlb52ratsKs44ZXRkAdGCeH93/WZIXZRL265L8QpKrk2xIckNVPW9+pQFAH6q1Nu8a/o2q+sMklyf5UGvtFVOuY1uS42daGACsvs+21k4Ys4K1eDLe5qHdONcqAKADazHodwzturlWAQAdWItBf/LQ3jvXKgCgA3MJ+qo6rqqessz0I5O8a3j7gdWtCgD6M6/L685N8ltVdVOS+5J8K8mzkvxKkkOTXJ/kD+dUGwB0Y15Bf1OS5yT5T5l8VL8uyUNJPp7JdfXvb2vtcgAAOAjNJeiHm+Hs9YY4AMA4a/FkPABgRgQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAx+byPHqA9evXj+p/+OGHz6iS/bdu3bqp+x577LGjxh6z3Vpro8Y+8cQTp+67cePGUWMfeeSRU/f9zGc+M2rsF7zgBaP6z5sjegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI55TC1k3KM/xzy6c6yxY7/yla+cuu+YR7UmydOe9rRR/Y844oip+459XGtVGXs//f3f//2osf/oj/5o6r5XXHHFqLEPdo7oAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjnkcPSS688MKp+27atGl2heynMc8HT8Y9n3zs2Dt27BjV/+677x7Vf17G/rtvvfXWqftu3bp11Nj333//1H3vuOOOUWMzPUf0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHRP0ANAxQQ8AHfOYWsi4x3du3rx51NhXX3311H2PPfbYUWOvX79+6r433HDDqLF37tw5qv9DDz00qj8sipkc0VfVOVX1zqq6taoeqapWVR/YS59Tqur6qnqwqh6rqtur6tKqesIsagIAZndE/5Ykz0vy7SRfTXLMnhauqpcnuS7Jd5N8MMmDSX41yRVJTk1y7ozqAoCFNqvv6C9LcnSSw5K8fk8LVtVhSf44yfeTnN5au6C19t+SPD/JJ5OcU1XnzaguAFhoMwn61tpNrbW7W2ttHxY/J8n6JFtaa5/ZbR3fzeSTgWQvfywAAPtmHmfdnzm0H15m3tYkjyU5paqeuHolAUCf5hH0zxnau5bOaK09nuS+TM4dOGo1iwKAHs3j8rrDh/bhFebvmn7E3lZUVdtWmLXHkwEBYFGsxRvm1NDuy/f9AMAezOOIftcR++ErzD9syXIraq2dsNz04Uj/+P0vDQD6Mo8j+i8N7dFLZ1TVIUmemeTxJPeuZlEA0KN5BP2NQ/uSZeZtTPKkJLe11r63eiUBQJ/mEfTXJtmZ5LyqOnHXxKo6NMnvD2/fO4e6AKA7M/mOvqrOTnL28PYZQ3tyVV0z/LyztfamJGmtPVJVv5FJ4N9cVVsyuQXuyzK59O7aTG6LCwCMNKuT8Z6f5Pwl047KD6+F/3KSN+2a0Vr7UFWdluTNSV6V5NAk9yR5Y5J37OMd9gCAvZhJ0LfWNiXZtJ99PpHkl2cxPgCwPM+jpwvvec97RvW/8MILp+77+c9/ftTY7373u6fuu2XLllFjA/1bizfMAQBmRNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMc8ppY146qrrpq675jHzCbJjh07pu571llnjRr7gQceGNUfYE8c0QNAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxzyPnjVj3bp1U/etqlFjP/3pT5+676c//elRY1977bVT9/3KV74yauw77rhj6r5/93d/N2psYHU4ogeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYx9SyZlxyySVT9/3mN785auxjjjlm6r4bNmwYNfYb3/jGqfuOfTxva23qvp/61KdGjT3WFVdcMXXfMY8GhoONI3oA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6FiNeR71WlVV25IcP+86YF+cdtppU/fduHHjqLF/7ud+buq+J5988qixjzzyyFH9x7juuutG9X/DG94wdd8HHnhg1NgsnM+21k4Ys4KZHNFX1TlV9c6qurWqHqmqVlUfWGHZDcP8lV5bZlETAJAcMqP1vCXJ85J8O8lXkxyzD30+n+RDy0z/woxqAoCFN6ugvyyTgL8nyWlJbtqHPp9rrW2a0fgAwDJmEvSttX8N9qqaxSoBgBmY1RH9NH6qql6X5KlJvpHkk6212+dYDwB0Z55B/+Lh9a+q6uYk57fW7t+XFQxn1y9nX84RAIDuzeM6+seSvC3JCUmePLx2fa9/epKPVdW6OdQFAN1Z9SP61tqOJL+zZPLWqjoryceTnJTktUmu2od1LXttoevoAWBizdwZr7X2eJL3DW/H3QUEAEiyhoJ+sOuWUT66B4AZWGtB/8KhvXeuVQBAJ1Y96KvqpKr6sWWmn5nJjXeSZNnb5wIA+2cmJ+NV1dlJzh7ePmNoT66qa4afd7bW3jT8/AdJjhsupfvqMO25Sc4cfn5ra+22WdQFAItuVmfdPz/J+UumHTW8kuTLSXYF/fuTvCLJLyZ5aZIfTfJPSf4yybtaa7fOqCYAWHizugXupiSb9nHZP0nyJ7MYFwDYM8+jB6by7Gc/e1T/yy67bO8L7cHrX//6qfuO/b130UUXTd138+bNo8Zm4ayN59EDAGuToAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjs3kefTA4rnnnntG9R/zqNck+fEf//Gp+7761a8eNfb69etH9YfV5IgeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADom6AGgY4IeADrmefTAXBx77LGj+m/cuHHqvq21UWM/8MADo/rDanJEDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DFBDwAdE/QA0DGPqQWmcvHFF4/qf+WVV86okv33tre9bVT/zZs3z6gSOPAc0QNAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxzyPnpm5/PLLR/W/4447pu57ww03jBr7YLVhw4ZR/d/85jdP3feCCy4YNXZrbVT/rVu3Tt336quvHjU2HExGH9FX1VOr6rVV9VdVdU9VfaeqHq6qj1fVBVW17BhVdUpVXV9VD1bVY1V1e1VdWlVPGFsTADAxiyP6c5O8N8nXktyU5P4kP5nklUnel+SlVXVu2+3P96p6eZLrknw3yQeTPJjkV5NckeTUYZ0AwEizCPq7krwsyd+21n6wa2JV/XaSTyd5VSahf90w/bAkf5zk+0lOb619Zpj+1iQ3Jjmnqs5rrW2ZQW0AsNBGf3TfWruxtfY3u4f8MP3rSTYPb0/fbdY5SdYn2bIr5Iflv5vkLcPb14+tCwA48Gfd/8vQPr7btDOH9sPLLL81yWNJTqmqJx7IwgBgERyws+6r6pAkvz683T3UnzO0dy3t01p7vKruS3JckqOS3LmXMbatMOuY/asWAPp0II/o357k55Nc31r7yG7TDx/ah1fot2v6EQeqMABYFAfkiL6qLk5yeZIvJnn1/nYf2r1eZNtaO2GF8bclOX4/xwWA7sz8iL6qLkpyVZI7kpzRWntwySK7jtgPz/IOW7IcADClmQZ9VV2a5F1JvpBJyH99mcW+NLRHL9P/kCTPzOTkvXtnWRsALKKZBX1V/WYmN7z5XCYhv2OFRW8c2pcsM29jkiclua219r1Z1QYAi2omQT/c7ObtSbYleVFrbeceFr82yc4k51XVibut49Akvz+8fe8s6gKARTf6ZLyqOj/J72Vyp7tbk1xcVUsX295auyZJWmuPVNVvZBL4N1fVlkxugfuyTC69uzaT2+ICACPN4qz7Zw7tE5JcusIytyS5Zteb1tqHquq0JG/O5Ba5hya5J8kbk7yjjX2sFQCQZAZB31rblGTTFP0+keSXx47P2vG0pz1tVP+/+Iu/mLrvEUfM77YLjz766Kj+O3fu6ZuuPTvyyCNHjT3G7bffPqr/JZdcMqr/LbfcMqo/LIoDfQtcAGCOBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHqrU27xpmrqq2JTl+3nWwf9avXz913/POO2+GleyfE044YW5jf/nLXx7V/84775y675YtW0aNDeyTz7bWRv2ScUQPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMY+pBYC1y2NqAYCVCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COCXoA6JigB4COjQ76qnpqVb22qv6qqu6pqu9U1cNV9fGquqCqfmTJ8huqqu3htWVsTQDAxCEzWMe5Sd6b5GtJbkpyf5KfTPLKJO9L8tKqOre11pb0+3ySDy2zvi/MoCYAILMJ+ruSvCzJ37bWfrBrYlX9dpJPJ3lVJqF/3ZJ+n2utbZrB+ADACkZ/dN9au7G19je7h/ww/etJNg9vTx87DgCw/2ZxRL8n/zK0jy8z76eq6nVJnprkG0k+2Vq7/QDXAwAL5YAFfVUdkuTXh7cfXmaRFw+v3fvcnOT81tr9B6ouAFgkB/KI/u1Jfj7J9a21j+w2/bEkb8vkRLx7h2nPTbIpyRlJPlZVz2+tPbq3Aapq2wqzjpm2aADoSf37k+FnsNKqi5NcleSLSU5trT24D30OSfLxJCclubS1dtU+9NlT0D9p3ysGgDXps621E8asYOZH9FV1USYhf0eSF+1LyCdJa+3xqnpfJkG/cVjH3vos+48f/gA4fp+LBoBOzfTOeFV1aZJ3ZXIt/BnDmff744GhXTfLugBgUc0s6KvqN5NckeRzmYT8jilW88KhvXePSwEA+2QmQV9Vb83k5LttmXxcv3MPy55UVT+2zPQzk1w2vP3ALOoCgEU3+jv6qjo/ye8l+X6SW5NcXFVLF9veWrtm+PkPkhw3XEr31WHac5OcOfz81tbabWPrAgBmczLeM4f2CUkuXWGZW5JcM/z8/iSvSPKLSV6a5EeT/FOSv0zyrtbarTOoCQDIAbq8bt6cdQ9AJ0ZfXud59ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB3rNeg3zLsAAJiBDWNXcMgMiliLHhna7SvMP2Zov3jgS+mGbTYd2206ttv+s82ms5a324b8MM+mVq218aUcZKpqW5K01k6Ydy0HC9tsOrbbdGy3/WebTWcRtluvH90DABH0ANA1QQ8AHRP0ANAxQQ8AHVvIs+4BYFE4ogeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAji1U0FfVT1fVn1bVP1bV96pqe1VdWVVPnndta9WwjdoKr6/Pu755qapzquqdVXVrVT0ybI8P7KXPKVV1fVU9WFWPVdXtVXVpVT1hteqet/3ZblW1YQ/7XquqLatd/zxU1VOr6rVV9VdVdU9VfaeqHq6qj1fVBVW17O/xRd/f9ne79by/9fo8+n+nqp6V5LYkT0/y15k8e/gFSS5J8pKqOrW19o05lriWPZzkymWmf3u1C1lD3pLkeZlsg6/mh8+0XlZVvTzJdUm+m+SDSR5M8qtJrkhyapJzD2Sxa8h+bbfB55N8aJnpX5hhXWvZuUnem+RrSW5Kcn+Sn0zyyiTvS/LSqjq37Xb3M/tbkim226C//a21thCvJB9J0pL81yXT/+cwffO8a1yLryTbk2yfdx1r7ZXkjCQ/m6SSnD7sQx9YYdnDkuxI8r0kJ+42/dBM/vhsSc6b979pDW63DcP8a+Zd95y32ZmZhPSPLJn+jEzCqyV51W7T7W/Tbbdu97eF+Oi+qo5KclYmofXuJbP/e5JHk7y6qtatcmkcpFprN7XW7m7Db4i9OCfJ+iRbWmuf2W0d383kCDdJXn8Aylxz9nO7kaS1dmNr7W9aaz9YMv3rSTYPb0/fbZb9LVNtt24tykf3Zw7tR5f5T/9WVX0ikz8EXpjkY6td3EHgiVX1a0l+JpM/im5PsrW19v35lnXQ2LX/fXiZeVuTPJbklKp6Ymvte6tX1kHjp6rqdUmemuQbST7ZWrt9zjWtFf8ytI/vNs3+tnfLbbddutvfFiXonzO0d60w/+5Mgv7oCPrlPCPJ+5dMu6+qXtNau2UeBR1kVtz/WmuPV9V9SY5LclSSO1ezsIPEi4fXv6qqm5Oc31q7fy4VrQFVdUiSXx/e7h7q9rc92MN226W7/W0hPrpPcvjQPrzC/F3Tj1iFWg42f5bkRZmE/bokv5Dk6ky+z7qhqp43v9IOGva/6TyW5G1JTkjy5OF1WiYnVp2e5GML/nXb25P8fJLrW2sf2W26/W3PVtpu3e5vixL0e1ND63vDJVprvzt81/VPrbXHWmtfaK1dmMlJjP8hyab5VtgF+98yWms7Wmu/01r7bGvtoeG1NZNP3/5vkmcnee18q5yPqro4yeWZXD306v3tPrQLt7/tabv1vL8tStDv+gv28BXmH7ZkOfZu18ksG+daxcHB/jdDrbXHM7k8KlnA/a+qLkpyVZI7kpzRWntwySL2t2Xsw3ZbVg/726IE/ZeG9ugV5v/s0K70HT7/3o6hPSg/ylplK+5/w/eFz8zkpKB7V7Oog9wDQ7tQ+19VXZrkXZlc033GcAb5Uva3JfZxu+3JQb2/LUrQ3zS0Zy1zN6SfyOQGEt9J8qnVLuwgdvLQLswvixFuHNqXLDNvY5InJbltgc+AnsYLh3Zh9r+q+s1MbnjzuUzCascKi9rfdrMf221PDur9bSGCvrX2D0k+mskJZBctmf27mfyV9uettUdXubQ1raqOq6qnLDP9yEz+Ok6SPd72lSTJtUl2Jjmvqk7cNbGqDk3y+8Pb986jsLWsqk6qqh9bZvqZSS4b3i7E/ldVb83kJLJtSV7UWtu5h8Xtb4P92W4972+1KPetWOYWuHcmOSmTO3XdleSU5ha4/0ZVbUryW5l8InJfkm8leVaSX8nkLlvXJ3lFa+2f51XjvFTV2UnOHt4+I8kvZfLX/q3DtJ2ttTctWf7aTG5JuiWTW5K+LJNLoa5N8l8W4SYy+7Pdhkuajktycya3y02S5+aH14m/tbW2K7i6VVXnJ7kmyfeTvDPLf7e+vbV2zW59Fn5/29/t1vX+Nu9b863mK8l/zORysa8l+eckX87k5IynzLu2tfjK5NKS/5XJGaoPZXKTiQeS/J9MrkOtedc4x22zKZOzlld6bV+mz6mZ/HH0zUy+Kvp/mRwpPGHe/561uN2SXJDkf2dyR8tvZ3JL1/szuXf7f573v2UNbbOW5Gb727jt1vP+tjBH9ACwiBbiO3oAWFSCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGP/H0cPrH483oxdAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 250, "width": 253 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xb, yb = next(iter(train_dl))\n", "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(9)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHGBJREFUeJzt3X2sbWV9J/Dvr2JhoF5AoyXEaa9YkIQWFWwRiAjYotgUoQKhaS1ptHE6RovVSRvUzm3ttDZpBt862lQqqSRiA9G2U1SmvAiCLfESRVoRKCBDxCLckRevUl6e+WOvW2+P59yXvfc965xnfz7JznP2WutZ68di5X732nutZ1VrLQBAn35o7AIAgD1H0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAx/Yau4A9oaruSrIhyd0jlwIA09qY5OHW2vNmWUmXQZ9JyD9zeAHAwhr1q/uqem5V/UVVfaOqHququ6vqvVV14Iyrvnse9QHAyO6edQWjndFX1fOT3JDkOUn+OsmtSX4myW8meVVVHd9ae3Cs+gCgB2Oe0f+vTEL+La2101trv9NaOznJBUlekOR/jFgbAHShWmurv9GqQ5L8SyZfSTy/tfbUdvOekeS+JJXkOa2170yx/s1JjppPtQAwmptaa0fPsoKxzuhPHtortg/5JGmtPZLk+iT7JnnpahcGAD0Z6zf6FwztbSvMvz3JKUkOS3LlSisZztyXc/j0pQFAP8Y6o99/aB9aYf626QesQi0A0K21eh99De0OLyBY6XcLv9EDwMRYZ/Tbztj3X2H+hiXLAQBTGCvovza0h60w/9ChXek3fABgF4wV9FcP7SlV9R9qGG6vOz7Jd5P8w2oXBgA9GSXoW2v/kuSKTAbsf9OS2b+XZL8kfznNPfQAwPeNeTHef81kCNz3V9Urknw1yTFJTsrkK/t3jFgbAHRhtCFwh7P6lyS5KJOAf1uS5yd5f5JjjXMPALMb9fa61tr/TfJrY9YAAD0b9TG1AMCeJegBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6NlrQV9XdVdVWeH1zrLoAoCd7jbz9h5K8d5npj652IQDQo7GD/tuttU0j1wAA3fIbPQB0bOwz+r2r6leS/FiS7yS5Ocm1rbUnxy0LAPowdtAflORjS6bdVVW/1lr73M46V9XmFWYdPnNlANCBMb+6/2iSV2QS9vsl+akkf5ZkY5JPV9ULxysNAPpQrbWxa/gPqupPkrwtyadaa2dMuY7NSY6aa2EAsPpuaq0dPcsK1uLFeB8e2hNGrQIAOrAWg/7+od1v1CoAoANrMeiPHdo7R60CADowStBX1RFV9cxlpv94kg8Oby9e3aoAoD9j3V53VpLfqaqrk9yV5JEkz0/y80n2SXJ5kj8ZqTYA6MZYQX91khckeXEmX9Xvl+TbST6fyX31H2tr7XYAAFiHRgn6YTCcnQ6IA+vBu9/97pn6v+Md75hTJbuvqqbuO+tn8Y9//OMz9f/iF784dd8LLrhgpm3DerIWL8YDAOZE0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHSsZn2m9FpUVZuTHDV2HayeDRs2zNT/4osvnrrvSSedNNO2991335n6s/s+8YlPzNT/3HPPnbrv448/PtO2WTg3tdaOnmUFzugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65jG1rBkHHnjg1H0vvPDCmbb9mte8Zqb+s7jrrrum7vtP//RPM2379ttvn7pvVc207bPPPnum/gcffPBM/Wdx7LHHTt33xhtvnGMlLACPqQUAViboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOrbX2AXANhdccMHUfcd8nvy99947U/9Xv/rVU/e97bbbZtr2mA466KCZ+p9zzjlzqmT3XXbZZVP3PeaYY2ba9je+8Y2Z+rN4nNEDQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0zGNqWTM2btw42rZnedTsK1/5ypm2vZ4fNbuoDj744Kn77rPPPnOsBHZuLmf0VXVmVX2gqq6rqoerqlXVxTvpc1xVXV5VW6pqa1XdXFXnVdXT5lETADC/M/p3JnlhkkeT3Jvk8B0tXFWvSXJZku8l+USSLUl+IckFSY5Pctac6gKAhTav3+jfmuSwJBuS/MaOFqyqDUn+PMmTSU5srb2+tfbfkrwoyReSnFlV58ypLgBYaHMJ+tba1a2121trbRcWPzPJs5Nc0lr74nbr+F4m3wwkO/mwAADsmjGuuj95aD+zzLxrk2xNclxV7b16JQFAn8YI+hcM7Q9catxaeyLJXZlcO3DIahYFAD0a4/a6/Yf2oRXmb5t+wM5WVFWbV5i1w4sBAWBRrMUBc2pod+X3fgBgB8Y4o992xr7/CvM3LFluRa21o5ebPpzpH7X7pQFAX8Y4o//a0B62dEZV7ZXkeUmeSHLnahYFAD0aI+ivGtpXLTPvhCT7JrmhtfbY6pUEAH0aI+gvTfJAknOq6iXbJlbVPkn+YHj7oRHqAoDuzOU3+qo6Pcnpw9uDhvbYqrpo+PuB1trbk6S19nBV/XomgX9NVV2SyRC4p2Vy692lmQyLCwDMaF4X470oyblLph2S798L//Ukb982o7X2qap6eZJ3JHltkn2S3JHkt5K8fxdH2AMAdmIuQd9a25Rk0272uT7Jq+exfQBgeZ5Hz9wceeSRM/U/9NBDp+77wAMPzLTtU089deq+t95660zbZv15/PHHp+771FNPzbES2Lm1OGAOADAngh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOiboAaBjgh4AOuYxtczN1q1bZ+r/2GOPTd137733nmnb++6770z9WSxf+cpXpu778MMPz7ES2Dln9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMc+jZ27uuOOOmfrfc889U/d92cteNtO2r7rqqqn7/uzP/uxM277xxhtn6j+LX/qlX5q67/nnnz/Ttg8//PCZ+o/pqKOOmrrvAQccMNO2t2zZMlN/Fo8zegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI55TC1rxh/+4R9O3ffTn/70TNveb7/9pu57xRVXzLTtRx55ZKb+s3j2s589dd+nP/3pc6wE2FOc0QNAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxwQ9AHRM0ANAxzyPnjXj+uuvn7rvaaedNtO2/+Zv/mbqvs94xjNm2vas/Wfx4IMPTt33pptummnbV1111Uz9/+iP/mim/rAo5nJGX1VnVtUHquq6qnq4qlpVXbzCshuH+Su9LplHTQDA/M7o35nkhUkeTXJvksN3oc+Xk3xqmem3zKkmAFh48wr6t2YS8HckeXmSq3ehz5daa5vmtH0AYBlzCfrW2r8He1XNY5UAwByMeTHewVX1xiTPSvJgki+01m4esR4A6M6YQf9zw+vfVdU1Sc5trd2zKyuoqs0rzNqVawQAoHtj3Ee/Ncm7kxyd5MDhte13/ROTXFlV+41QFwB0Z9XP6Ftr9yf53SWTr62qU5J8PskxSd6Q5H27sK6jl5s+nOkfNWOpALDurZmR8VprTyT5yPD2hDFrAYBerJmgH3xraH11DwBzsNaC/qVDe+eoVQBAJ1Y96KvqmKr64WWmn5zJwDtJsuzwuQDA7pnLxXhVdXqS04e3Bw3tsVV10fD3A621tw9//3GSI4Zb6e4dph2Z5OTh73e11m6YR10AsOjmddX9i5Kcu2TaIcMrSb6eZFvQfyzJGUl+OsmpSZ6e5F+T/FWSD7bWrptTTQCw8OY1BO6mJJt2cdkLk1w4j+0CADtWrbWxa5g799Evnn333Xem/scdd9zUfU855ZSZtn3llVdO3ffJJ5+cadtbtmyZuu+sz6N/7nOfO1P/r3/96zP1H8uhhx46U/8773St8oK5aaUxY3bVWrvqHgCYI0EPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB2by/PoYWxbt26dqf/f//3fj9J3kc36iN1HH3106r4/8iM/MtO2Z7Fhw4bRts1ickYPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB3zPHpgFPfdd99M/T/60Y9O3ffNb37zTNuexfnnnz9T/7PPPntOlbAonNEDQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0TNADQMcEPQB0bK+xCwBYJIcccshM/Z/znOdM3ff++++fadusTzOf0VfVs6rqDVX1yaq6o6q+W1UPVdXnq+r1VbXsNqrquKq6vKq2VNXWqrq5qs6rqqfNWhMAMDGPM/qzknwoyX1Jrk5yT5IfTfKLST6S5NSqOqu11rZ1qKrXJLksyfeSfCLJliS/kOSCJMcP6wQAZjSPoL8tyWlJ/q619tS2iVV1fpIbk7w2k9C/bJi+IcmfJ3kyyYmttS8O09+V5KokZ1bVOa21S+ZQGwAstJm/um+tXdVa+9vtQ36Y/s0kHx7enrjdrDOTPDvJJdtCflj+e0neObz9jVnrAgD2/FX3jw/tE9tNO3loP7PM8tcm2ZrkuKrae08WBgCLYI9ddV9VeyX51eHt9qH+gqG9bWmf1toTVXVXkiOSHJLkqzvZxuYVZh2+e9UCQJ/25Bn9e5L8ZJLLW2uf3W76/kP70Ar9tk0/YE8VBgCLYo+c0VfVW5K8LcmtSV63u92Htu1wqSSttaNX2P7mJEft5nYBoDtzP6OvqjcleV+Sf05yUmtty5JFtp2x75/lbViyHAAwpbkGfVWdl+SDSW7JJOS/ucxiXxvaw5bpv1eS52Vy8d6d86wNABbR3IK+qn47kwFvvpRJyK801uJVQ/uqZeadkGTfJDe01h6bV20AsKjmEvTDYDfvSbI5yStaaw/sYPFLkzyQ5Jyqesl269gnyR8Mbz80j7oAYNHNfDFeVZ2b5PczGenuuiRvqaqli93dWrsoSVprD1fVr2cS+NdU1SWZDIF7Wia33l2aybC4AMCM5nHV/fOG9mlJzlthmc8luWjbm9bap6rq5UnekckQufskuSPJbyV5//bj4gMA05s56Ftrm5JsmqLf9UlePev2gcV0yy23jF3CVF784hfP1H/jxo1T9/WY2sW0p4fABQBGJOgBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6Vq21sWuYu6ranOSosesA9py999576r5bt26dYyWr65d/+Zen7nvJJZfMsRJWyU2ttaNnWYEzegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI7tNXYBAOy6N77xjVP39ZjaxeSMHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA6JugBoGOCHgA65nn0wLr0xBNPTN33k5/85EzbPuOMM2bqD6vJGT0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHBD0AdEzQA0DHPKYWWJeefPLJqft++ctfnmnbYz6m9vLLLx9t26xPM5/RV9WzquoNVfXJqrqjqr5bVQ9V1eer6vVV9UNLlt9YVW0Hr0tmrQkAmJjHGf1ZST6U5L4kVye5J8mPJvnFJB9JcmpVndVaa0v6fTnJp5ZZ3y1zqAkAyHyC/rYkpyX5u9baU9smVtX5SW5M8tpMQv+yJf2+1FrbNIftAwArmPmr+9baVa21v90+5Ifp30zy4eHtibNuBwDYfXv6YrzHh/aJZeYdXFVvTPKsJA8m+UJr7eY9XA8ALJQ9FvRVtVeSXx3efmaZRX5ueG3f55ok57bW7tlTdQHAItmTZ/TvSfKTSS5vrX12u+lbk7w7kwvx7hymHZlkU5KTklxZVS9qrX1nZxuoqs0rzDp82qIBoCd7ZMCcqnpLkrcluTXJ67af11q7v7X2u621m1pr3x5e1yY5Jck/JvmJJG/YE3UBwKKZ+xl9Vb0pyfuS/HOSV7TWtuxKv9baE1X1kSTHJDlhWMfO+hy9Qg2bkxy1y0UDQKfmekZfVecl+WAm98KfNFx5vzu+NbT7zbMuAFhUcwv6qvrtJBck+VImIX//FKt56dDeucOlAIBdMpegr6p3ZXLx3eZMvq5/YAfLHlNVP7zM9JOTvHV4e/E86gKARTfzb/RVdW6S30/yZJLrkrylqpYudndr7aLh7z9OcsRwK929w7Qjk5w8/P2u1toNs9YFAMznYrznDe3Tkpy3wjKfS3LR8PfHkpyR5KeTnJrk6Un+NclfJflga+26OdQEAGQOQT+MV79pN5a/MMmFs24XANi5+sGHyq1/bq8DoBM3rXQr+a7aIwPmAABrg6AHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDomKAHgI4JegDoWK9Bv3HsAgBgDjbOuoK95lDEWvTw0N69wvzDh/bWPV9KN+yz6dhv07Hfdp99Np21vN825vt5NrVqrc1eyjpTVZuTpLV29Ni1rBf22XTst+nYb7vPPpvOIuy3Xr+6BwAi6AGga4IeADom6AGgY4IeADq2kFfdA8CicEYPAB0T9ADQMUEPAB0T9ADQMUEPAB0T9ADQMUEPAB1bqKCvqudW1V9U1Teq6rGquruq3ltVB45d21o17KO2wuubY9c3lqo6s6o+UFXXVdXDw/64eCd9jquqy6tqS1Vtraqbq+q8qnraatU9tt3Zb1W1cQfHXquqS1a7/jFU1bOq6g1V9cmquqOqvltVD1XV56vq9VW17L/ji3687e5+6/l46/V59D+gqp6f5IYkz0ny15k8e/hnkvxmkldV1fGttQdHLHEteyjJe5eZ/uhqF7KGvDPJCzPZB/fm+8+0XlZVvSbJZUm+l+QTSbYk+YUkFyQ5PslZe7LYNWS39tvgy0k+tcz0W+ZY11p2VpIPJbkvydVJ7knyo0l+MclHkpxaVWe17UY/c7wlmWK/Dfo73lprC/FK8tkkLcmbl0z/n8P0D49d41p8Jbk7yd1j17HWXklOSnJokkpy4nAMXbzCshuS3J/ksSQv2W76Ppl8+GxJzhn7v2kN7reNw/yLxq575H12ciYh/UNLph+USXi1JK/dbrrjbbr91u3xthBf3VfVIUlOySS0/nTJ7P+e5DtJXldV+61yaaxTrbWrW2u3t+FfiJ04M8mzk1zSWvviduv4XiZnuEnyG3ugzDVnN/cbSVprV7XW/ra19tSS6d9M8uHh7YnbzXK8Zar91q1F+er+5KG9Ypn/6Y9U1fWZfBB4aZIrV7u4dWDvqvqVJD+WyYeim5Nc21p7ctyy1o1tx99nlpl3bZKtSY6rqr1ba4+tXlnrxsFV9cYkz0ryYJIvtNZuHrmmteLxoX1iu2mOt51bbr9t093xtihB/4KhvW2F+bdnEvSHRdAv56AkH1sy7a6q+rXW2ufGKGidWfH4a609UVV3JTkiySFJvrqaha0TPze8/l1VXZPk3NbaPaNUtAZU1V5JfnV4u32oO952YAf7bZvujreF+Oo+yf5D+9AK87dNP2AVallvPprkFZmE/X5JfirJn2Xye9anq+qF45W2bjj+prM1ybuTHJ3kwOH18kwurDoxyZUL/nPbe5L8ZJLLW2uf3W66423HVtpv3R5vixL0O1ND63fDJVprvzf81vWvrbWtrbVbWmv/JZOLGP9Tkk3jVtgFx98yWmv3t9Z+t7V2U2vt28Pr2ky+ffvHJD+R5A3jVjmOqnpLkrdlcvfQ63a3+9Au3PG2o/3W8/G2KEG/7RPs/ivM37BkOXZu28UsJ4xaxfrg+Juj1toTmdwelSzg8VdVb0ryviT/nOSk1tqWJYs43paxC/ttWT0cb4sS9F8b2sNWmH/o0K70Gz4/6P6hXZdfZa2yFY+/4ffC52VyUdCdq1nUOvetoV2o46+qzkvywUzu6T5puIJ8KcfbEru433ZkXR9vixL0Vw/tKcuMhvSMTAaQ+G6Sf1jtwtaxY4d2Yf6xmMFVQ/uqZeadkGTfJDcs8BXQ03jp0C7M8VdVv53JgDdfyiSs7l9hUcfbdnZjv+3Iuj7eFiLoW2v/kuSKTC4ge9OS2b+Xyae0v2ytfWeVS1vTquqIqnrmMtN/PJNPx0myw2FfSZJcmuSBJOdU1Uu2TayqfZL8wfD2Q2MUtpZV1TFV9cPLTD85yVuHtwtx/FXVuzK5iGxzkle01h7YweKOt8Hu7Leej7dalHErlhkC96tJjslkpK7bkhzXDIH7H1TVpiS/k8k3IncleSTJ85P8fCajbF2e5IzW2r+NVeNYqur0JKcPbw9K8spMPu1fN0x7oLX29iXLX5rJkKSXZDIk6WmZ3Ap1aZKzF2EQmd3Zb8MtTUckuSaT4XKT5Mh8/z7xd7XWtgVXt6rq3CQXJXkyyQey/G/rd7fWLtquz8Ifb7u737o+3sYemm81X0n+cya3i92X5N+SfD2TizOeOXZta/GVya0lH8/kCtVvZzLIxLeS/J9M7kOtsWsccd9syuSq5ZVedy/T5/hMPhz9v0x+KvpKJmcKTxv7v2ct7rckr0/yvzMZ0fLRTIZ0vSeTsdtfNvZ/yxraZy3JNY632fZbz8fbwpzRA8AiWojf6AFgUQl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjgl6AOiYoAeAjv1/GTLG2l0KsBoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 250, "width": 253 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xb, yb = next(iter(train_dl))\n", "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1490, grad_fn=), tensor(0.9375))" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model, opt = get_model()\n", "fit()\n", "loss, acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "assert acc>0.7\n", "loss, acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorch's DataLoader\n", "\n", "PyTorch has its own `DataLoader`, `RandomSampler` (for training), and `SequentialSampler` (for validation) classes and we can use them to create our train/valid dataloaders like so:" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch.utils.data import DataLoader, SequentialSampler, RandomSampler" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)\n", "valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1246, grad_fn=), tensor(0.9531))" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model,opt = get_model()\n", "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch's defaults work fine in most cases. Note that if we pass `num_workers` to PyTorch's `DataLoader`, PyTorch will use multiple threads to call the `Dataset`." ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True)\n", "valid_dl = DataLoader(valid_ds, bs, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1928, grad_fn=), tensor(0.9531))" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model, opt = get_model()\n", "fit()\n", "\n", "loss, acc = loss_func(model(xb), yb), accuracy(model(xb), yb)\n", "assert acc>0.7\n", "loss, acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setting aside a Validation Set\n", "\n", "We should **always** have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/) in order to identify whether or not at some point during training our model begins to overfit.\n", "\n", "We'll write a training loop once more below, and print out the validation loss at the end of each epoch.\n", "\n", "Note that with PyTorch, you should be sure to call `model.train()` *before* training and then call `model.eval()` *before* inference. The reason is that the `nn.BatchNorm2d` and `nn.Dropout` layers' behavior is different depending on whether training or inference is being performed!" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "def fit (epochs, model, loss_func, opt, train_dl, valid_dl):\n", " for epoch in range(epochs):\n", " model.train() # Handle proper execution of bn and dropout at training.\n", " for xb, yb in train_dl:\n", " loss = loss_func(model(xb), yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " \n", " model.eval() # Handle proper execution of bn and dropout at inference.\n", " with torch.no_grad():\n", " tot_loss, tot_acc = 0., 0.\n", " for xb,yb in valid_dl:\n", " pred = model(xb)\n", " tot_loss += loss_func(pred, yb)\n", " tot_acc += accuracy(pred, yb)\n", " nv = len(valid_dl)\n", " print(epoch, tot_loss/nv, tot_acc/nv)\n", " return tot_loss/nv, tot_acc/nv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A question to think about: Will the validation metrics printed out here still be correct if batch size varies?\n", "\n", "And the answer is that owing to the way that the validation loss/accuracy are calculated the metrics will be incorrect if batch size varies. `tot_loss` and `tot_acc` are augmented each batch. After all batches, they are divided by the number of batches to get the average val loss/accuracy for the entire epoch.\n", "```\n", "for xb,yb in valid_dl:\n", " pred = model(xb)\n", " tot_loss += loss_func(pred, yb)\n", " tot_acc += accuracy(pred, yb)\n", " .\n", " .\n", " .\n", "return tot_loss/nv, tot_acc/nv\n", "```\n", "If batch size varies, and say the final batch is smaller than all the others, the loss/acc of the final batch will be *over*-weighted in the epoch's loss/acc metrics.\n", "\n", "Why? Say all batches are size 64 except for the final batch, which is 16. Each batch up to the penultimate batch have an avg loss/acc (for that batch) that's divided by 64. The final batch's avg loss/acc is only divided by 16. \n", "\n", "Now, by averaging the average loss/acc of *each* of the individual batches over the total number of batches, our epoch loss/avg is essentially assuming that each each batch's avg loss/acc is calculated using the same sized denominator (batch size). In other words, `tot_loss`/`tot_acc` assumes that each pred/label pair (or each batch's average loss/acc) contributes equally to the overall epoch average loss/acc. However, we know that for the final batch, this isn't true. The denominator (batch size) is only 16 and because this isn't compensated for, the pred/label pairs in the final batch disproportionally sway the overall average epoch loss/acc.\n", "\n", "Here's a simple example to illustrate what's going on.\n", "\n", "Say we have three batches, with the first two of size 10 and the last of size 5. And if they have the following accuracies:\n", "$$\\frac{5}{10}, \\frac{5}{10}, \\frac{2}{5}$$ We would expect that the total combined accuracy of all samples would be: $$\\frac{5+5+2}{10+10+5} = \\frac{12}{25}$$ However, if we calculate the their combined average accuracy using the approach we used to calculate `tot_acc`, we'd calculate a combined average accuracy of $$\\frac{\\frac{5}{10} + \\frac{5}{10} + \\frac{2}{5}}{3} = \\frac{\\frac{14}{10}}{3} = \\frac{14}{30}$$ Immediately we notice that $\\frac{14}{30}$ is *less than* $\\frac{12}{25}$. \n", "\n", "In other words, the 3rd batch's lower accuracy is exerting *too large* an effect on the overall average accuracy calculation -- it is lower, just slightly, than it ought to be. The batch's accuracy of $\\frac{2}{5}$ should only contribute to $\\frac{5}{25}$ of the epoch's average accuracy (since the batch has only 5 of the 25 total samples), yet our misguided calculation has it influencing $\\frac{1}{3}$ of the epoch's average accuracy.\n", "\n", "The proper way to calculate the epoch's average accuracy that takes into account the 3rd batch's smaller size relative to the first two batches would be to calculate a weighted average: $$\\frac{5}{10}*\\frac{10}{25} + \\frac{5}{10}*\\frac{10}{25} + \\frac{2}{5}*\\frac{5}{25} = \\frac{10}{25} + \\frac{2}{25} = \\frac{12}{25}$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`get_dls` will return dataloaders for the training and validation sets:" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "#export \n", "def get_dls(train_ds, valid_ds, bs, **kwargs):\n", " return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs), \n", " DataLoader(valid_ds, batch_size=bs*2, **kwargs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the validation set has `10,000` items and because we're using a batch size of 128 for running inference on our validation set, if we don't explicitly set `drop_last=True` in our validation data loader, our calculations for validation loss/accuracy will be slightly skewed. \n", "\n", "The last batch will have a size of only `16`. And as explained above, the last batch's loss/acc metrics will thus sway the overall totals more than they should. For the instructional purposes of this notebook, we won't worry about accounting for this in our validation loss/acc calculations at this point." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `get_dls()` method allows us to now create dataloaders and fit the model, in only *three lines of code*." ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(0.2808) tensor(0.9006)\n", "1 tensor(0.1229) tensor(0.9634)\n", "2 tensor(0.1326) tensor(0.9616)\n", "3 tensor(0.1114) tensor(0.9670)\n", "4 tensor(0.1046) tensor(0.9691)\n" ] } ], "source": [ "train_dl, valid_dl = get_dls(train_ds, valid_ds, bs)\n", "model, opt = get_model()\n", "loss, acc = fit(5, model, loss_func, opt, train_dl, valid_dl)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [], "source": [ "assert acc>0.9" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Export" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 03_minibatch_training_my_reimplementation.ipynb to nb_03.py\r\n" ] } ], "source": [ "!python notebook2script_my_reimplementation.py 03_minibatch_training_my_reimplementation.ipynb" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }