{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hook callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This provides both a standalone class and a callback for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks), along with some pre-defined hooks. Hooks can be attached to any [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), for either the forward or the backward pass.\n", "\n", "We'll start by looking at the pre-defined hook [`ActivationStats`](/callbacks.hooks.html#ActivationStats), then we'll see how to create our own." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.hooks import * \n", "from fastai.train import *\n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class ActivationStats[source][test]ActivationStats(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`HookCallback`](/callbacks.hooks.html#HookCallback)\n",
"\n",
"No tests found for ActivationStats. To contribute a test please refer to this guide and this discussion.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "
|---|---|---|
| 1 | \n", "0.112384 | \n", "0.083544 | \n", "
hook[source][test]hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`i`**:`Tensors`, **`o`**:`Tensors`) → `Tuple`\\[`Rank0Tensor`, `Rank0Tensor`\\]\n",
"\n",
"on_train_begin[source][test]on_train_begin(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.
on_batch_end[source][test]on_batch_end(**`train`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.
on_train_end[source][test]on_train_end(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.
class Hook[source][test]Hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n",
"\n",
"No tests found for Hook. To contribute a test please refer to this guide and this discussion.
remove[source][test]remove()\n",
"\n",
"No tests found for remove. To contribute a test please refer to this guide and this discussion.
class Hooks[source][test]Hooks(**`ms`**:`ModuleList`, **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n",
"\n",
"No tests found for Hooks. To contribute a test please refer to this guide and this discussion.
remove[source][test]remove()\n",
"\n",
"No tests found for remove. To contribute a test please refer to this guide and this discussion.
hook_output[source][test]hook_output(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hook`](/callbacks.hooks.html#Hook)\n",
"\n",
"hook_outputs[source][test]hook_outputs(**`modules`**:`ModuleList`, **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hooks`](/callbacks.hooks.html#Hooks)\n",
"\n",
"No tests found for hook_outputs. To contribute a test please refer to this guide and this discussion.
model_sizes[source][test]model_sizes(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tuple`\\[`Sizes`, `Tensor`, [`Hooks`](/callbacks.hooks.html#Hooks)\\]\n",
"\n",
"No tests found for model_sizes. To contribute a test please refer to this guide and this discussion.
model_summary[source][test]model_summary(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n",
"\n",
"Tests found for model_summary:
pytest -sv tests/test_basic_train.py::test_export_load_learner [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]To run tests please refer to this guide.
num_features_model[source][test]num_features_model(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)) → `int`\n",
"\n",
"No tests found for num_features_model. To contribute a test please refer to this guide and this discussion.
dummy_batch[source][test]dummy_batch(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tensor`\n",
"\n",
"No tests found for dummy_batch. To contribute a test please refer to this guide and this discussion.
dummy_eval[source][test]dummy_eval(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***)\n",
"\n",
"No tests found for dummy_eval. To contribute a test please refer to this guide and this discussion.
class HookCallback[source][test]HookCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for HookCallback. To contribute a test please refer to this guide and this discussion.
on_train_begin[source][test]on_train_begin(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.
on_train_end[source][test]on_train_end(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.
remove[source][test]remove()\n",
"\n",
"No tests found for remove. To contribute a test please refer to this guide and this discussion.
hook_fn[source][test]hook_fn(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`input`**:`Tensors`, **`output`**:`Tensors`)\n",
"\n",
"No tests found for hook_fn. To contribute a test please refer to this guide and this discussion.