{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.basics import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.hook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model hooks\n", "\n", "> Callback and helper function to add hooks in models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What are hooks?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hooks are functions you can attach to a particular layer in your model and that will be executed in the foward pass (for forward hooks) or backward pass (for backward hooks). Here we begin with an introduction around hooks, but you should jump to `HookCallback` if you quickly want to implement one (and read the following example `ActivationStats`).\n", "\n", "Forward hooks are functions that take three arguments: the layer it's applied to, the input of that layer and the output of that layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=5, out_features=3, bias=True) (tensor([[-1.8820, 0.7021, -0.6919, -0.8470, 0.4694],\n", " [ 1.6047, 1.1505, 1.9210, -0.4393, 1.6700],\n", " [ 1.1959, 0.5682, -1.0785, -0.5261, -0.2628],\n", " [-0.9082, 1.5110, 0.3545, 0.3456, 0.7868]]),) tensor([[-0.4622, 0.4920, 0.0128],\n", " [ 0.5634, -1.1016, 1.5529],\n", " [-0.1620, 0.0452, 0.6342],\n", " [ 0.1072, 0.1193, 0.0884]], grad_fn=)\n" ] } ], "source": [ "tst_model = nn.Linear(5,3)\n", "def example_forward_hook(m,i,o): print(m,i,o)\n", " \n", "x = torch.randn(4,5)\n", "hook = tst_model.register_forward_hook(example_forward_hook)\n", "y = tst_model(x)\n", "hook.remove()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Backward hooks are functions that take three arguments: the layer it's applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=5, out_features=3, bias=True) (tensor([ 0.2188, -0.4699, 0.4320]), None, tensor([[-0.1662, -0.1490, 0.0825],\n", " [ 0.5101, 0.0921, 0.0596],\n", " [ 0.5675, -0.5542, 0.4907],\n", " [-0.2050, -0.0900, -0.0019],\n", " [-0.0335, 0.0810, -0.0717]])) (tensor([[ 0.0437, -0.1476, 0.1374],\n", " [-0.1166, -0.0519, 0.0458],\n", " [ 0.1882, -0.0790, 0.1273],\n", " [ 0.1035, -0.1914, 0.1215]]),)\n" ] } ], "source": [ "def example_backward_hook(m,gi,go): print(m,gi,go)\n", "hook = tst_model.register_backward_hook(example_backward_hook)\n", "\n", "x = torch.randn(4,5)\n", "y = tst_model(x)\n", "loss = y.pow(2).mean()\n", "loss.backward()\n", "hook.remove()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it's best to have your hook associated to a class so that it can put it in the state of an instance of that class." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hook -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@docs\n", "class Hook():\n", " \"Create a hook on `m` with `hook_func`.\"\n", " def __init__(self, m, hook_func, is_forward=True, detach=True, cpu=False):\n", " self.hook_func,self.detach,self.cpu,self.stored = hook_func,detach,cpu,None\n", " f = m.register_forward_hook if is_forward else m.register_backward_hook\n", " self.hook = f(self.hook_fn)\n", " self.removed = False\n", "\n", " def hook_fn(self, module, input, output):\n", " \"Applies `hook_func` to `module`, `input`, `output`.\"\n", " if self.detach: input,output = to_detach(input, cpu=self.cpu),to_detach(output, cpu=self.cpu)\n", " self.stored = self.hook_func(module, input, output)\n", "\n", " def remove(self):\n", " \"Remove the hook from the model.\"\n", " if not self.removed:\n", " self.hook.remove()\n", " self.removed=True\n", "\n", " def __enter__(self, *args): return self\n", " def __exit__(self, *args): self.remove()\n", "\n", " _docs = dict(__enter__=\"Register the hook\",\n", " __exit__=\"Remove the hook\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will be called during the forward pass if `is_forward=True`, the backward pass otherwise, and will optionally `detach` and put on the `cpu` the (gradient of the) input/output of the model before passing them to `hook_func`. The result of `hook_func` will be stored in the `stored` attribute of the `Hook`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_model = nn.Linear(5,3)\n", "hook = Hook(tst_model, lambda m,i,o: o)\n", "y = tst_model(x)\n", "test_eq(hook.stored, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hook.hook_fn[source]

\n", "\n", "> Hook.hook_fn(**`module`**, **`input`**, **`output`**)\n", "\n", "Applies `hook_func` to `module`, `input`, `output`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.hook_fn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hook.remove[source]

\n", "\n", "> Hook.remove()\n", "\n", "Remove the hook from the model." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: It's important to properly remove your hooks for your model when you're done to avoid them being called again next time your model is applied to some inputs, and to free the memory that go with their state." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_model = nn.Linear(5,10)\n", "x = torch.randn(4,5)\n", "y = tst_model(x)\n", "hook = Hook(tst_model, example_forward_hook)\n", "test_stdout(lambda: tst_model(x), f\"{tst_model} ({x},) {y.detach()}\")\n", "hook.remove()\n", "test_stdout(lambda: tst_model(x), \"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Context Manager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since it's very important to remove your `Hook` even if your code is interrupted by some bug, `Hook` can be used as context managers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hook.__enter__[source]

\n", "\n", "> Hook.__enter__(**\\*`args`**)\n", "\n", "Register the hook" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.__enter__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hook.__exit__[source]

\n", "\n", "> Hook.__exit__(**\\*`args`**)\n", "\n", "Remove the hook" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.__exit__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_model = nn.Linear(5,10)\n", "x = torch.randn(4,5)\n", "y = tst_model(x)\n", "with Hook(tst_model, example_forward_hook) as h:\n", " test_stdout(lambda: tst_model(x), f\"{tst_model} ({x},) {y.detach()}\")\n", "test_stdout(lambda: tst_model(x), \"\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _hook_inner(m,i,o): return o if isinstance(o,Tensor) or is_listy(o) else list(o)\n", "\n", "def hook_output(module, detach=True, cpu=False, grad=False):\n", " \"Return a `Hook` that stores activations of `module` in `self.stored`\"\n", " return Hook(module, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The activations stored are the gradients if `grad=True`, otherwise the output of `module`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_model = nn.Linear(5,10)\n", "x = torch.randn(4,5)\n", "with hook_output(tst_model) as h:\n", " y = tst_model(x)\n", " test_eq(y, h.stored)\n", " assert not h.stored.requires_grad\n", " \n", "with hook_output(tst_model, grad=True) as h:\n", " y = tst_model(x)\n", " loss = y.pow(2).mean()\n", " loss.backward()\n", " test_close(2*y / y.numel(), h.stored[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#cuda\n", "with hook_output(tst_model, cpu=True) as h:\n", " y = tst_model.cuda()(x.cuda())\n", " test_eq(h.stored.device, torch.device('cpu'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hooks -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@docs\n", "class Hooks():\n", " \"Create several hooks on the modules in `ms` with `hook_func`.\"\n", " def __init__(self, ms, hook_func, is_forward=True, detach=True, cpu=False):\n", " self.hooks = [Hook(m, hook_func, is_forward, detach, cpu) for m in ms]\n", "\n", " def __getitem__(self,i): return self.hooks[i]\n", " def __len__(self): return len(self.hooks)\n", " def __iter__(self): return iter(self.hooks)\n", " @property\n", " def stored(self): return L(o.stored for o in self)\n", "\n", " def remove(self):\n", " \"Remove the hooks from the model.\"\n", " for h in self.hooks: h.remove()\n", "\n", " def __enter__(self, *args): return self\n", " def __exit__ (self, *args): self.remove()\n", "\n", " _docs = dict(stored = \"The states saved in each hook.\",\n", " __enter__=\"Register the hooks\",\n", " __exit__=\"Remove the hooks\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n", "tst_model = nn.Sequential(*layers)\n", "hooks = Hooks(tst_model, lambda m,i,o: o)\n", "y = tst_model(x)\n", "test_eq(hooks.stored[0], layers[0](x))\n", "test_eq(hooks.stored[1], F.relu(layers[0](x)))\n", "test_eq(hooks.stored[2], y)\n", "hooks.remove()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hooks.stored[source]

\n", "\n", "The states saved in each hook." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.stored, name='Hooks.stored')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hooks.remove[source]

\n", "\n", "> Hooks.remove()\n", "\n", "Remove the hooks from the model." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Context Manager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like `Hook` , you can use `Hooks` as context managers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hooks.__enter__[source]

\n", "\n", "> Hooks.__enter__(**\\*`args`**)\n", "\n", "Register the hooks" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.__enter__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Hooks.__exit__[source]

\n", "\n", "> Hooks.__exit__(**\\*`args`**)\n", "\n", "Remove the hooks" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.__exit__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n", "tst_model = nn.Sequential(*layers)\n", "with Hooks(layers, lambda m,i,o: o) as h:\n", " y = tst_model(x)\n", " test_eq(h.stored[0], layers[0](x))\n", " test_eq(h.stored[1], F.relu(layers[0](x)))\n", " test_eq(h.stored[2], y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def hook_outputs(modules, detach=True, cpu=False, grad=False):\n", " \"Return `Hooks` that store activations of all `modules` in `self.stored`\"\n", " return Hooks(modules, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The activations stored are the gradients if `grad=True`, otherwise the output of `modules`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n", "tst_model = nn.Sequential(*layers)\n", "x = torch.randn(4,5)\n", "with hook_outputs(layers) as h:\n", " y = tst_model(x)\n", " test_eq(h.stored[0], layers[0](x))\n", " test_eq(h.stored[1], F.relu(layers[0](x)))\n", " test_eq(h.stored[2], y)\n", " for s in h.stored: assert not s.requires_grad\n", " \n", "with hook_outputs(layers, grad=True) as h:\n", " y = tst_model(x)\n", " loss = y.pow(2).mean()\n", " loss.backward()\n", " g = 2*y / y.numel()\n", " test_close(g, h.stored[2][0])\n", " g = g @ layers[2].weight.data\n", " test_close(g, h.stored[1][0])\n", " g = g * (layers[0](x) > 0).float()\n", " test_close(g, h.stored[0][0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#cuda\n", "with hook_outputs(tst_model, cpu=True) as h:\n", " y = tst_model.cuda()(x.cuda())\n", " for s in h.stored: test_eq(s.device, torch.device('cpu'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def dummy_eval(m, size=(64,64)):\n", " \"Evaluate `m` on a dummy input of a certain `size`\"\n", " ch_in = in_channels(m)\n", " x = one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)\n", " with torch.no_grad(): return m.eval()(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def model_sizes(m, size=(64,64)):\n", " \"Pass a dummy input through the model `m` to get the various sizes of activations.\"\n", " with hook_outputs(m) as hooks:\n", " _ = dummy_eval(m, size=size)\n", " return [o.stored.shape for o in hooks]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))\n", "test_eq(model_sizes(m), [[1, 16, 64, 64], [1, 32, 32, 32], [1, 32, 32, 32]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def num_features_model(m):\n", " \"Return the number of output features for `m`.\"\n", " sz,ch_in = 32,in_channels(m)\n", " while True:\n", " #Trying for a few sizes in case the model requires a big input size.\n", " try:\n", " return model_sizes(m, (sz,sz))[-1][1]\n", " except Exception as e:\n", " sz *= 2\n", " if sz > 2048: raise e" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))\n", "test_eq(num_features_model(m), 3)\n", "m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))\n", "test_eq(num_features_model(m), 32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## HookCallback -" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a `hook` function (plus any element you might need)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def has_params(m):\n", " \"Check if `m` has at least one parameter\"\n", " return len(list(m.parameters())) > 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert has_params(nn.Linear(3,4))\n", "assert has_params(nn.LSTM(4,5,2))\n", "assert not has_params(nn.ReLU())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@funcs_kwargs\n", "class HookCallback(Callback):\n", " \"`Callback` that can be used to register hooks on `modules`\"\n", " _methods = [\"hook\"]\n", " hook = noops\n", " def __init__(self, modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, **kwargs):\n", " store_attr(self, 'modules,every,remove_end,is_forward,detach,cpu')\n", " assert not kwargs\n", "\n", " def begin_fit(self):\n", " \"Register the `Hooks` on `self.modules`.\"\n", " if self.modules is None: self.modules = [m for m in flatten_model(self.model) if has_params(m)]\n", " if self.every is None: self._register()\n", " \n", " def begin_batch(self):\n", " if self.every is None: return\n", " if self.training and self.train_iter%self.every==0: self._register()\n", " \n", " def after_batch(self):\n", " if self.every is None: return\n", " if self.training and self.train_iter%self.every==0: self._remove()\n", "\n", " def after_fit(self):\n", " \"Remove the `Hooks`.\"\n", " if self.remove_end: self._remove()\n", "\n", " def _register(self): self.hooks = Hooks(self.modules, self.hook, self.is_forward, self.detach, self.cpu)\n", " def _remove(self):\n", " if getattr(self, 'hooks', None): self.hooks.remove()\n", "\n", " def __del__(self): self._remove()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can either subclass and implement a `hook` function (along with any event you want) or pass that a `hook` function when initializing. Such a function needs to take three argument: a layer, input and output (for a backward hook, input means gradient with respect to the inputs, output, gradient with respect to the output) and can either modify them or update the state according to them.\n", "\n", "If not provided, `modules` will default to the layers of `self.model` that have a `weight` attribute. Depending on `do_remove`, the hooks will be properly removed at the end of training (or in case of error). `is_forward` , `detach` and `cpu` are passed to `Hooks`.\n", "\n", "The function called at each forward (or backward) pass is `self.hook` and must be implemented when subclassing this callback." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,11.155611038208008,11.110326766967773,00:00]\n" ] } ], "source": [ "class TstCallback(HookCallback):\n", " def hook(self, m, i, o): return o\n", " def after_batch(self): test_eq(self.hooks.stored[0], self.pred)\n", " \n", "learn = synth_learner(n_trn=5, cbs = TstCallback())\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,14.338729858398438,9.407218933105469,00:00]\n" ] } ], "source": [ "class TstCallback(HookCallback):\n", " def __init__(self, modules=None, remove_end=True, detach=True, cpu=False):\n", " super().__init__(modules, None, remove_end, False, detach, cpu)\n", " def hook(self, m, i, o): return o\n", " def after_batch(self):\n", " if self.training:\n", " test_eq(self.hooks.stored[0][0], 2*(self.pred-self.y)/self.pred.shape[0])\n", " \n", "learn = synth_learner(n_trn=5, cbs = TstCallback())\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

HookCallback.begin_fit[source]

\n", "\n", "> HookCallback.begin_fit()\n", "\n", "Register the [`Hooks`](/callback.hook.html#Hooks) on `self.modules`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.begin_fit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

HookCallback.after_fit[source]

\n", "\n", "> HookCallback.after_fit()\n", "\n", "Remove the [`Hooks`](/callback.hook.html#Hooks)." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.after_fit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def total_params(m):\n", " \"Give the number of parameters of a module and if it's trainable or not\"\n", " params = sum([p.numel() for p in m.parameters()])\n", " trains = [p.requires_grad for p in m.parameters()]\n", " return params, (False if len(trains)==0 else trains[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(total_params(nn.Linear(10,32)), (32*10+32,True))\n", "test_eq(total_params(nn.Linear(10,32, bias=False)), (32*10,True))\n", "test_eq(total_params(nn.BatchNorm2d(20)), (20*2, True))\n", "test_eq(total_params(nn.BatchNorm2d(20, affine=False)), (0,False))\n", "test_eq(total_params(nn.Conv2d(16, 32, 3)), (16*32*3*3 + 32, True))\n", "test_eq(total_params(nn.Conv2d(16, 32, 3, bias=False)), (16*32*3*3, True))\n", "#First ih layer 20--10, all else 10--10. *4 for the four gates\n", "test_eq(total_params(nn.LSTM(20, 10, 2)), (4 * (20*10 + 10) + 3 * 4 * (10*10 + 10), True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def layer_info(learn):\n", " def _track(m, i, o):\n", " return (m.__class__.__name__,)+total_params(m)+(apply(lambda x:x.shape, o),)\n", " layers = [m for m in flatten_model(learn.model)]\n", " xb,_ = learn.dbunch.train_dl.one_batch()\n", " with Hooks(layers, _track) as h:\n", " _ = learn.model.eval()(apply(lambda o:o[:1], xb))\n", " return xb,h.stored" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))\n", "learn = synth_learner()\n", "learn.model=m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(layer_info(learn)[1], [\n", " ('Linear', 100, True, [1, 50]),\n", " ('ReLU', 0, False, [1, 50]),\n", " ('BatchNorm1d', 100, True, [1, 50]),\n", " ('Linear', 51, True, [1, 1])\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _print_shapes(o, bs):\n", " if isinstance(o, torch.Size): return ' x '.join([str(bs)] + [str(t) for t in o[1:]])\n", " else: return [_print_shapes(x, bs) for x in o]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def summary(self:Learner):\n", " \"Print a summary of the model, optimizer and loss function.\"\n", " xb,infos = layer_info(self)\n", " n,bs = 64,find_bs(xb)\n", " inp_sz = _print_shapes(apply(lambda x:x.shape, xb), bs)\n", " res = f\"{self.model.__class__.__name__} (Input shape: {inp_sz})\\n\"\n", " res += \"=\" * n + \"\\n\"\n", " res += f\"{'Layer (type)':<20} {'Output Shape':<20} {'Param #':<10} {'Trainable':<10}\\n\"\n", " res += \"=\" * n + \"\\n\"\n", " ps,trn_ps = 0,0\n", " for typ,np,trn,sz in infos:\n", " if sz is None: continue\n", " ps += np\n", " if trn: trn_ps += np\n", " res += f\"{typ:<20} {_print_shapes(sz, bs):<20} {np:<10,} {str(trn):<10}\\n\"\n", " res += \"_\" * n + \"\\n\"\n", " res += f\"\\nTotal params: {ps:,}\\n\"\n", " res += f\"Total trainable params: {trn_ps:,}\\n\"\n", " res += f\"Total non-trainable params: {ps - trn_ps:,}\\n\\n\"\n", " res += f\"Optimizer used: {self.opt_func}\\nLoss function: {self.loss_func}\\n\\n\"\n", " if self.opt is not None:\n", " res += f\"Model \" + (\"unfrozen\\n\\n\" if self.opt.frozen_idx==0 else f\"frozen up to parameter group number {self.opt.frozen_idx}\\n\\n\")\n", " res += \"Callbacks:\\n\" + '\\n'.join(f\" - {cb}\" for cb in sort_by_run(self.cbs))\n", " return PrettyString(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential (Input shape: 16 x 1)\n", "================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "================================================================\n", "Linear 16 x 50 100 False \n", "________________________________________________________________\n", "ReLU 16 x 50 0 False \n", "________________________________________________________________\n", "BatchNorm1d 16 x 50 100 True \n", "________________________________________________________________\n", "Linear 16 x 1 51 True \n", "________________________________________________________________\n", "\n", "Total params: 251\n", "Total trainable params: 151\n", "Total non-trainable params: 100\n", "\n", "Optimizer used: functools.partial(, mom=0.9)\n", "Loss function: FlattenedLoss of MSELoss()\n", "\n", "Model unfrozen\n", "\n", "Callbacks:\n", " - TrainEvalCallback\n", " - Recorder" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))\n", "for p in m[0].parameters(): p.requires_grad_(False)\n", "learn = synth_learner()\n", "learn.create_opt()\n", "learn.model=m\n", "learn.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Activation graphs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of a `HookCallback`, that stores the mean, stds and histograms of activations that go through the network." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#exports\n", "@delegates()\n", "class ActivationStats(HookCallback):\n", " \"Callback that record the mean and std of activations.\"\n", " run_before=TrainEvalCallback\n", " def __init__(self, with_hist=False, **kwargs):\n", " super().__init__(**kwargs)\n", " self.with_hist = with_hist\n", " \n", " def begin_fit(self):\n", " \"Initialize stats.\"\n", " super().begin_fit()\n", " self.stats = L()\n", " \n", " def hook(self, m, i, o): \n", " o = o.float()\n", " res = {'mean': o.mean().item(), 'std': o.std().item(), 'percent_null': (o<=0.05).long().sum().item()/o.numel()}\n", " if self.with_hist: res['hist'] = o.histc(40,0,10)\n", " return res\n", " \n", " def after_batch(self):\n", " \"Take the stored results and puts it in `self.stats`\"\n", " if self.training and (self.every is None or self.train_iter%self.every != 0): self.stats.append(self.hooks.stored)\n", " super().after_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,15.506263732910156,14.684226989746094,00:00]\n" ] } ], "source": [ "learn = synth_learner(n_trn=5, cbs = ActivationStats(every=4))\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#3) [(#1) [{'mean': -0.9653729200363159, 'std': 1.5585802793502808, 'percent_null': 0.8125}],(#1) [{'mean': -0.9653729200363159, 'std': 1.5585802793502808, 'percent_null': 0.8125}],(#1) [{'mean': -0.9653729200363159, 'std': 1.5585802793502808, 'percent_null': 0.8125}]]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.activation_stats.stats" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first line contains the means of the outputs of the model for each batch in the training set, the second line their standard deviations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "40" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(learn.activation_stats.stats[0][0]['hist'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,4.010986804962158,3.2662320137023926,00:00]\n" ] } ], "source": [ "#hide\n", "class TstCallback(HookCallback):\n", " def hook(self, m, i, o): return o\n", " def begin_fit(self):\n", " super().begin_fit()\n", " self.means,self.stds = [],[]\n", " \n", " def after_batch(self):\n", " if self.training:\n", " self.means.append(self.hooks.stored[0].mean().item())\n", " self.stds.append (self.hooks.stored[0].std() .item())\n", "\n", "learn = synth_learner(n_trn=5, cbs = [TstCallback(), ActivationStats()])\n", "learn.fit(1)\n", "test_eq(learn.activation_stats.stats.itemgot(0).itemgot(\"mean\"), learn.tst.means)\n", "test_eq(learn.activation_stats.stats.itemgot(0).itemgot(\"std\"), learn.tst.stds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }