{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt\n", "from pathlib import Path\n", "from torch import tensor,nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastcore.test import test_close\n", "\n", "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", "torch.manual_seed(1)\n", "mpl.rcParams['image.cmap'] = 'gray'\n", "\n", "path_data = Path('data')\n", "path_gz = path_data/'mnist.pkl.gz'\n", "with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", "x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n,m = x_train.shape\n", "c = y_train.max()+1\n", "nh = 50" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]\n", " \n", " def __call__(self, x):\n", " for l in self.layers: x = l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([50000, 10])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Model(m, nh, 10)\n", "pred = model(x_train)\n", "pred.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cross entropy loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will need to compute the softmax of our activations. This is defined by:\n", "\n", "$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \\cdots + e^{x_{n-1}}}$$\n", "\n", "or more concisely:\n", "\n", "$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{\\sum\\limits_{0 \\leq j \\lt n} e^{x_{j}}}$$ \n", "\n", "In practice, we will need the log of the softmax when we calculate the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", " [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", " [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", " ...,\n", " [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", " [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", " [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_softmax(pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the formula \n", "\n", "$$\\log \\left ( \\frac{a}{b} \\right ) = \\log(a) - \\log(b)$$ \n", "\n", "gives a simplification when we compute the log softmax:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:\n", "\n", "$$\\log \\left ( \\sum_{j=1}^{n} e^{x_{j}} \\right ) = \\log \\left ( e^{a} \\sum_{j=1}^{n} e^{x_{j}-a} \\right ) = a + \\log \\left ( \\sum_{j=1}^{n} e^{x_{j}-a} \\right )$$\n", "\n", "where a is the maximum of the $x_{j}$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def logsumexp(x):\n", " m = x.max(-1)[0]\n", " return m + (x-m[:,None]).exp().sum(-1).log()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", " [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", " [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", " ...,\n", " [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", " [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", " [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_close(logsumexp(pred), pred.logsumexp(-1))\n", "sm_pred = log_softmax(pred)\n", "sm_pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cross entropy loss for some target $x$ and some prediction $p(x)$ is given by:\n", "\n", "$$ -\\sum x\\, \\log p(x) $$\n", "\n", "But since our $x$s are 1-hot encoded (actually, they're just the integer indices), this can be rewritten as $-\\log(p_{i})$ where i is the index of the desired target.\n", "\n", "This can be done using numpy-style [integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#integer-array-indexing). Note that PyTorch supports all the tricks in the advanced indexing methods discussed in that link." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 0, 4])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-2.20, grad_fn=),\n", " tensor(-2.37, grad_fn=),\n", " tensor(-2.36, grad_fn=))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sm_pred[0,5],sm_pred[1,0],sm_pred[2,4]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.20, -2.37, -2.36], grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sm_pred[[0,1,2], y_train[:3]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def nll(input, target): return -input[range(target.shape[0]), target].mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.30, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = nll(sm_pred, y_train)\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then use PyTorch's implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_close(F.cross_entropy(pred, y_train), loss, 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basically the training loop repeats over the following steps:\n", "- get the output of the model on a batch of inputs\n", "- compare the output to the labels we have and compute a loss\n", "- calculate the gradients of the loss with respect to every parameter of the model\n", "- update said parameters with those gradients to make them a little bit better" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func = F.cross_entropy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([-0.09, -0.21, -0.08, 0.10, -0.04, 0.08, -0.04, -0.03, 0.01, 0.06], grad_fn=),\n", " torch.Size([50, 10]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs=50 # batch size\n", "\n", "xb = x_train[0:bs] # a mini-batch from x\n", "preds = model(xb) # predictions\n", "preds[0], preds.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9,\n", " 3, 9, 8, 5, 9, 3])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yb = y_train[0:bs]\n", "yb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.30, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_func(preds, yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3, 9, 3, 8, 5, 9, 3, 9, 3, 9, 5, 3, 9, 9, 3, 9, 9, 5, 8, 7, 9, 5, 3, 8, 9, 5, 9, 5, 5, 9, 3, 5, 9, 7, 5, 7, 9, 9, 3, 9, 3, 5, 3, 8,\n", " 3, 5, 9, 5, 9, 5])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds.argmax(dim=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.08)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(preds, yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 0.5 # learning rate\n", "epochs = 3 # how many epochs to train for" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def report(loss, preds, yb): print(f'{loss:.2f}, {accuracy(preds, yb):.2f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.30, 0.08\n" ] } ], "source": [ "xb,yb = x_train[:bs],y_train[:bs]\n", "preds = model(xb)\n", "report(loss_func(preds, yb), preds, yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.12, 0.98\n", "0.12, 0.94\n", "0.08, 0.96\n" ] } ], "source": [ "for epoch in range(epochs):\n", " for i in range(0, n, bs):\n", " s = slice(i, min(n,i+bs))\n", " xb,yb = x_train[s],y_train[s]\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " with torch.no_grad():\n", " for l in model.layers:\n", " if hasattr(l, 'weight'):\n", " l.weight -= l.weight.grad * lr\n", " l.bias -= l.bias.grad * lr\n", " l.weight.grad.zero_()\n", " l.bias .grad.zero_()\n", " report(loss, preds, yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using parameters and optim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Module(\n", " (foo): Linear(in_features=3, out_features=4, bias=True)\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m1 = nn.Module()\n", "m1.foo = nn.Linear(3,4)\n", "m1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('foo', Linear(in_features=3, out_features=4, bias=True))]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(m1.named_children())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m1.named_children()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor([[ 0.57, 0.43, -0.30],\n", " [ 0.13, -0.32, -0.24],\n", " [ 0.51, 0.04, 0.22],\n", " [ 0.13, -0.17, -0.24]], requires_grad=True),\n", " Parameter containing:\n", " tensor([-0.01, -0.51, -0.39, 0.56], requires_grad=True)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(m1.parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.l1 = nn.Linear(n_in,nh)\n", " self.l2 = nn.Linear(nh,n_out)\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x): return self.l2(self.relu(self.l1(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Linear(in_features=784, out_features=50, bias=True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = MLP(m, nh, 10)\n", "model.l1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MLP(\n", " (l1): Linear(in_features=784, out_features=50, bias=True)\n", " (l2): Linear(in_features=50, out_features=10, bias=True)\n", " (relu): ReLU()\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "l1: Linear(in_features=784, out_features=50, bias=True)\n", "l2: Linear(in_features=50, out_features=10, bias=True)\n", "relu: ReLU()\n" ] } ], "source": [ "for name,l in model.named_children(): print(f\"{name}: {l}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([50, 784])\n", "torch.Size([50])\n", "torch.Size([10, 50])\n", "torch.Size([10])\n" ] } ], "source": [ "for p in model.parameters(): print(p.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for i in range(0, n, bs):\n", " s = slice(i, min(n,i+bs))\n", " xb,yb = x_train[s],y_train[s]\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): p -= p.grad * lr\n", " model.zero_grad()\n", " report(loss, preds, yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.19, 0.96\n", "0.11, 0.96\n", "0.04, 1.00\n" ] } ], "source": [ "fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Behind the scenes, PyTorch overrides the `__setattr__` function in `nn.Module` so that the submodules you define are properly registered as parameters of the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MyModule:\n", " def __init__(self, n_in, nh, n_out):\n", " self._modules = {}\n", " self.l1 = nn.Linear(n_in,nh)\n", " self.l2 = nn.Linear(nh,n_out)\n", "\n", " def __setattr__(self,k,v):\n", " if not k.startswith(\"_\"): self._modules[k] = v\n", " super().__setattr__(k,v)\n", "\n", " def __repr__(self): return f'{self._modules}'\n", " \n", " def parameters(self):\n", " for l in self._modules.values(): yield from l.parameters()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mdl = MyModule(m,nh,10)\n", "mdl" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([50, 784])\n", "torch.Size([50])\n", "torch.Size([10, 50])\n", "torch.Size([10])\n" ] } ], "source": [ "for p in mdl.parameters(): print(p.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Registering modules" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from functools import reduce" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the original `layers` approach, but we have to register the modules." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, layers):\n", " super().__init__()\n", " self.layers = layers\n", " for i,l in enumerate(self.layers): self.add_module(f'layer_{i}', l)\n", "\n", " def forward(self, x): return reduce(lambda val,layer: layer(val), self.layers, x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Model(\n", " (layer_0): Linear(in_features=784, out_features=50, bias=True)\n", " (layer_1): ReLU()\n", " (layer_2): Linear(in_features=50, out_features=10, bias=True)\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Model(layers)\n", "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([50, 10])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(xb).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### nn.ModuleList" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nn.ModuleList` does this for us." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SequentialModel(nn.Module):\n", " def __init__(self, layers):\n", " super().__init__()\n", " self.layers = nn.ModuleList(layers)\n", " \n", " def forward(self, x):\n", " for l in self.layers: x = l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SequentialModel(\n", " (layers): ModuleList(\n", " (0): Linear(in_features=784, out_features=50, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=50, out_features=10, bias=True)\n", " )\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = SequentialModel(layers)\n", "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.12, 0.96\n", "0.11, 0.96\n", "0.07, 0.98\n" ] } ], "source": [ "fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### nn.Sequential" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nn.Sequential` is a convenient class which does the same as the above:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.16, 0.94\n", "0.13, 0.96\n", "0.08, 0.96\n" ] }, { "data": { "text/plain": [ "(tensor(0.03, grad_fn=), tensor(1.))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): Linear(in_features=784, out_features=50, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=50, out_features=10, bias=True)\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### optim" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Optimizer():\n", " def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr\n", "\n", " def step(self):\n", " with torch.no_grad():\n", " for p in self.params: p -= p.grad * self.lr\n", "\n", " def zero_grad(self):\n", " for p in self.params: p.grad.data.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt = Optimizer(model.parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.18, 0.94\n", "0.13, 0.96\n", "0.11, 0.94\n" ] } ], "source": [ "for epoch in range(epochs):\n", " for i in range(0, n, bs):\n", " s = slice(i, min(n,i+bs))\n", " xb,yb = x_train[s],y_train[s]\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " report(loss, preds, yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch already provides this exact functionality in `optim.SGD` (it also handles stuff like momentum, which we'll look at later)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch import optim" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_model():\n", " model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))\n", " return model, optim.SGD(model.parameters(), lr=lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.33, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model,opt = get_model()\n", "loss_func(model(xb), yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.12, 0.98\n", "0.09, 0.98\n", "0.07, 0.98\n" ] } ], "source": [ "for epoch in range(epochs):\n", " for i in range(0, n, bs):\n", " s = slice(i, min(n,i+bs))\n", " xb,yb = x_train[s],y_train[s]\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " report(loss, preds, yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset and DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's clunky to iterate through minibatches of x and y values separately:\n", "\n", "```python\n", " xb = x_train[s]\n", " yb = y_train[s]\n", "```\n", "\n", "Instead, let's do these two steps together, by introducing a `Dataset` class:\n", "\n", "```python\n", " xb,yb = train_ds[s]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class Dataset():\n", " def __init__(self, x, y): self.x,self.y = x,y\n", " def __len__(self): return len(self.x)\n", " def __getitem__(self, i): return self.x[i],self.y[i]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n", "assert len(train_ds)==len(x_train)\n", "assert len(valid_ds)==len(x_valid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]),\n", " tensor([5, 0, 4, 1, 9]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb,yb = train_ds[0:5]\n", "assert xb.shape==(5,28*28)\n", "assert yb.shape==(5,)\n", "xb,yb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model,opt = get_model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.17, 0.96\n", "0.11, 0.94\n", "0.09, 0.96\n" ] } ], "source": [ "for epoch in range(epochs):\n", " for i in range(0, n, bs):\n", " xb,yb = train_ds[i:min(n,i+bs)]\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " report(loss, preds, yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously, our loop iterated over batches (xb, yb) like this:\n", "\n", "```python\n", "for i in range(0, n, bs):\n", " xb,yb = train_ds[i:min(n,i+bs)]\n", " ...\n", "```\n", "\n", "Let's make our loop much cleaner, using a data loader:\n", "\n", "```python\n", "for xb,yb in train_dl:\n", " ...\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DataLoader():\n", " def __init__(self, ds, bs): self.ds,self.bs = ds,bs\n", " def __iter__(self):\n", " for i in range(0, len(self.ds), self.bs): yield self.ds[i:i+self.bs]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs)\n", "valid_dl = DataLoader(valid_ds, bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([50, 784])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb,yb = next(iter(valid_dl))\n", "xb.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3, 8, 6, 9, 6, 4, 5, 3, 8, 4, 5, 2, 3, 8, 4, 8, 1, 5, 0, 5, 9, 7, 4, 1, 0, 3, 0, 6, 2, 9, 9, 4, 1, 3, 6, 8, 0, 7, 7, 6, 8, 9, 0, 3,\n", " 8, 3, 7, 7, 8, 4])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZ1UlEQVR4nO3df2zUdx3H8dcB5ShwPVehvbsBtWEQzcBmAwQafkcamozwYyZsi6YYQzb5kWDBZYCGOhNKSIYk1rG4KINsKHFjSAIOqtDCRAwjXYY4sZNiq1AbKt6VXyWMj38QLru1FL7HHe9e+3wkn4T7fr9vvm++fMOLT+97n/M555wAADDQx7oBAEDvRQgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADATD/rBj7v1q1bOn/+vAKBgHw+n3U7AACPnHNqa2tTJBJRnz5dz3W6XQidP39ew4cPt24DAPCAmpqaNGzYsC6P6XY/jgsEAtYtAABS4H7+PU9bCL366qsqLCzUgAEDNG7cOB09evS+6vgRHAD0DPfz73laQmjXrl1auXKl1q1bp7q6Ok2dOlWlpaVqbGxMx+kAABnKl45VtCdOnKgnn3xSW7dujW/7yle+ovnz56uysrLL2lgspmAwmOqWAAAPWTQaVU5OTpfHpHwmdOPGDZ08eVIlJSUJ20tKSnTs2LEOx7e3tysWiyUMAEDvkPIQunjxoj799FPl5+cnbM/Pz1dzc3OH4ysrKxUMBuODJ+MAoPdI24MJn39DyjnX6ZtUa9asUTQajY+mpqZ0tQQA6GZS/jmhIUOGqG/fvh1mPS0tLR1mR5Lk9/vl9/tT3QYAIAOkfCbUv39/jRs3TtXV1Qnbq6urVVxcnOrTAQAyWFpWTCgvL9e3vvUtjR8/XpMnT9bPf/5zNTY26oUXXkjH6QAAGSotIbRo0SK1trbq5Zdf1oULFzRmzBjt379fBQUF6TgdACBDpeVzQg+CzwkBQM9g8jkhAADuFyEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzPSzbgC4l6KiIs813/ve95I618iRIz3XDBw40HPN2rVrPdcEg0HPNb/73e8810hSW1tbUnWAV8yEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmPE555x1E58Vi8WSWqgRmWHw4MGeaxobGz3XfOELX/Bc0xP9+9//TqoumQVg33777aTOhZ4rGo0qJyeny2OYCQEAzBBCAAAzKQ+hiooK+Xy+hBEKhVJ9GgBAD5CWL7V7/PHH9fvf/z7+um/fvuk4DQAgw6UlhPr168fsBwBwT2l5T6i+vl6RSESFhYV65plndPbs2bse297erlgsljAAAL1DykNo4sSJ2rFjhw4cOKDXX39dzc3NKi4uVmtra6fHV1ZWKhgMxsfw4cNT3RIAoJtKeQiVlpbq6aef1tixY/X1r39d+/btkyRt37690+PXrFmjaDQaH01NTaluCQDQTaXlPaHPGjRokMaOHav6+vpO9/v9fvn9/nS3AQDohtL+OaH29nZ9/PHHCofD6T4VACDDpDyEVq9erdraWjU0NOjPf/6zvvGNbygWi6msrCzVpwIAZLiU/zjuX//6l5599lldvHhRQ4cO1aRJk3T8+HEVFBSk+lQAgAzHAqZ4qAKBgOea/fv3e66529OY91JXV+e55oknnvBck8x/ypJ5cjQ7O9tzjST95z//8VwzefLkh3IeZA4WMAUAdGuEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMpP1L7YDPamtr81wzderUNHSSeYYMGeK55vvf/35S50qmbs6cOZ5r7vaNy+g9mAkBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMywijaQIS5evOi55o9//GNS50pmFe0nnnjCcw2raIOZEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADMsYApkiEceecRzzdq1a9PQSecikchDOxd6DmZCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzLCAKWCgqKjIc81vfvMbzzWPPfaY5xpJ+vvf/+65ZtWqVUmdC70bMyEAgBlCCABgxnMIHTlyRHPnzlUkEpHP59OePXsS9jvnVFFRoUgkouzsbM2YMUOnT59OVb8AgB7EcwhduXJFRUVFqqqq6nT/pk2btHnzZlVVVenEiRMKhUKaPXu22traHrhZAEDP4vnBhNLSUpWWlna6zzmnLVu2aN26dVq4cKEkafv27crPz9fOnTv1/PPPP1i3AIAeJaXvCTU0NKi5uVklJSXxbX6/X9OnT9exY8c6rWlvb1csFksYAIDeIaUh1NzcLEnKz89P2J6fnx/f93mVlZUKBoPxMXz48FS2BADoxtLydJzP50t47ZzrsO2ONWvWKBqNxkdTU1M6WgIAdEMp/bBqKBSSdHtGFA6H49tbWlo6zI7u8Pv98vv9qWwDAJAhUjoTKiwsVCgUUnV1dXzbjRs3VFtbq+Li4lSeCgDQA3ieCV2+fFmffPJJ/HVDQ4M+/PBD5ebmasSIEVq5cqU2bNigUaNGadSoUdqwYYMGDhyo5557LqWNAwAyn+cQ+uCDDzRz5sz46/LycklSWVmZ3njjDb344ou6du2ali5dqkuXLmnixIk6ePCgAoFA6roGAPQIPuecs27is2KxmILBoHUbwH0rKyvzXPPyyy97rknmydFr1655rpGkp556ynPN4cOHkzoXeq5oNKqcnJwuj2HtOACAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAmZR+syrQXQwePDiputWrV3uu+cEPfuC5pk8f7///++9//+u5ZsqUKZ5rJOlvf/tbUnWAV8yEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmGEBU/RIb7zxRlJ1CxcuTG0jd/H22297rtmyZYvnGhYiRXfHTAgAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZFjBFjzRy5EjrFrq0detWzzXHjh1LQyeALWZCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzLCAKXqkgwcPJlVXVFSU4k46l0x/ySx6unHjRs81knT+/Pmk6gCvmAkBAMwQQgAAM55D6MiRI5o7d64ikYh8Pp/27NmTsH/x4sXy+XwJY9KkSanqFwDQg3gOoStXrqioqEhVVVV3PWbOnDm6cOFCfOzfv/+BmgQA9EyeH0woLS1VaWlpl8f4/X6FQqGkmwIA9A5peU+opqZGeXl5Gj16tJYsWaKWlpa7Htve3q5YLJYwAAC9Q8pDqLS0VG+99ZYOHTqkV155RSdOnNCsWbPU3t7e6fGVlZUKBoPxMXz48FS3BADoplL+OaFFixbFfz1mzBiNHz9eBQUF2rdvnxYuXNjh+DVr1qi8vDz+OhaLEUQA0Euk/cOq4XBYBQUFqq+v73S/3++X3+9PdxsAgG4o7Z8Tam1tVVNTk8LhcLpPBQDIMJ5nQpcvX9Ynn3wSf93Q0KAPP/xQubm5ys3NVUVFhZ5++mmFw2GdO3dOa9eu1ZAhQ7RgwYKUNg4AyHyeQ+iDDz7QzJkz46/vvJ9TVlamrVu36tSpU9qxY4f+97//KRwOa+bMmdq1a5cCgUDqugYA9Ag+55yzbuKzYrGYgsGgdRvIcNnZ2UnVvfnmm55rxo0b57lmxIgRnmuS0dzcnFTdt7/9bc81Bw4cSOpc6Lmi0ahycnK6PIa14wAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZlhFG/iMAQMGeK7p18/7FxTHYjHPNQ/T9evXPdfc+VoXL1577TXPNcgcrKINAOjWCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmGEBU8DAV7/6Vc81P/nJTzzXzJw503NNshobGz3XfOlLX0p9I+g2WMAUANCtEUIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMMMCpnioBg4c6Lnm6tWraegk8zzyyCOea375y18mda558+YlVefVo48+6rnmwoULaegE6cACpgCAbo0QAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAICZftYNIHONHDnSc83777/vuWbfvn2ea/7yl794rpGSWxzzO9/5juearKwszzXJLPb52GOPea5J1j/+8Q/PNSxGCmZCAAAzhBAAwIynEKqsrNSECRMUCASUl5en+fPn68yZMwnHOOdUUVGhSCSi7OxszZgxQ6dPn05p0wCAnsFTCNXW1mrZsmU6fvy4qqurdfPmTZWUlOjKlSvxYzZt2qTNmzerqqpKJ06cUCgU0uzZs9XW1pby5gEAmc3Tgwnvvfdewutt27YpLy9PJ0+e1LRp0+Sc05YtW7Ru3TotXLhQkrR9+3bl5+dr586dev7551PXOQAg4z3Qe0LRaFSSlJubK0lqaGhQc3OzSkpK4sf4/X5Nnz5dx44d6/T3aG9vVywWSxgAgN4h6RByzqm8vFxTpkzRmDFjJEnNzc2SpPz8/IRj8/Pz4/s+r7KyUsFgMD6GDx+ebEsAgAyTdAgtX75cH330kX71q1912Ofz+RJeO+c6bLtjzZo1ikaj8dHU1JRsSwCADJPUh1VXrFihvXv36siRIxo2bFh8eygUknR7RhQOh+PbW1paOsyO7vD7/fL7/cm0AQDIcJ5mQs45LV++XLt379ahQ4dUWFiYsL+wsFChUEjV1dXxbTdu3FBtba2Ki4tT0zEAoMfwNBNatmyZdu7cqd/+9rcKBALx93mCwaCys7Pl8/m0cuVKbdiwQaNGjdKoUaO0YcMGDRw4UM8991xa/gAAgMzlKYS2bt0qSZoxY0bC9m3btmnx4sWSpBdffFHXrl3T0qVLdenSJU2cOFEHDx5UIBBIScMAgJ7D55xz1k18ViwWUzAYtG4D9+Gll17yXFNZWem5ppvdoilxtwd1uvIwr8Ply5c91yxYsMBzzR/+8AfPNcgc0WhUOTk5XR7D2nEAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADNJfbMqIElf/OIXrVvoVd555x3PNT/+8Y+TOldLS4vnmjvfLwZ4wUwIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGZ9zzlk38VmxWEzBYNC6DdyHrKwszzWzZs3yXPPNb37Tc00kEvFcI0nRaDSpOq9++tOfeq45evSo55qbN296rgFSJRqNKicnp8tjmAkBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwwwKmAIC0YAFTAEC3RggBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM55CqLKyUhMmTFAgEFBeXp7mz5+vM2fOJByzePFi+Xy+hDFp0qSUNg0A6Bk8hVBtba2WLVum48ePq7q6Wjdv3lRJSYmuXLmScNycOXN04cKF+Ni/f39KmwYA9Az9vBz83nvvJbzetm2b8vLydPLkSU2bNi2+3e/3KxQKpaZDAECP9UDvCUWjUUlSbm5uwvaamhrl5eVp9OjRWrJkiVpaWu76e7S3tysWiyUMAEDv4HPOuWQKnXOaN2+eLl26pKNHj8a379q1S4MHD1ZBQYEaGhr0wx/+UDdv3tTJkyfl9/s7/D4VFRX60Y9+lPyfAADQLUWjUeXk5HR9kEvS0qVLXUFBgWtqauryuPPnz7usrCz3zjvvdLr/+vXrLhqNxkdTU5OTxGAwGIwMH9Fo9J5Z4uk9oTtWrFihvXv36siRIxo2bFiXx4bDYRUUFKi+vr7T/X6/v9MZEgCg5/MUQs45rVixQu+++65qampUWFh4z5rW1lY1NTUpHA4n3SQAoGfy9GDCsmXL9Oabb2rnzp0KBAJqbm5Wc3Ozrl27Jkm6fPmyVq9erT/96U86d+6campqNHfuXA0ZMkQLFixIyx8AAJDBvLwPpLv83G/btm3OOeeuXr3qSkpK3NChQ11WVpYbMWKEKysrc42Njfd9jmg0av5zTAaDwWA8+Lif94SSfjouXWKxmILBoHUbAIAHdD9Px7F2HADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADATLcLIeecdQsAgBS4n3/Pu10ItbW1WbcAAEiB+/n33Oe62dTj1q1bOn/+vAKBgHw+X8K+WCym4cOHq6mpSTk5OUYd2uM63MZ1uI3rcBvX4bbucB2cc2pra1MkElGfPl3Pdfo9pJ7uW58+fTRs2LAuj8nJyenVN9kdXIfbuA63cR1u4zrcZn0dgsHgfR3X7X4cBwDoPQghAICZjAohv9+v9evXy+/3W7diiutwG9fhNq7DbVyH2zLtOnS7BxMAAL1HRs2EAAA9CyEEADBDCAEAzBBCAAAzGRVCr776qgoLCzVgwACNGzdOR48etW7poaqoqJDP50sYoVDIuq20O3LkiObOnatIJCKfz6c9e/Yk7HfOqaKiQpFIRNnZ2ZoxY4ZOnz5t02wa3es6LF68uMP9MWnSJJtm06SyslITJkxQIBBQXl6e5s+frzNnziQc0xvuh/u5DplyP2RMCO3atUsrV67UunXrVFdXp6lTp6q0tFSNjY3WrT1Ujz/+uC5cuBAfp06dsm4p7a5cuaKioiJVVVV1un/Tpk3avHmzqqqqdOLECYVCIc2ePbvHrUN4r+sgSXPmzEm4P/bv3/8QO0y/2tpaLVu2TMePH1d1dbVu3rypkpISXblyJX5Mb7gf7uc6SBlyP7gM8bWvfc298MILCdu+/OUvu5deesmoo4dv/fr1rqioyLoNU5Lcu+++G39969YtFwqF3MaNG+Pbrl+/7oLBoHvttdcMOnw4Pn8dnHOurKzMzZs3z6QfKy0tLU6Sq62tdc713vvh89fBucy5HzJiJnTjxg2dPHlSJSUlCdtLSkp07Ngxo65s1NfXKxKJqLCwUM8884zOnj1r3ZKphoYGNTc3J9wbfr9f06dP73X3hiTV1NQoLy9Po0eP1pIlS9TS0mLdUlpFo1FJUm5urqTeez98/jrckQn3Q0aE0MWLF/Xpp58qPz8/YXt+fr6am5uNunr4Jk6cqB07dujAgQN6/fXX1dzcrOLiYrW2tlq3ZubO339vvzckqbS0VG+99ZYOHTqkV155RSdOnNCsWbPU3t5u3VpaOOdUXl6uKVOmaMyYMZJ65/3Q2XWQMud+6HaraHfl81/t4JzrsK0nKy0tjf967Nixmjx5skaOHKnt27ervLzcsDN7vf3ekKRFixbFfz1mzBiNHz9eBQUF2rdvnxYuXGjYWXosX75cH330kd5///0O+3rT/XC365Ap90NGzISGDBmivn37dvifTEtLS4f/8fQmgwYN0tixY1VfX2/dipk7Twdyb3QUDodVUFDQI++PFStWaO/evTp8+HDCV7/0tvvhbtehM931fsiIEOrfv7/GjRun6urqhO3V1dUqLi426spee3u7Pv74Y4XDYetWzBQWFioUCiXcGzdu3FBtbW2vvjckqbW1VU1NTT3q/nDOafny5dq9e7cOHTqkwsLChP295X6413XoTLe9HwwfivDk17/+tcvKynK/+MUv3F//+le3cuVKN2jQIHfu3Dnr1h6aVatWuZqaGnf27Fl3/Phx99RTT7lAINDjr0FbW5urq6tzdXV1TpLbvHmzq6urc//85z+dc85t3LjRBYNBt3v3bnfq1Cn37LPPunA47GKxmHHnqdXVdWhra3OrVq1yx44dcw0NDe7w4cNu8uTJ7tFHH+1R1+G73/2uCwaDrqamxl24cCE+rl69Gj+mN9wP97oOmXQ/ZEwIOefcz372M1dQUOD69+/vnnzyyYTHEXuDRYsWuXA47LKyslwkEnELFy50p0+ftm4r7Q4fPuwkdRhlZWXOuduP5a5fv96FQiHn9/vdtGnT3KlTp2ybToOursPVq1ddSUmJGzp0qMvKynIjRoxwZWVlrrGx0brtlOrszy/Jbdu2LX5Mb7gf7nUdMul+4KscAABmMuI9IQBAz0QIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMDM/wGB1/R+vec9fQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model,opt = get_model()" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for xb,yb in train_dl:\n", " preds = model(xb)\n", " loss = loss_func(preds, yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " report(loss, preds, yb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.11, 0.96\n", "0.08, 0.96\n", "0.06, 0.96\n" ] }, { "data": { "text/plain": [ "(tensor(0.03, grad_fn=), tensor(1.))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Random sampling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Sampler():\n", " def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle\n", " def __iter__(self):\n", " res = list(range(self.n))\n", " if self.shuffle: random.shuffle(res)\n", " return iter(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from itertools import islice" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ss = Sampler(train_ds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1\n", "2\n", "3\n", "4\n" ] } ], "source": [ "it = iter(ss)\n", "for o in range(5): print(next(it))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0, 1, 2, 3, 4]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(islice(ss, 5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[23241, 48397, 22315, 6710, 35026]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ss = Sampler(train_ds, shuffle=True)\n", "list(islice(ss, 5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastcore.all as fc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BatchSampler():\n", " def __init__(self, sampler, bs, drop_last=False): fc.store_attr()\n", " def __iter__(self): yield from fc.chunked(iter(self.sampler), self.bs, drop_last=self.drop_last)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[12257, 25230, 12990, 40776],\n", " [47107, 32853, 39167, 47750],\n", " [42548, 36662, 15709, 31127],\n", " [7489, 27472, 46202, 18724],\n", " [2249, 38893, 25321, 32710]]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batchs = BatchSampler(ss, 4)\n", "list(islice(batchs, 5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def collate(b):\n", " xs,ys = zip(*b)\n", " return torch.stack(xs),torch.stack(ys)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DataLoader():\n", " def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr()\n", " def __iter__(self): yield from (self.collate_fn(self.ds[i] for i in b) for b in self.batchs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs)\n", "valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, batchs=train_samp)\n", "valid_dl = DataLoader(valid_ds, batchs=valid_samp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZ1UlEQVR4nO3df2zUdx3H8dcB5ShwPVehvbsBtWEQzcBmAwQafkcamozwYyZsi6YYQzb5kWDBZYCGOhNKSIYk1rG4KINsKHFjSAIOqtDCRAwjXYY4sZNiq1AbKt6VXyWMj38QLru1FL7HHe9e+3wkn4T7fr9vvm++fMOLT+97n/M555wAADDQx7oBAEDvRQgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADATD/rBj7v1q1bOn/+vAKBgHw+n3U7AACPnHNqa2tTJBJRnz5dz3W6XQidP39ew4cPt24DAPCAmpqaNGzYsC6P6XY/jgsEAtYtAABS4H7+PU9bCL366qsqLCzUgAEDNG7cOB09evS+6vgRHAD0DPfz73laQmjXrl1auXKl1q1bp7q6Ok2dOlWlpaVqbGxMx+kAABnKl45VtCdOnKgnn3xSW7dujW/7yle+ovnz56uysrLL2lgspmAwmOqWAAAPWTQaVU5OTpfHpHwmdOPGDZ08eVIlJSUJ20tKSnTs2LEOx7e3tysWiyUMAEDvkPIQunjxoj799FPl5+cnbM/Pz1dzc3OH4ysrKxUMBuODJ+MAoPdI24MJn39DyjnX6ZtUa9asUTQajY+mpqZ0tQQA6GZS/jmhIUOGqG/fvh1mPS0tLR1mR5Lk9/vl9/tT3QYAIAOkfCbUv39/jRs3TtXV1Qnbq6urVVxcnOrTAQAyWFpWTCgvL9e3vvUtjR8/XpMnT9bPf/5zNTY26oUXXkjH6QAAGSotIbRo0SK1trbq5Zdf1oULFzRmzBjt379fBQUF6TgdACBDpeVzQg+CzwkBQM9g8jkhAADuFyEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzPSzbgC4l6KiIs813/ve95I618iRIz3XDBw40HPN2rVrPdcEg0HPNb/73e8810hSW1tbUnWAV8yEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmPE555x1E58Vi8WSWqgRmWHw4MGeaxobGz3XfOELX/Bc0xP9+9//TqoumQVg33777aTOhZ4rGo0qJyeny2OYCQEAzBBCAAAzKQ+hiooK+Xy+hBEKhVJ9GgBAD5CWL7V7/PHH9fvf/z7+um/fvuk4DQAgw6UlhPr168fsBwBwT2l5T6i+vl6RSESFhYV65plndPbs2bse297erlgsljAAAL1DykNo4sSJ2rFjhw4cOKDXX39dzc3NKi4uVmtra6fHV1ZWKhgMxsfw4cNT3RIAoJtKeQiVlpbq6aef1tixY/X1r39d+/btkyRt37690+PXrFmjaDQaH01NTaluCQDQTaXlPaHPGjRokMaOHav6+vpO9/v9fvn9/nS3AQDohtL+OaH29nZ9/PHHCofD6T4VACDDpDyEVq9erdraWjU0NOjPf/6zvvGNbygWi6msrCzVpwIAZLiU/zjuX//6l5599lldvHhRQ4cO1aRJk3T8+HEVFBSk+lQAgAzHAqZ4qAKBgOea/fv3e66529OY91JXV+e55oknnvBck8x/ypJ5cjQ7O9tzjST95z//8VwzefLkh3IeZA4WMAUAdGuEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMpP1L7YDPamtr81wzderUNHSSeYYMGeK55vvf/35S50qmbs6cOZ5r7vaNy+g9mAkBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMywijaQIS5evOi55o9//GNS50pmFe0nnnjCcw2raIOZEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADMsYApkiEceecRzzdq1a9PQSecikchDOxd6DmZCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzLCAKWCgqKjIc81vfvMbzzWPPfaY5xpJ+vvf/+65ZtWqVUmdC70bMyEAgBlCCABgxnMIHTlyRHPnzlUkEpHP59OePXsS9jvnVFFRoUgkouzsbM2YMUOnT59OVb8AgB7EcwhduXJFRUVFqqqq6nT/pk2btHnzZlVVVenEiRMKhUKaPXu22traHrhZAEDP4vnBhNLSUpWWlna6zzmnLVu2aN26dVq4cKEkafv27crPz9fOnTv1/PPPP1i3AIAeJaXvCTU0NKi5uVklJSXxbX6/X9OnT9exY8c6rWlvb1csFksYAIDeIaUh1NzcLEnKz89P2J6fnx/f93mVlZUKBoPxMXz48FS2BADoxtLydJzP50t47ZzrsO2ONWvWKBqNxkdTU1M6WgIAdEMp/bBqKBSSdHtGFA6H49tbWlo6zI7u8Pv98vv9qWwDAJAhUjoTKiwsVCgUUnV1dXzbjRs3VFtbq+Li4lSeCgDQA3ieCV2+fFmffPJJ/HVDQ4M+/PBD5ebmasSIEVq5cqU2bNigUaNGadSoUdqwYYMGDhyo5557LqWNAwAyn+cQ+uCDDzRz5sz46/LycklSWVmZ3njjDb344ou6du2ali5dqkuXLmnixIk6ePCgAoFA6roGAPQIPuecs27is2KxmILBoHUbwH0rKyvzXPPyyy97rknmydFr1655rpGkp556ynPN4cOHkzoXeq5oNKqcnJwuj2HtOACAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAmZR+syrQXQwePDiputWrV3uu+cEPfuC5pk8f7///++9//+u5ZsqUKZ5rJOlvf/tbUnWAV8yEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmGEBU/RIb7zxRlJ1CxcuTG0jd/H22297rtmyZYvnGhYiRXfHTAgAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZFjBFjzRy5EjrFrq0detWzzXHjh1LQyeALWZCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzLCAKXqkgwcPJlVXVFSU4k46l0x/ySx6unHjRs81knT+/Pmk6gCvmAkBAMwQQgAAM55D6MiRI5o7d64ikYh8Pp/27NmTsH/x4sXy+XwJY9KkSanqFwDQg3gOoStXrqioqEhVVVV3PWbOnDm6cOFCfOzfv/+BmgQA9EyeH0woLS1VaWlpl8f4/X6FQqGkmwIA9A5peU+opqZGeXl5Gj16tJYsWaKWlpa7Htve3q5YLJYwAAC9Q8pDqLS0VG+99ZYOHTqkV155RSdOnNCsWbPU3t7e6fGVlZUKBoPxMXz48FS3BADoplL+OaFFixbFfz1mzBiNHz9eBQUF2rdvnxYuXNjh+DVr1qi8vDz+OhaLEUQA0Euk/cOq4XBYBQUFqq+v73S/3++X3+9PdxsAgG4o7Z8Tam1tVVNTk8LhcLpPBQDIMJ5nQpcvX9Ynn3wSf93Q0KAPP/xQubm5ys3NVUVFhZ5++mmFw2GdO3dOa9eu1ZAhQ7RgwYKUNg4AyHyeQ+iDDz7QzJkz46/vvJ9TVlamrVu36tSpU9qxY4f+97//KRwOa+bMmdq1a5cCgUDqugYA9Ag+55yzbuKzYrGYgsGgdRvIcNnZ2UnVvfnmm55rxo0b57lmxIgRnmuS0dzcnFTdt7/9bc81Bw4cSOpc6Lmi0ahycnK6PIa14wAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZlhFG/iMAQMGeK7p18/7FxTHYjHPNQ/T9evXPdfc+VoXL1577TXPNcgcrKINAOjWCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmGEBU8DAV7/6Vc81P/nJTzzXzJw503NNshobGz3XfOlLX0p9I+g2WMAUANCtEUIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMMMCpnioBg4c6Lnm6tWraegk8zzyyCOea375y18mda558+YlVefVo48+6rnmwoULaegE6cACpgCAbo0QAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAICZftYNIHONHDnSc83777/vuWbfvn2ea/7yl794rpGSWxzzO9/5juearKwszzXJLPb52GOPea5J1j/+8Q/PNSxGCmZCAAAzhBAAwIynEKqsrNSECRMUCASUl5en+fPn68yZMwnHOOdUUVGhSCSi7OxszZgxQ6dPn05p0wCAnsFTCNXW1mrZsmU6fvy4qqurdfPmTZWUlOjKlSvxYzZt2qTNmzerqqpKJ06cUCgU0uzZs9XW1pby5gEAmc3Tgwnvvfdewutt27YpLy9PJ0+e1LRp0+Sc05YtW7Ru3TotXLhQkrR9+3bl5+dr586dev7551PXOQAg4z3Qe0LRaFSSlJubK0lqaGhQc3OzSkpK4sf4/X5Nnz5dx44d6/T3aG9vVywWSxgAgN4h6RByzqm8vFxTpkzRmDFjJEnNzc2SpPz8/IRj8/Pz4/s+r7KyUsFgMD6GDx+ebEsAgAyTdAgtX75cH330kX71q1912Ofz+RJeO+c6bLtjzZo1ikaj8dHU1JRsSwCADJPUh1VXrFihvXv36siRIxo2bFh8eygUknR7RhQOh+PbW1paOsyO7vD7/fL7/cm0AQDIcJ5mQs45LV++XLt379ahQ4dUWFiYsL+wsFChUEjV1dXxbTdu3FBtba2Ki4tT0zEAoMfwNBNatmyZdu7cqd/+9rcKBALx93mCwaCys7Pl8/m0cuVKbdiwQaNGjdKoUaO0YcMGDRw4UM8991xa/gAAgMzlKYS2bt0qSZoxY0bC9m3btmnx4sWSpBdffFHXrl3T0qVLdenSJU2cOFEHDx5UIBBIScMAgJ7D55xz1k18ViwWUzAYtG4D9+Gll17yXFNZWem5ppvdoilxtwd1uvIwr8Ply5c91yxYsMBzzR/+8AfPNcgc0WhUOTk5XR7D2nEAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADNJfbMqIElf/OIXrVvoVd555x3PNT/+8Y+TOldLS4vnmjvfLwZ4wUwIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGZ9zzlk38VmxWEzBYNC6DdyHrKwszzWzZs3yXPPNb37Tc00kEvFcI0nRaDSpOq9++tOfeq45evSo55qbN296rgFSJRqNKicnp8tjmAkBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwwwKmAIC0YAFTAEC3RggBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM55CqLKyUhMmTFAgEFBeXp7mz5+vM2fOJByzePFi+Xy+hDFp0qSUNg0A6Bk8hVBtba2WLVum48ePq7q6Wjdv3lRJSYmuXLmScNycOXN04cKF+Ni/f39KmwYA9Az9vBz83nvvJbzetm2b8vLydPLkSU2bNi2+3e/3KxQKpaZDAECP9UDvCUWjUUlSbm5uwvaamhrl5eVp9OjRWrJkiVpaWu76e7S3tysWiyUMAEDv4HPOuWQKnXOaN2+eLl26pKNHj8a379q1S4MHD1ZBQYEaGhr0wx/+UDdv3tTJkyfl9/s7/D4VFRX60Y9+lPyfAADQLUWjUeXk5HR9kEvS0qVLXUFBgWtqauryuPPnz7usrCz3zjvvdLr/+vXrLhqNxkdTU5OTxGAwGIwMH9Fo9J5Z4uk9oTtWrFihvXv36siRIxo2bFiXx4bDYRUUFKi+vr7T/X6/v9MZEgCg5/MUQs45rVixQu+++65qampUWFh4z5rW1lY1NTUpHA4n3SQAoGfy9GDCsmXL9Oabb2rnzp0KBAJqbm5Wc3Ozrl27Jkm6fPmyVq9erT/96U86d+6campqNHfuXA0ZMkQLFixIyx8AAJDBvLwPpLv83G/btm3OOeeuXr3qSkpK3NChQ11WVpYbMWKEKysrc42Njfd9jmg0av5zTAaDwWA8+Lif94SSfjouXWKxmILBoHUbAIAHdD9Px7F2HADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADATLcLIeecdQsAgBS4n3/Pu10ItbW1WbcAAEiB+/n33Oe62dTj1q1bOn/+vAKBgHw+X8K+WCym4cOHq6mpSTk5OUYd2uM63MZ1uI3rcBvX4bbucB2cc2pra1MkElGfPl3Pdfo9pJ7uW58+fTRs2LAuj8nJyenVN9kdXIfbuA63cR1u4zrcZn0dgsHgfR3X7X4cBwDoPQghAICZjAohv9+v9evXy+/3W7diiutwG9fhNq7DbVyH2zLtOnS7BxMAAL1HRs2EAAA9CyEEADBDCAEAzBBCAAAzGRVCr776qgoLCzVgwACNGzdOR48etW7poaqoqJDP50sYoVDIuq20O3LkiObOnatIJCKfz6c9e/Yk7HfOqaKiQpFIRNnZ2ZoxY4ZOnz5t02wa3es6LF68uMP9MWnSJJtm06SyslITJkxQIBBQXl6e5s+frzNnziQc0xvuh/u5DplyP2RMCO3atUsrV67UunXrVFdXp6lTp6q0tFSNjY3WrT1Ujz/+uC5cuBAfp06dsm4p7a5cuaKioiJVVVV1un/Tpk3avHmzqqqqdOLECYVCIc2ePbvHrUN4r+sgSXPmzEm4P/bv3/8QO0y/2tpaLVu2TMePH1d1dbVu3rypkpISXblyJX5Mb7gf7uc6SBlyP7gM8bWvfc298MILCdu+/OUvu5deesmoo4dv/fr1rqioyLoNU5Lcu+++G39969YtFwqF3MaNG+Pbrl+/7oLBoHvttdcMOnw4Pn8dnHOurKzMzZs3z6QfKy0tLU6Sq62tdc713vvh89fBucy5HzJiJnTjxg2dPHlSJSUlCdtLSkp07Ngxo65s1NfXKxKJqLCwUM8884zOnj1r3ZKphoYGNTc3J9wbfr9f06dP73X3hiTV1NQoLy9Po0eP1pIlS9TS0mLdUlpFo1FJUm5urqTeez98/jrckQn3Q0aE0MWLF/Xpp58qPz8/YXt+fr6am5uNunr4Jk6cqB07dujAgQN6/fXX1dzcrOLiYrW2tlq3ZubO339vvzckqbS0VG+99ZYOHTqkV155RSdOnNCsWbPU3t5u3VpaOOdUXl6uKVOmaMyYMZJ65/3Q2XWQMud+6HaraHfl81/t4JzrsK0nKy0tjf967Nixmjx5skaOHKnt27ervLzcsDN7vf3ekKRFixbFfz1mzBiNHz9eBQUF2rdvnxYuXGjYWXosX75cH330kd5///0O+3rT/XC365Ap90NGzISGDBmivn37dvifTEtLS4f/8fQmgwYN0tixY1VfX2/dipk7Twdyb3QUDodVUFDQI++PFStWaO/evTp8+HDCV7/0tvvhbtehM931fsiIEOrfv7/GjRun6urqhO3V1dUqLi426spee3u7Pv74Y4XDYetWzBQWFioUCiXcGzdu3FBtbW2vvjckqbW1VU1NTT3q/nDOafny5dq9e7cOHTqkwsLChP295X6413XoTLe9HwwfivDk17/+tcvKynK/+MUv3F//+le3cuVKN2jQIHfu3Dnr1h6aVatWuZqaGnf27Fl3/Phx99RTT7lAINDjr0FbW5urq6tzdXV1TpLbvHmzq6urc//85z+dc85t3LjRBYNBt3v3bnfq1Cn37LPPunA47GKxmHHnqdXVdWhra3OrVq1yx44dcw0NDe7w4cNu8uTJ7tFHH+1R1+G73/2uCwaDrqamxl24cCE+rl69Gj+mN9wP97oOmXQ/ZEwIOefcz372M1dQUOD69+/vnnzyyYTHEXuDRYsWuXA47LKyslwkEnELFy50p0+ftm4r7Q4fPuwkdRhlZWXOuduP5a5fv96FQiHn9/vdtGnT3KlTp2ybToOursPVq1ddSUmJGzp0qMvKynIjRoxwZWVlrrGx0brtlOrszy/Jbdu2LX5Mb7gf7nUdMul+4KscAABmMuI9IQBAz0QIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMDM/wGB1/R+vec9fQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "xb,yb = next(iter(valid_dl))\n", "plt.imshow(xb[0].view(28,28))\n", "yb[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([50, 784]), torch.Size([50]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb.shape,yb.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model,opt = get_model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.27, 0.04\n", "0.15, 0.02\n", "0.01, 0.12\n" ] } ], "source": [ "fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiprocessing DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.multiprocessing as mp\n", "from fastcore.basics import store_attr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]),\n", " tensor([1, 1, 1, 0]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds[[3,6,8,1]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]),\n", " tensor([1, 1, 1, 0]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds.__getitem__([3,6,8,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]), tensor([1, 1]))\n", "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]), tensor([1, 0]))\n" ] } ], "source": [ "for o in map(train_ds.__getitem__, ([3,6],[8,1])): print(o)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DataLoader():\n", " def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr()\n", " def __iter__(self):\n", " with mp.Pool(self.n_workers) as ex: yield from ex.map(self.ds.__getitem__, iter(self.batchs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, batchs=train_samp, n_workers=2)\n", "it = iter(train_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xb,yb = next(it)\n", "xb.shape,yb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorch DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_samp = BatchSampler(RandomSampler(train_ds), bs, drop_last=False)\n", "valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate)\n", "valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.10, 0.06\n", "0.10, 0.04\n", "0.27, 0.06\n" ] }, { "data": { "text/plain": [ "(tensor(0.05, grad_fn=), tensor(0.98))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model,opt = get_model()\n", "fit()\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch can auto-generate the BatchSampler for us:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)\n", "valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch can also generate the Sequential/RandomSamplers too:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2)\n", "valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n", "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.21, 0.14\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n", "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.15, 0.16\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n", "[W ParallelNative.cpp:229] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.05, 0.10\n" ] }, { "data": { "text/plain": [ "(tensor(0.01, grad_fn=), tensor(1.))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model,opt = get_model()\n", "fit()\n", "\n", "loss_func(model(xb), yb), accuracy(model(xb), yb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our dataset actually already knows how to sample a batch of indices all at once:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]),\n", " tensor([9, 1, 3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds[[4,6,7]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...that means that we can actually skip the batch_sampler and collate_fn entirely:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_ds, sampler=train_samp)\n", "valid_dl = DataLoader(valid_ds, sampler=valid_samp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 50, 784]), torch.Size([1, 50]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb,yb = next(iter(train_dl))\n", "xb.shape,yb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You **always** should also have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/), in order to identify if you are overfitting.\n", "\n", "We will calculate and print the validation loss at the end of each epoch.\n", "\n", "(Note that we always call `model.train()` before training, and `model.eval()` before inference, because these are used by layers such as `nn.BatchNorm2d` and `nn.Dropout` to ensure appropriate behaviour for these different phases.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def fit(epochs, model, loss_func, opt, train_dl, valid_dl):\n", " for epoch in range(epochs):\n", " model.train()\n", " for xb,yb in train_dl:\n", " loss = loss_func(model(xb), yb)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " model.eval()\n", " with torch.no_grad():\n", " tot_loss,tot_acc,count = 0.,0.,0\n", " for xb,yb in valid_dl:\n", " pred = model(xb)\n", " n = len(xb)\n", " count += n\n", " tot_loss += loss_func(pred,yb).item()*n\n", " tot_acc += accuracy (pred,yb).item()*n\n", " print(epoch, tot_loss/count, tot_acc/count)\n", " return tot_loss/count, tot_acc/count" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def get_dls(train_ds, valid_ds, bs, **kwargs):\n", " return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),\n", " DataLoader(valid_ds, batch_size=bs*2, **kwargs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)\n", "model,opt = get_model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0.0951265140157193 0.9733000087738037\n", "1 0.09963013818487525 0.9722000092267991\n", "2 0.1106003435049206 0.969400007724762\n", "3 0.09982822934631258 0.9742000102996826\n", "4 0.12104855466226581 0.9687000066041946\n", "CPU times: user 1.38 s, sys: 487 ms, total: 1.87 s\n", "Wall time: 1.54 s\n" ] } ], "source": [ "%time loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import nbdev; nbdev.nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }