{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hook callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This provides both a standalone class and a callback for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks), along with some pre-defined hooks. Hooks can be attached to any [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), for either the forward or the backward pass.\n", "\n", "We'll start by looking at the pre-defined hook [`ActivationStats`](/callbacks.hooks.html#ActivationStats), then we'll see how to create our own." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.hooks import * \n", "from fastai import *\n", "from fastai.train import *\n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class ActivationStats[source]ActivationStats(`learn`:[`Learner`](/basic_train.html#Learner), `modules`:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=`None`, `do_remove`:`bool`=`True`) :: [`HookCallback`](/callbacks.hooks.html#HookCallback)\n",
"\n",
"Callback that record the activations. "
],
"text/plain": [
"class Hook[source]Hook(`m`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `hook_func`:`HookFunc`, `is_forward`:`bool`=`True`)\n",
"\n",
"Create a hook. "
],
"text/plain": [
"remove[source]remove()"
],
"text/plain": [
"class Hooks[source]Hooks(`ms`:`ModuleList`, `hook_func`:`HookFunc`, `is_forward`:`bool`=`True`)\n",
"\n",
"Create several hooks. "
],
"text/plain": [
"remove[source]remove()"
],
"text/plain": [
"hook_output[source]hook_output(`module`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)) → [`Hook`](/callbacks.hooks.html#Hook)"
],
"text/plain": [
"hook_outputs[source]hook_outputs(`modules`:`ModuleList`) → [`Hooks`](/callbacks.hooks.html#Hooks)"
],
"text/plain": [
"model_sizes[source]model_sizes(`m`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `size`:`tuple`=`(256, 256)`, `full`:`bool`=`True`) → `Tuple`\\[`Sizes`, `Tensor`, [`Hooks`](/callbacks.hooks.html#Hooks)\\]\n",
"\n",
"Pass a dummy input through the model to get the various sizes. "
],
"text/plain": [
"class HookCallback[source]HookCallback(`learn`:[`Learner`](/basic_train.html#Learner), `modules`:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=`None`, `do_remove`:`bool`=`True`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Callback that registers given hooks. "
],
"text/plain": [
"remove[source]remove()"
],
"text/plain": [
"on_train_begin[source]on_train_begin(`kwargs`)\n",
"\n",
"To initialize constants in the callback. "
],
"text/plain": [
"on_train_end[source]on_train_end(`kwargs`)\n",
"\n",
"Useful for cleaning up things and saving files/models. "
],
"text/plain": [
"hook[source]hook(`m`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `i`:`Tensors`, `o`:`Tensors`) → `Tuple`\\[`Rank0Tensor`, `Rank0Tensor`\\]"
],
"text/plain": [
"on_batch_end[source]on_batch_end(`train`, `kwargs`)\n",
"\n",
"Called at the end of the batch. "
],
"text/plain": [
"on_train_begin[source]on_train_begin(`kwargs`)\n",
"\n",
"To initialize constants in the callback. "
],
"text/plain": [
"on_train_end[source]on_train_end(`kwargs`)\n",
"\n",
"Useful for cleaning up things and saving files/models. "
],
"text/plain": [
"hook_fn[source]hook_fn(`module`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `input`:`Tensors`, `output`:`Tensors`)"
],
"text/plain": [
"