{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hook callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This provides both a standalone class and a callback for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks), along with some pre-defined hooks. Hooks can be attached to any [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), for either the forward or the backward pass.\n", "\n", "We'll start by looking at the pre-defined hook [`ActivationStats`](/callbacks.hooks.html#ActivationStats), then we'll see how to create our own." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.hooks import * \n", "from fastai.train import *\n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ActivationStats[source]

\n", "\n", "> ActivationStats(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`HookCallback`](/callbacks.hooks.html#HookCallback)\n", "\n", "Callback that record the mean and std of activations. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`ActivationStats`](/callbacks.hooks.html#ActivationStats) saves the layer activations in `self.stats` for all `modules` passed to it. By default it will save activations for *all* modules. For instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:02

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_loss
10.1123840.083544
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "#learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats)\n", "learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats)\n", "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The saved `stats` is a `FloatTensor` of shape `(2,num_modules,num_batches)`. The first axis is `(mean,stdev)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(193, 3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(learn.data.train_dl),len(learn.activation_stats.modules)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 193])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.activation_stats.stats.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So this shows the standard deviation (`axis0==1`) of 2th last layer (`axis1==-2`) for each batch (`axis2`):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(learn.activation_stats.stats[1][-2].numpy());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Internal implementation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook[source]

\n", "\n", "> hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`i`**:`Tensors`, **`o`**:`Tensors`) → `Tuple`\\[`Rank0Tensor`, `Rank0Tensor`\\]\n", "\n", "Take the mean and std of `o`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.hook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**)\n", "\n", "Initialize stats. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source]

\n", "\n", "> on_batch_end(**`train`**, **\\*\\*`kwargs`**)\n", "\n", "Take the stored results and puts it in `self.stats` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_batch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "Polish the final result. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_train_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Hook[source]

\n", "\n", "> Hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n", "\n", "Create a hook on `m` with `hook_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Registers and manually deregisters a [PyTorch hook](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Your `hook_func` will be called automatically when forward/backward (depending on `is_forward`) for your module `m` is run, and the result of that function is placed in `self.stored`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

remove[source]

\n", "\n", "> remove()\n", "\n", "Remove the hook from the model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deregister the hook, if not called already." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Hooks[source]

\n", "\n", "> Hooks(**`ms`**:`ModuleList`, **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n", "\n", "Create several hooks on the modules in `ms` with `hook_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Acts as a `Collection` (i.e. `len(hooks)` and `hooks[i]`) and an `Iterator` (i.e. `for hook in hooks`) of a group of hooks, one for each module in `ms`, with the ability to remove all as a group. Use `stored` to get all hook results. `hook_func` and `is_forward` behavior is the same as [`Hook`](/callbacks.hooks.html#Hook). See the source code for [`HookCallback`](/callbacks.hooks.html#HookCallback) for a simple example." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

remove[source]

\n", "\n", "> remove()\n", "\n", "Remove the hooks from the model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deregister all hooks created by this class, if not previously called." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convenience functions for hooks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook_output[source]

\n", "\n", "> hook_output(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hook`](/callbacks.hooks.html#Hook)\n", "\n", "Return a [`Hook`](/callbacks.hooks.html#Hook) that stores activations of `module` in `self.stored` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(hook_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Function that creates a [`Hook`](/callbacks.hooks.html#Hook) for `module` that simply stores the output of the layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook_outputs[source]

\n", "\n", "> hook_outputs(**`modules`**:`ModuleList`, **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hooks`](/callbacks.hooks.html#Hooks)\n", "\n", "Return [`Hooks`](/callbacks.hooks.html#Hooks) that store activations of all `modules` in `self.stored` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(hook_outputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Function that creates a [`Hook`](/callbacks.hooks.html#Hook) for all passed `modules` that simply stores the output of the layers. For example, the (slightly simplified) source code of [`model_sizes`](/callbacks.hooks.html#model_sizes) is:\n", "\n", "```python\n", "def model_sizes(m, size):\n", " x = m(torch.zeros(1, in_channels(m), *size))\n", " return [o.stored.shape for o in hook_outputs(m)]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

model_sizes[source]

\n", "\n", "> model_sizes(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tuple`\\[`Sizes`, `Tensor`, [`Hooks`](/callbacks.hooks.html#Hooks)\\]\n", "\n", "Pass a dummy input through the model `m` to get the various sizes of activations. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(model_sizes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

model_summary[source]

\n", "\n", "> model_summary(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n", "\n", "Print a summary of `m` using a output text width of `n` chars " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(model_summary)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

num_features_model[source]

\n", "\n", "> num_features_model(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)) → `int`\n", "\n", "Return the number of output features for `model`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(num_features_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a [`DynamicUnet`](/vision.models.unet.html#DynamicUnet)), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of `size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dummy_batch[source]

\n", "\n", "> dummy_batch(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tensor`\n", "\n", "Create a dummy batch to go through `m` with `size`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(dummy_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dummy_eval[source]

\n", "\n", "> dummy_eval(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***)\n", "\n", "Pass a [`dummy_batch`](/callbacks.hooks.html#dummy_batch) in evaluation mode in `m` with `size`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(dummy_eval)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class HookCallback[source]

\n", "\n", "> HookCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For all `modules`, uses a callback to automatically register a method `self.hook` (that you must define in an inherited class) as a hook. This method must have the signature:\n", "\n", "```python\n", "def hook(self, m:Model, input:Tensors, output:Tensors)\n", "```\n", "\n", "If `do_remove` then the hook is automatically deregistered at the end of training. See [`ActivationStats`](/callbacks.hooks.html#ActivationStats) for a simple example of inheriting from this class." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**)\n", "\n", "Register the [`Hooks`](/callbacks.hooks.html#Hooks) on `self.modules`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "Remove the [`Hooks`](/callbacks.hooks.html#Hooks). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.on_train_end)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

remove[source]

\n", "\n", "> remove()" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.remove)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

hook_fn[source]

\n", "\n", "> hook_fn(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`input`**:`Tensors`, **`output`**:`Tensors`)\n", "\n", "Applies `hook_func` to `module`, `input`, `output`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.hook_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implement callbacks using hooks", "title": "callbacks.hooks" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }