{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The forward and backward passes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=4960)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from exp.nb_01 import *\n", "\n", "def get_data():\n", " path = datasets.download_data(MNIST_URL, ext='.gz')\n", " with gzip.open(path, 'rb') as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", " return map(tensor, (x_train,y_train,x_valid,y_valid))\n", "\n", "def normalize(x, m, s): return (x-m)/s" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_train,y_train,x_valid,y_valid = get_data()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1304), tensor(0.3073))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_mean,train_std = x_train.mean(),x_train.std()\n", "train_mean,train_std" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_train = normalize(x_train, train_mean, train_std)\n", "# NB: Use training, not validation mean for validation set\n", "x_valid = normalize(x_valid, train_mean, train_std)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(3.0614e-05), tensor(1.))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_mean,train_std = x_train.mean(),x_train.std()\n", "train_mean,train_std" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def test_near_zero(a,tol=1e-3): assert a.abs()0).float() * out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def lin_grad(inp, out, w, b):\n", " # grad of matmul with respect to input\n", " inp.g = out.g @ w.t()\n", " w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)\n", " b.g = out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def forward_and_backward(inp, targ):\n", " # forward pass:\n", " l1 = inp @ w1 + b1\n", " l2 = relu(l1)\n", " out = l2 @ w2 + b2\n", " # we don't actually need the loss in backward!\n", " loss = mse(out, targ)\n", " \n", " # backward pass:\n", " mse_grad(out, targ)\n", " lin_grad(l2, out, w2, b2)\n", " relu_grad(l1, l2)\n", " lin_grad(inp, l1, w1, b1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "forward_and_backward(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save for testing against later\n", "w1g = w1.g.clone()\n", "w2g = w2.g.clone()\n", "b1g = b1.g.clone()\n", "b2g = b2.g.clone()\n", "ig = x_train.g.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We cheat a little bit and use PyTorch autograd to check our results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xt2 = x_train.clone().requires_grad_(True)\n", "w12 = w1.clone().requires_grad_(True)\n", "w22 = w2.clone().requires_grad_(True)\n", "b12 = b1.clone().requires_grad_(True)\n", "b22 = b2.clone().requires_grad_(True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def forward(inp, targ):\n", " # forward pass:\n", " l1 = inp @ w12 + b12\n", " l2 = relu(l1)\n", " out = l2 @ w22 + b22\n", " # we don't actually need the loss in backward!\n", " return mse(out, targ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss = forward(xt2, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_near(w22.grad, w2g)\n", "test_near(b22.grad, b2g)\n", "test_near(w12.grad, w1g)\n", "test_near(b12.grad, b1g)\n", "test_near(xt2.grad, ig )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Refactor model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Layers as classes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=7112)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Relu():\n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp.clamp_min(0.)-0.5\n", " return self.out\n", " \n", " def backward(self): self.inp.g = (self.inp>0).float() * self.out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lin():\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp@self.w + self.b\n", " return self.out\n", " \n", " def backward(self):\n", " self.inp.g = self.out.g @ self.w.t()\n", " # Creating a giant outer product, just to sum it, is inefficient!\n", " self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)\n", " self.b.g = self.out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Mse():\n", " def __call__(self, inp, targ):\n", " self.inp = inp\n", " self.targ = targ\n", " self.out = (inp.squeeze() - targ).pow(2).mean()\n", " return self.out\n", " \n", " def backward(self):\n", " self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.targ.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model():\n", " def __init__(self, w1, b1, w2, b2):\n", " self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n", " self.loss = Mse()\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return self.loss(x, targ)\n", " \n", " def backward(self):\n", " self.loss.backward()\n", " for l in reversed(self.layers): l.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1.g,b1.g,w2.g,b2.g = [None]*4\n", "model = Model(w1, b1, w2, b2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 137 ms, sys: 4.95 ms, total: 142 ms\n", "Wall time: 70.7 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.84 s, sys: 3.86 s, total: 6.71 s\n", "Wall time: 3.4 s\n" ] } ], "source": [ "%time model.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_near(w2g, w2.g)\n", "test_near(b2g, b2.g)\n", "test_near(w1g, w1.g)\n", "test_near(b1g, b1.g)\n", "test_near(ig, x_train.g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Module.forward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Module():\n", " def __call__(self, *args):\n", " self.args = args\n", " self.out = self.forward(*args)\n", " return self.out\n", " \n", " def forward(self): raise Exception('not implemented')\n", " def backward(self): self.bwd(self.out, *self.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Relu(Module):\n", " def forward(self, inp): return inp.clamp_min(0.)-0.5\n", " def bwd(self, out, inp): inp.g = (inp>0).float() * out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lin(Module):\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def forward(self, inp): return inp@self.w + self.b\n", " \n", " def bwd(self, out, inp):\n", " inp.g = out.g @ self.w.t()\n", " self.w.g = torch.einsum(\"bi,bj->ij\", inp, out.g)\n", " self.b.g = out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Mse(Module):\n", " def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n", " def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model():\n", " def __init__(self):\n", " self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n", " self.loss = Mse()\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return self.loss(x, targ)\n", " \n", " def backward(self):\n", " self.loss.backward()\n", " for l in reversed(self.layers): l.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1.g,b1.g,w2.g,b2.g = [None]*4\n", "model = Model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 86 ms, sys: 8.25 ms, total: 94.2 ms\n", "Wall time: 46.3 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 193 ms, sys: 87.6 ms, total: 280 ms\n", "Wall time: 140 ms\n" ] } ], "source": [ "%time model.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_near(w2g, w2.g)\n", "test_near(b2g, b2.g)\n", "test_near(w1g, w1.g)\n", "test_near(b1g, b1.g)\n", "test_near(ig, x_train.g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Without einsum" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=7484)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lin(Module):\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def forward(self, inp): return inp@self.w + self.b\n", " \n", " def bwd(self, out, inp):\n", " inp.g = out.g @ self.w.t()\n", " self.w.g = inp.t() @ out.g\n", " self.b.g = out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1.g,b1.g,w2.g,b2.g = [None]*4\n", "model = Model()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 88.6 ms, sys: 5.04 ms, total: 93.6 ms\n", "Wall time: 46.4 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 197 ms, sys: 83.9 ms, total: 281 ms\n", "Wall time: 140 ms\n" ] } ], "source": [ "%time model.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_near(w2g, w2.g)\n", "test_near(b2g, b2.g)\n", "test_near(w1g, w1.g)\n", "test_near(b1g, b1.g)\n", "test_near(ig, x_train.g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### nn.Linear and nn.Module" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch import nn" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]\n", " self.loss = mse\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return self.loss(x.squeeze(), targ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Model(m, nh, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 85.1 ms, sys: 8.16 ms, total: 93.3 ms\n", "Wall time: 46.3 ms\n" ] } ], "source": [ "%time loss = model(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 135 ms, sys: 78.1 ms, total: 213 ms\n", "Wall time: 71.1 ms\n" ] } ], "source": [ "%time loss.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 02_fully_connected.ipynb to nb_02.py\r\n" ] } ], "source": [ "!./notebook2script.py 02_fully_connected.ipynb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }