{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## The forward and backward passes" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np\n", "from pathlib import Path\n", "from torch import tensor\n", "\n", "mpl.rcParams['image.cmap'] = 'gray'\n", "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", "np.set_printoptions(precision=2, linewidth=140)\n", "\n", "path_data = Path('data')\n", "path_gz = path_data/'mnist.pkl.gz'\n", "with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", "x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Foundations version" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Basic architecture" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(50000, 784, tensor(10))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n,m = x_train.shape\n", "c = y_train.max()+1\n", "n,m,c" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hidden": true }, "outputs": [], "source": [ "# num hidden\n", "nh = 50" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hidden": true }, "outputs": [], "source": [ "w1 = torch.randn(m,nh)\n", "b1 = torch.zeros(nh)\n", "w2 = torch.randn(nh,1)\n", "b2 = torch.zeros(1)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hidden": true }, "outputs": [], "source": [ "def lin(x, w, b): return x@w + b" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([10000, 50])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = lin(x_valid, w1, b1)\n", "t.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hidden": true }, "outputs": [], "source": [ "def relu(x): return x.clamp_min(0.)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.66, 0.00, 0.00, ..., 0.00, 0.00, 0.00],\n", " [ 0.00, 10.93, 0.00, ..., 0.08, 0.00, 0.00],\n", " [ 5.45, 3.59, 0.00, ..., 10.75, 8.27, 0.00],\n", " ...,\n", " [ 0.00, 3.35, 5.53, ..., 0.00, 0.00, 0.00],\n", " [ 0.00, 4.37, 5.65, ..., 0.97, 9.48, 0.00],\n", " [ 5.59, 0.00, 3.30, ..., 0.00, 0.00, 0.00]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = relu(lin(x_valid, w1, b1))\n", "t" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "hidden": true }, "outputs": [], "source": [ "def model(xb):\n", " l1 = lin(xb, w1, b1)\n", " l2 = relu(l1)\n", " return lin(l2, w2, b2)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([10000, 1])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res = model(x_valid)\n", "res.shape" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Loss function: MSE" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We need to get rid of that trailing (,1), in order to use `mse`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([10000])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res[:,0].shape" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "(Of course, `mse` is not a suitable loss function for multi-class classification; we'll use a better loss function soon. We'll use `mse` for now to keep things simple.)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "hidden": true }, "outputs": [], "source": [ "def mse(output, targ): return (output[:,0]-targ).pow(2).mean()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "hidden": true }, "outputs": [], "source": [ "y_train,y_valid = y_train.float(),y_valid.float()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([50000, 1])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = model(x_train)\n", "preds.shape" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "tensor(1869.67)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse(preds, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradients and backward pass" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle 2 x$" ], "text/plain": [ "2*x" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sympy import symbols,diff\n", "x,y = symbols('x y')\n", "diff(x**2, x)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def lin_grad(inp, out, w, b):\n", " # grad of matmul with respect to input\n", " inp.g = out.g @ w.t()\n", " w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)\n", " b.g = out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def forward_and_backward(inp, targ):\n", " # forward pass:\n", " l1 = inp @ w1 + b1\n", " l2 = relu(l1)\n", " out = l2 @ w2 + b2\n", " diff = out[:,0]-targ\n", " loss = res.pow(2).mean()\n", " \n", " # backward pass:\n", " out.g = 2.*diff[:,None] / inp.shape[0]\n", " lin_grad(l2, out, w2, b2)\n", " l1.g = (l1>0).float() * l2.g\n", " lin_grad(inp, l1, w1, b1)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "forward_and_backward(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# Save for testing against later\n", "w1g = w1.g.clone()\n", "w2g = w2.g.clone()\n", "b1g = b1.g.clone()\n", "b2g = b2.g.clone()\n", "ig = x_train.g.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We cheat a little bit and use PyTorch autograd to check our results." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "xt2 = x_train.clone().requires_grad_(True)\n", "w12 = w1.clone().requires_grad_(True)\n", "w22 = w2.clone().requires_grad_(True)\n", "b12 = b1.clone().requires_grad_(True)\n", "b22 = b2.clone().requires_grad_(True)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "def forward(inp, targ):\n", " l1 = inp @ w12 + b12\n", " l2 = relu(l1)\n", " out = l2 @ w22 + b22\n", " return mse(out, targ)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "loss = forward(xt2, y_train)\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from fastcore.test import test_close" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "scrolled": false }, "outputs": [], "source": [ "test_close(w22.grad, w2g, eps=0.01)\n", "test_close(b22.grad, b2g, eps=0.01)\n", "test_close(w12.grad, w1g, eps=0.01)\n", "test_close(b12.grad, b1g, eps=0.01)\n", "test_close(xt2.grad, ig , eps=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Refactor model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Layers as classes" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "class Relu():\n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp.clamp_min(0.)\n", " return self.out\n", " \n", " def backward(self): self.inp.g = (self.inp>0).float() * self.out.g" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "class Lin():\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp@self.w + self.b\n", " return self.out\n", "\n", " def backward(self):\n", " self.inp.g = self.out.g @ self.w.t()\n", " self.w.g = self.inp.t() @ self.out.g\n", " self.b.g = self.out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class Mse():\n", " def __call__(self, inp, targ):\n", " self.inp = inp\n", " self.targ = targ\n", " self.out = (inp.squeeze() - targ).pow(2).mean()\n", " return self.out\n", " \n", " def backward(self):\n", " self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.targ.shape[0]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "class Model():\n", " def __init__(self, w1, b1, w2, b2):\n", " self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n", " self.loss = Mse()\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return self.loss(x, targ)\n", " \n", " def backward(self):\n", " self.loss.backward()\n", " for l in reversed(self.layers): l.backward()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "model = Model(w1, b1, w2, b2)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 677 ms, sys: 42.1 ms, total: 719 ms\n", "Wall time: 22.5 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.97 s, sys: 157 ms, total: 2.12 s\n", "Wall time: 66.4 ms\n" ] } ], "source": [ "%time model.backward()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "scrolled": false }, "outputs": [], "source": [ "test_close(w2g, w2.g, eps=0.01)\n", "test_close(b2g, b2.g, eps=0.01)\n", "test_close(w1g, w1.g, eps=0.01)\n", "test_close(b1g, b1.g, eps=0.01)\n", "test_close(ig, x_train.g, eps=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Module.forward()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "class Module():\n", " def __call__(self, *args):\n", " self.args = args\n", " self.out = self.forward(*args)\n", " return self.out\n", "\n", " def forward(self): raise Exception('not implemented')\n", " def bwd(self): raise Exception('not implemented')\n", " def backward(self): self.bwd(self.out, *self.args)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "class Relu(Module):\n", " def forward(self, inp): return inp.clamp_min(0.)\n", " def bwd(self, out, inp): inp.g = (inp>0).float() * out.g" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "class Lin(Module):\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " def forward(self, inp): return inp@self.w + self.b\n", " def bwd(self, out, inp):\n", " inp.g = self.out.g @ self.w.t()\n", " self.w.g = inp.t() @ self.out.g\n", " self.b.g = self.out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "class Mse(Module):\n", " def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n", " def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "model = Model(w1, b1, w2, b2)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 725 ms, sys: 0 ns, total: 725 ms\n", "Wall time: 22.6 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2 s, sys: 154 ms, total: 2.15 s\n", "Wall time: 67.2 ms\n" ] } ], "source": [ "%time model.backward()" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "test_close(w2g, w2.g, eps=0.01)\n", "test_close(b2g, b2.g, eps=0.01)\n", "test_close(w1g, w1.g, eps=0.01)\n", "test_close(b1g, b1.g, eps=0.01)\n", "test_close(ig, x_train.g, eps=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Autograd" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from torch import nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "class Linear(nn.Module):\n", " def __init__(self, n_in, n_out):\n", " super().__init__()\n", " self.w = torch.randn(n_in,n_out).requires_grad_()\n", " self.b = torch.zeros(n_out).requires_grad_()\n", " def forward(self, inp): return inp@self.w + self.b" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.layers = [Linear(n_in,nh), nn.ReLU(), Linear(nh,n_out)]\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return F.mse_loss(x, targ[:,None])" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "model = Model(m, nh, 1)\n", "loss = model(x_train, y_train)\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ -0.18, 5.17, -13.02, 3.17, -3.52, 8.76, -1.48, 11.78, -0.38, -1.35, -11.37, -2.88, 11.22, 0.80, 2.72, -2.02,\n", " -5.94, 3.50, 0.52, -8.99, -0.72, 2.58, 0.65, 19.21, 0.83, 1.28, 1.27, -3.17, 16.17, 12.27, 1.43, 0.62,\n", " 24.14, 10.15, 8.03, -2.05, 3.81, 1.89, -2.60, 3.19, 8.60, -0.33, -0.06, 5.92, 47.10, -5.18, 3.15, 5.82,\n", " -0.95, -1.54])" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l0 = model.layers[0]\n", "l0.b.grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }