{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hook callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This provides both a standalone class and a callback for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks), along with some pre-defined hooks. Hooks can be attached to any [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), for either the forward or the backward pass.\n", "\n", "We'll start by looking at the pre-defined hook [`ActivationStats`](/callbacks.hooks.html#ActivationStats), then we'll see how to create our own." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.hooks import * \n", "from fastai.train import *\n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ActivationStats[source][test]

\n", "\n", "> ActivationStats(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`HookCallback`](/callbacks.hooks.html#HookCallback)\n", "\n", "
×

No tests found for ActivationStats. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Callback that record the mean and std of activations. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`ActivationStats`](/callbacks.hooks.html#ActivationStats) saves the layer activations in `self.stats` for all `modules` passed to it. By default it will save activations for *all* modules. For instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.1426660.10116600:03
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "#learn = cnn_learner(data, models.resnet18, callback_fns=ActivationStats)\n", "learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats)\n", "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The saved `stats` is a `FloatTensor` of shape `(2,num_modules,num_batches)`. The first axis is `(mean,stdev)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(193, 3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(learn.data.train_dl),len(learn.activation_stats.modules)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 193])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.activation_stats.stats.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So this shows the standard deviation (`axis0==1`) of 2th last layer (`axis1==-2`) for each batch (`axis2`):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXzcdZ348dc7k/tOmvtqmjZteh+kLYW23ArIKR4gK6JoV0VZQHaF1WV3PVbxt4sKIoiogCKHHFIUkFIK5WybHumZNmmb+74mSZNJMjOf3x8zGXI2aZtkJpn38/HIozPf73dm3v1m8p7PvL+fQ4wxKKWUmt4CvB2AUkqpiafJXiml/IAme6WU8gOa7JVSyg9osldKKT8Q6K0XTkhIMNnZ2d56eaWUmpJ27tzZaIxJPNXHeS3ZZ2dnU1BQ4K2XV0qpKUlEyk7ncVrGUUopP6DJXiml/IAme6WU8gOa7JVSyg9osldKKT+gyV4ppfyAJnullPIDmuyVGoExhhd3VVLR3OntUJQ6Y5rslRqGw2n495f2c+dzhfxp22mNYVHKp2iyV2oYj717jKe3lyMCrSd6vR2OmkZ6HU567M5Jf12vTZeglC/bV2UlKz6csCALrV093g5HTWFOp+FP28rIiAujx+7kZ68f5oZVWXxtfc6kxqHJXim3hvZuQoMCiAoNoq7NRmpMKAZo7dSWvTp9e6us3PvyAc/9OUmR5CZHTnocmuyVX7I7nBRWWmnq6GZneQvvHG6gqLad9XMTefIrq6hts7EiK46uHgfleoFWnYHCilYAfv75pYQGWrhkQTKBlsmvoGuyV37p9+8f539eLQIgyCKcNTOOvJQoDte2YYyhrq2blOhQWjp72FupLXt1+gorWkmKCuGaZemIiNfi0GSv/NLxxhPEhQfx+JdXMTspksiQQB7YXMz9m45Q22ajx+4kOdpdxtGavRrFweo2chIjCA2yDNm3p7KVJRmxXk30MIbeOCISKiLbRaRQRA6IyH8Pc8zNItIgInvcP1+dmHCVGh81VhtpsWEszYwlMsTV5smMDwOgoLQFgJSYUGLCgrD1OrH1OrwWq5pcJfUdlNS3j3pcjbULYwyv7avh8gfe5Y8fDu2i22br5VjDCZZlxkxEqKdkLIWjbuBCY8xSYBlwqYicPcxxzxpjlrl/HhvXKJU6Aye67fzs9SI6e+yebbVW1wXY/jLjwgEoKG0GIDk6lNjwIACsXVrK8RfffWEvN/1uO72OkbtH7q1sZc1P3uKzj3zIXX8pBOCjY01DjttXaQVgSUbsxAR7CkZN9salw303yP1jJjQqpc7QoZo27nlxL70OJ28crOXXbx/l/ZKP/xjr2mwkRw9M9lnxrmS/vV/LPjYsGNAeOf7CGMORunaqrTY27qke8bgd7vfI0YYOwoIDOX9eIjvLW3A6B6bGwkrXxdklGVOjZY+IWERkD1APbDLGbBvmsOtEZK+IPC8imeMapVKn6BdvHuHp7RXsLm/1lGVqrV0A2HodtHT2DmnZJ0aFEBIYQFFtGyKQFBXiadm3dmrd3h80dvTQbnN9A/zN1qMYM3y79kC1lcSoEN6/+0LevHM9ly9KpbWzl6MNHTy1rYztx13fDrcdayZ7Rjix4cGT9n8YyZiSvTHGYYxZBmQAq0Rk0aBDXgGyjTFLgDeBJ4Z7HhHZICIFIlLQ0NBwJnErNaJaq403D9UD8F5JoyfZV1ttgKtVDwxp2YsImfHhGAMzIkIIsgQQE+ZO9mMo4xxt6KCpo3vc/h9q8h1tcBUxrl2ezpG6Dm58bBtvHKgdctyBqjYWpkUTHhxIbHgw+dlxADz5YRnfe2k/3/jTTrYU1fPOkQauWZ4+qf+HkZxSZ09jTCvwNnDpoO1Nxpi+d/lvgbNGePyjxph8Y0x+YuIpL46u1Jg8s6Mch9OQERfG6/trOOK+2FbT6mrZ17qTfsqglj1AZpzrIm1fq99Tsx9DGeem323n/zYdOfP/gPKaYw0nALjj4rnceclcKlo6+fqfdlJU2+Y5xtbroKShg0VpH5dmZiVEMCMimD9+VEZEsIWWzh6+9mQBiVEhfG3d5I6UHclYeuMkikis+3YYcDFQNOiY1H53rwIOjWeQSo2VMYZnd1Swfm4i1yxztc6MgZDAAE/Lvtbdsh9cxoGP6/Z9rf6+r9+jdb880W2nqrWL+jZt2U81xhhuePQjHnv3GMcaOggNCiAjLozbLsrllW+tJTosiP/eeNBT0jlc247DaViYFu15DhFhxUxX6/6r63K4+ZxZ2J2GOy6eS0SIb/RwH0sUqcATImLB9eHwnDHmbyLyA6DAGLMRuE1ErgLsQDNw80QFrNTJVLV2UWO1cesFc5idGMmvtpRgCRDWz03kcK2rhd/Xsh9cxgHIdCf7lJgQACKCLQQGyKgXaPtG2bbZ9ELuVHOkroMPjzVRbe0iJyGCWQmRBAS4+sTHhgdz5yVzufflA7xxsI5PLkzhQLWrlb8wbeBF108sSOZgdRtfWTuLkMAA1ubO4Ly5SZP+/xnJWHrj7DXGLDfGLDHGLDLG/MC9/V53oscYc48xZqExZqkx5gJjTNHJn1WpiVFc56q55qVEsWJmLGFBFhamRTM7MZJaqw2n01DbZiMi2EJUaNCQx2e4u1+muD8IRITY8CBPzb6ypZNP/PwdjjeeGPC4siZ3stcumqekzdbLL948MqldW4tq2/jw6Mc9s948VAe4focfHmsiJzFiwPFfWJXF7MQI7n/jCE6n4UC1lajQQM+4jD6fzc/kve9eQExYEKFBFi7MS8YS4N2BVP3pFMdqWjlc52q95yZHERJo4e7L8rj1gjmkxoTS43DSdKKHujbbsPV6cNVeAdLjPv5DjgkL8tTsNx+q50hdB++XNA54XFmTK/lrsh87u8PJt/+8m1+8Wcxr+2om9LUKSpspqXc1BO57rYgvP76denc5b9PBOnISIggQsPU6mZ04cJKyQEsAt12Uy+G6dv60rYytxQ0sSI0edkSst0fJnowmezWtHKltJ9U98hXgS+dk88mFKZ76fI21i1rryMl+XkoUj92Uz+WLP74MFRse7KnZ9w2cOVI3cIRlmaeMY0eNzf2bjvDOkQYCA8TTb30iOJ2GDX/cyX2vuwoO5c2d2Hqd/HJzMfXtNvZUtPLpFemsnjUDgNmDWvYAVyxJIycxgntfPkBjew+3XZQ7YfFOFN+4cqDUODlc187c5Kgh29NiXS316lYbtVYbZ8+eMeJzXLwgecD92LAgattcJaBt7v7TRbUDk325u4zT0W3H7nB6ZVZDX1Pd2oXDaTzXQQb7294aLsxLIjBA2FnWPGFxHK5rp/lED6WNJzDGUNnSRbAlgGd2VHi6Wl68IJmY8GA+PNY07PvHEiB8/1Pz+cmrRfzsM0tYnhU3YfFOFE32atpwOA3F9R2cOydhyL6+ln11axf17d3D9sQZSUx4EEW17RTXd9B8oofo0ECO1LVjjPF8bS9r/riG326zExfh/UE03vZvz+/laEMHW+46f8gEYb0OJ1WtXVy1NI3osEDeOFhHfbuNpKix/17Gqq8+X97cSUN7N912J9+6YA4v7a6ixmrj+pWZzEuOYk5iJDkJEcxPjR72eS7MS+bCvORh900FmuzVtFHWdIIeu3PYlll8RDAhgQG8XFiN3WmYOWPoV/WRxIYFY+3q9ZRwPr8yk9++e5yG9m6SokPpdTipbrWRHhtGVWsXbbZev0/2xhgKK1tpt9l5dkcFXzone8D+qhZXqz9rRjhzklw18p2lLVzWr3zWcqKHkoYO2rp6CQ8OZF5KFPGncV4/dP/euu1OdpW7ykXLs2K565PzBhwXaJFhGwrThSZ7NSW12Xr587Zy5qVEsSZnBqFBFk8dfe4wqwCJCKkxoRRWtJIRF8ZVS9PG/Fqx4UF0dNt5pbCa9NgwLsxL5rfvHmd/tZWtbx9lcXoMDqdhcXoMVa1dOmkaUNHcRbvNTrAlgIe2lPD5lZkDWvd91zhmxoezKC2GkMAACspauHRRiufb0nUPf8Cxfr2elmTEsPFba08pDofTsO1Yk+eD+D33hfW+Xlf+RAuLakr6x/5afvpaEV/+ww6++dQuAA7XdiCCp6U4WN9F2Xsumz/svOMj6WulF5S18IXVWZ4Pk5++VsTjH5Ry1/OuWQ8Xuye7auvy34u0rZ092HodHKxxzfb4nU/Mpb69m9f2D+xtU+7uvTRzRgTBgQEszYzljx+VMe8/XufFXZVUtnRyrPEEX107i7/eei5fWJ3FvirrmMcxGGP4xZtH+OHfDtJms/P5la7puvomw8uICzvZw6clbdmrKalvfpsrl6bx5sE6eh1Odle0kJMQQXjw8G/ri/KSSYgM4fLFKaf0WlcsTqXH7uSivCSy3V0zEyJDOFLXweL0GMqbO7F29bI43ZXs/bVlb4zh6ofeZ1V2PCkxoVgChC+umckvNxdTWGHl2uUZnmPLmjoJCQwgKco1eG3Duhxe2l3FjtJmXtxVRd/8Y9edlcH81GhOdNv587ZydpW1cP4810ClgtJmHnyrhG3Hm7hySRrfvjCXrBmuFntlSxe/eLMYcF1cve6sDH65uZjjjSeIjwj2mVGtk8n//sdqWqhtsxEXHsSlC1N4pbCaPRWtbD/ezHUrMkZ8zNfWn94cJXERwdyydtaAbXkpUbxX0s0Pr1mE3eHk9f21nm8U/jqKtqK5i7KmTmqtNpZmxDI70fXBuyA1mgPV1gHHljV3khUf7hmpevGCZC5ekMxPXj3E798/TnRYILHhQcxzX39ZlhmLJUDYWdbC+txEfrm5mAfeKiY+PJhPLEjhlb3VvF/SyGu3rycmLMgz6+SfbllNUnQI6bFhpMeGUd7c6ZetetAyjpqi6tq6SY4OZeUsVxe4R7ceo7PHwblzRu5SOZ6+sjabuy/LY1lmLPnZ8Xz/igWeSdP8aWCV3eHku8/vZV+llZ3lrgTbbXeyvbSZBe5eLYvSYzhQ3TZgrvfyps5hL5JfND+ZXofh1X21rMyO93wYRIS4PjR2lDZz3+tF/HJzMdcuT+fd717AAzcs55kNa6hr7+a/Nx4AoKCsmajQQM6ZPcNzwX6mu9Wf6Yf1etCWvZqi6ttsJEWHkhQVyqyECDYdrEMEzs6ZnGQ/XDe8sCDXPDr+VMbZVd7KswUVtHf3EhceTGRIICkxoZTUd3jmjlmYFk1nj4PjTSfYfKiO/Ox4yppPsDZ3aM+XFVmxrukpOntZPSt+wL787Die2lZOQWkLn8vP4L7rlngu5i7LjOVbF8zhl5uLuXJpGjtKW8ifGef5sABXsn+32D/r9aAtezVF1bV1k+yu967KdiWFhWnRXl0kQkSIDgvy6TKOw2mwn2S5vVP1VpFr3YA3D9XzbnEjy7Ni+exZrlJa36yQi9zXMv74YRn/82oRtzy+A1uv09PS7i/QEsAF7pr84A/uldnx9NidhAdb+O6leUOmJvjWhXPIig/nR38/SEl9B/nZAz8sZsa7vklkjDDIa7rTZK+mHIfT0NDR7elds9LdAjx3tvf7SMeEBWE9zd44I62KNF6MMWx4soDP/ebDcXvOLUX1JEeH0GN3Ut7cSf7MeG5ak82Pr13EaneynpMUSXBgAE98WEp4sMUzqVzWCEn3K+fO4sbVWUMGN62aFU94sIW7L5vPjMiQIY8LsgTw7QvncNQ9J/3Kwcne/eGiLXulpoimjm4cTkOSe2bK9XMTyEmI4FNLUkd55MSLDg08rZq9rdfBuT99i9+8c3QConJ5rqCCzUX17Cpv9UwCdjr2Vrbyb88XsvVIA4fr2vnq2hxPIs3PjiMs2MKNq2d6ZnwMsgSQlxKFMfDFNTP54tkzAYZMONZncUYMP7528ZAZIxMiQ9h97yV8YXXWiLFduzyd7BnhBFsChqz7et68RL53+XyfaBR4g9bs1ZRT514gpK+MkxQVylt3ne/FiD4WHRaEtauX440n6HUMHM279UgDbTZXF83BFyf3Vlqpttr42T8Os2pW/LjPvdLQ3s2P/naI7BnhlDZ18l5JI58+Sc+lkTy69Sg/fa0Ip4HnCioBuCAviV6nk19vOcqyzNhhH7fYfZH2S2uySYgM4aqlaSPOmXMyIYEnHx8RaAng/z63zL0IycBjQwItp90jazrQlr2ackZaQ9YX9NXs73h2D7c9vduzvflED19+fAff+vNuLvjft6lyL5HYp8A9EVhiZAi3P7tnXOvqAO8WN9DebefBG1YQHxHMe8WNoz9oELvDyQObSzh3TgKv3raOOUmR5KVEMTsxgn9eP5ut/3bBiP3X/+XiXJ7ZcDZpsWEEBwYMqaePp7NmxvHZ/MwJe/6pSpO9mnLq2n042YcGUWu1UVjZSnF9B912BwBvHKjF4TTcc1keTuMaENRfQWkLc5Ii+bdL51HW1MnhunaqW7u46ffbqW8fWHJp6ujmgc3F7Ksc2Hf9ZPZWWgkLsrAgLZpz5yTwbknjgGsEXT0Ontlejq3XMeJzFFZa6ei2c/3KLBakRfP6v6zjhW+cg4hgCZCTzluTFBU6pIauJpcmezXl1FltBAgkRPreZGMxYUF09jgwxnUhuW/BjL/vq2HmjHBuWTuLsCALu8tbPY9xOg0Fpc2szI7jLPc6poUVVl7bX8vWIw1s3FPtOfaDkkbW/2wL9286wkNbSk4ay/4qK//yzG667Q72VVlZlB6NJUBYNyeBhvZuz0IvAE9tK+PuF/fxvZf2j3ih+AP3vDJr3NNDB1oC/HIk6lSlyV5NOXVt3SREhvjknPHRYa7k13dxsaimnZYTPXxwtInLF6cS6L5wuLvi42RfXN9Bm81O/sx4suLDiQ0PorCile3HXfO4/ONArefYv+6pItASwLrcBHaUNg9JzAeqrZ7Ro3/bW8PLe6rZeqSRg9Vtni6Qff3bPyj5eGm+zYfqCbIIL+yq5A/vlw77f3v/aCML06JPa+ZJ5X2+99ei1Cjq2m0+WcIBPCtkXTI/mZDAAIpq23jjoKuE8yn39L3Ls+I4WG31lEx2uEs6+dlxiAhLM2I90z8EWYSCshZPKedowwnmpUTxqcWpNJ3o8cwKaXc4+d5L+7jiwff48h+243AaDte6FsZ+aEsJXb0OT++UtNgwUmNCPR841q5edpQ2c8vaHNbPTeSBt4o9sTmdhq89WcB9rxexq6x1Wk8BPN1psldTQv+h9q6pEob2s/YF0e5FzC/IS2RuchRFte28sLOKWQkRnkFGy7Ni6XUYz3wx/zhQS3J0iKff+dLMWA7XtdPS2cvN52RjjGudVIBjDR3MTozwXODc4W7FbzpYx1PbylmSHsOJHgcl9R0cdq+mtced1Benf9xTZnlWLLvdc7tvPdKA3Wm4ZEESG9bl0NrZ6/k2UVzfwaaDdTz89lF6HE5N9lOYJnvl82qtNlb8aBO3P7ObjYXVHGvo8Cwz6Gvys+O4eH4Sn1iQQl5KFAWlLWwvbeZz+ZmeEZ/Ls1xJd3d5Kx8cbeTd4ka+ujan39D/j/uH/9PZM8lJiOD1/bW0nOihpbOX2YmRzE6MYEZEsGft1r/uqSIhMoSffWYpAO+VNFJttbHU3ZqPCLaQk/Bxd8/lmXFUtnTR0N7N5kN1xEcEsywzjnNmzyArPpw/bysH4MOjrjr9v35yHpcuTBkyhYGaOkZN9iISKiLbRaRQRA6IyH8Pc0yIiDwrIiUisk1EsiciWOWfnt5eTmtnL6/sreG2p3czOzGSr63zzf7SqTFhPPallcRFBJOXGk1Xr4PAAOG6s9I9xyRFhZIRF8aTH5bxny8fIDUmlC+umenZvyTD9WGQEh1KVnw4581LZEdps2fd25zECESE/Ow4dpQ2Y+3sZUtRA1cuTSU3KZLIkED+UlABwIb1swkNCmBhesyAeWKWuT9w3itpYHNRPefPS8QSIAQECNevymTb8WaONnTw4bEmMuLCuPWCOTzyxbNOaR0A5VvG0rLvBi40xiwFlgGXisjZg465BWgxxswBfg7cN75hKn/V63DyzI5yzpubyAvfOId7r1jAX28997QG5Ey2+SmuAVUXzU8asrbqD69ZRKBFKK7v4I5L5g5IogmRIeQmRbIuNwER4eycGdh6nby02zWIKSfBNfJ0ZXY85c2dfOcvhfQ4nFy7PJ2AAGFRerTng2F5Viw/+fRibr8od8DrL0qLITBA+PHfD9Fus3Nzv2UDP3tWJiGBATy4uZhtx5tZM0mTy6mJNWq/KeO63N/hvhvk/hncN+tq4L/ct58HfiUiYiZ6sg817W0+VEddWzc/umYmyzJjRxyh6YuWZMayPCuWDetnD9l3wbwkzstNpKy5k+xhJgT7y9fXeEaL9k309kphDUEW8czt8tmzMtla3Mibh+rISYjwLJ6yNCOWj465pvhNjQkdsGhIn7BgC3mpUeyvauO8uYmebxMAiVEh3HxONr/Zegz4uKulmtrG1ElWRCzATmAO8JAxZtugQ9KBCgBjjF1ErMAMoHHQ82wANgBkZY08v4VSfV7eU01KdCgX5iV5O5RTFhkSyEvfPHfE/QEBwqyE4Rc+7z97Z1xEMHkprou9uUmRni6nMeFBPPmVVXx0rIm48GBPzb8vceelRA2ZGbK/5Zlx7K9q49sXzhmy7+vnzebP28pp77ZP2rTRamKNKdkbYxzAMhGJBV4SkUXGmP39DhnuHTWkVW+MeRR4FCA/P19b/WpUh2raWDEzdsikWP5m9ax4imrbyUkc+uEwOBn3dbGclxI15Nj+NqzPYXFGzLBTF8RFBHPP5fN5v6TRZy+Gq1NzSr1xjDGtwNvApYN2VQKZACISCMQAzSh1Bmy9DsqaO8lNOnnS8gd90wXnjDBTZH8ZcWH883k5fG6U+WEy48NPeswXVmfx0I0rTi1Q5bPG0hsn0d2iR0TCgIuBokGHbQS+5L79GeAtrderM1VS34ExDJg50l+tyZlBQmTwmLo+igj3XDZ/QB1eqbGUcVKBJ9x1+wDgOWPM30TkB0CBMWYj8DvgjyJSgqtFf/2ERaz8RnG9q0fJ3OTRW7PTXVxEMAXfv8TbYagpbCy9cfYCy4fZfm+/2zbgs+MbmvJ3R+o6CLII2SNcxFRKjZ2OoFU+q7iunVkJEQT54IRnSk01+lekfNaRug5ytV6v1LjQZK98xr5KK995rpAeu5OuHgcVLZ3M1Z44So0LXXlA+YzX9tfwwq5K1ubOIHtGhLsnjl6cVWo8aLJXPqO8uROA3713nMTIECKCLRO6VqlS/kSTvfIZ5c2dBAYI+6tci258/1PzSYzyzXnrlZpqtGavfEZZUydXLUsjLjyI+anRA2ZiVEqdGW3ZK59g7ezF2tVLXkoUt14wh6jQQJ9cY1apqUqTvfIJffX6rPgIZo9h/hel1KnRppPyCWXNroWzZw4zt7tS6sxpslc+oa9lPxVWoFJqKtJkr3xCeVMnCZHBRIZoZVGpiaDJXvmEsqZOsrRVr9SE0WSvfEJ5syZ7pSaSJnvldbZeBzXWLrJm6FTGSk0UTfbK63aXt+I0sCwzxtuhKDVtabJXXrejtBkROGumzoOj1ETRZK+8bvvxZvJSookJC/J2KEpNW5rs1YR6pbCaWx7fwUjrz/c6nOwqbxnTQtpKqdOnyV6Ni45uO912x5Dtr+6rYXNRvWfQ1GAHqtvo7HGwUqcyVmpCabJX4+Kzj3zIT18rGrK9qLYdgB2lLQC02Xp5eU8VHx1rAuD9kkYAVs6Km6RIlfJPow5XFJFM4EkgBXACjxpjfjnomPOBl4Hj7k0vGmN+ML6hKl/ldBqK69oJDJAB2zt77JQ2uea8KShtJjMujJt+v51uuxOAVbPi2VHazIqsWJKiQic9bqX8yVjGptuB7xhjdolIFLBTRDYZYw4OOu5dY8wV4x+i8nVNJ3qwOw1H6tqxO5yeqYkP17ZjDIQFWSgoa6HGaiMmLIhf37iCLYfr+e3W49y4Oot7Lpvv5f+BUtPfqMneGFMD1Lhvt4vIISAdGJzslZ+qa7MB0G13UtrUyZwk1xTFh2pcJZxrlqfz9PZySuo7uP3iXPKz48nPjueOi+fqnPVKTZJT+ksTkWxgObBtmN1rRKRQRF4TkYUjPH6DiBSISEFDQ8MpB6t8U63V5rldVNs24HZkSCDXLEsDwBIg3LAqy7NfE71Sk2fMf20iEgm8ANxujGkbtHsXMNMYsxR4EPjrcM9hjHnUGJNvjMlPTEw83ZiVj6lt65fs3a15gEM1beSlRLE0M5aQwAA+uTCZ5GitzSvlDWOaT1ZEgnAl+qeMMS8O3t8/+RtjXhWRX4tIgjGmcfxCVb6qrs1GgEB2QgRFtW30OpyUN3dSVNPO1cvTCA2y8Ow/ryEzLszboSrlt8bSG0eA3wGHjDH3j3BMClBnjDEisgrXN4amcY1U+axaq42EyBAWpcWw7XgT1/76ffZXuT7/F6e75rtZlhnrzRCV8ntjadmfC3wR2Ccie9zb/h3IAjDGPAJ8BviGiNiBLuB6M9KQSTXt1LbZSIkJZV5KFBsLq2ns6OHeKxaQFR/OurkJ3g5PKcXYeuO8B8gox/wK+NV4BaW8Z1+llQVp0VgCTvorH6CuzcbMGRGsyHINjPqfaxfx+ZVZozxKKTWZtDuE8ihrOsGVv3qPv+2tPulxO8taeHlPled+rdVGSnQoZ+fEs/s/LtFEr5QP0gU/lcexRtdo18O17Sc97uebjrCjtJlPLkzBGGiz2UmJCUVEiIsInoxQlVKnSFv2yqOqpQuA4+6k35/d4cQYg9NpKKxopdvu5MNjTZ4BVdqlUinfpi175VHdOjDZd/U4CA0KwOE0XPfwB8xOjOSbF8ymvdsOwDuHGwhbZAEgRZO9Uj5Nk73yqOqX7LvtDs7/3y0syYhl9ax4CiutFNW2syzL1YVy5oxw3j5cz3L3/ZSYEK/FrZQanZZxlEdfGafb7mTTwTrq2rrZdLCOH/39EBlxYXTbnTz89lGiQgK5+ZxsSps6efjto1gChJQYHTCllC/TZK88qlq7SI91Je1nd1QAcMvaWcSFB/GHm1eSEBlMjdXGkswYLsxLcj2mpYv/++xSIkP0S6JSvkyTvQJcywPWtdlYO8c1COrd4kbmJEXyH1csoOD7l5CbHPzxq88AABO8SURBVMUnFqYArtGwM2dE8MRXVvGPO9ZzzfJ0b4aulBoDTfYKcPWVdxpYnhVLeLDrousq97qwfQOsrlyS5t4+A4Dz5iaSFqvlG6WmAk32Cvj44mx6XBizEiIAhiwCvmb2DDbdsZ71uToFglJTjSZ7BXx8cTY9Noxsd7JfNWvoIuC5yVG45sZTSk0lelXNT+yvstLQ0c0F85KG3d/Xsk+LDePKJamEBVlI1R42Sk0bmuz9xL0v72d/VRtv3nkeWTPCh+yvaukiITKE0CALly5K5dJFqV6IUik1UbSM4wca2rvZXdFKj8PJT18/NGS/MYZ9VVYy47Ulr9R0pS17P7ClqB5j4LJFKby6r5YXd1WyLDOW776wl5XZ8SzNjOVgTRs/u26Jt0NVSk0QTfZ+YNOhOtJiQrn/c8uotn7Enc8VEhggiMCO0haiQgLJSYzg0yu0v7xS05WWcaY5W6+Dd4sbuHhBMmHBFl74+hq+/6n5fGJhMlvuOp8rlqTS3m3nO5fMI9Cibwelpitt2U9z2443Y+t1eqY3CLQE8NV1OZ79P//8Mr66LoelGTHeClEpNQk02U9z2483ERggw/aZBwiyBOhi4Er5Af3ePs1tP97MwvQYwoP1c10pf6bJfhqz9ToorLAOmfZAKeV/NNlPY4XuvvWrsjXZK+XvRv1uLyKZwJNACuAEHjXG/HLQMQL8Ergc6ARuNsbsGv9w1Vi8c6SBx949Rrx78e/87DgvR6SU8raxFHLtwHeMMbtEJArYKSKbjDEH+x1zGZDr/lkNPOz+V02SnWUtvFfcyMrsOG59ahcd7nVi81KiiA0P9nJ0SilvGzXZG2NqgBr37XYROQSkA/2T/dXAk8YYA3wkIrEikup+rJpg3XYHtz+7m4pm12RmCZEhPP+NNfx5W7lnjVillH87pS4aIpINLAe2DdqVDlT0u1/p3jYg2YvIBmADQFZW1qlFqkb0xAelVDR38bPPLKG8qZNLF6WQlxLND65e5O3QlFI+YszJXkQigReA240xbYN3D/MQM2SDMY8CjwLk5+cP2a9OnbWrlwffKuHCvCQ+l5/p7XCUUj5qTL1xRCQIV6J/yhjz4jCHVAL9M00GUH3m4anR7Cxrpt1m52v9RsUqpdRgoyZ7d0+b3wGHjDH3j3DYRuAmcTkbsGq9fnLsrbQiAkt0ugOl1EmMpYxzLvBFYJ+I7HFv+3cgC8AY8wjwKq5ulyW4ul5+efxDVcPZV2llTmIkESE6QlYpNbKx9MZ5j+Fr8v2PMcCt4xWUGru9VVbW6QLgSqlR6AjaKayuzUZDezdL0rWEo5Q6OU32U9jeSisAizO0L71S6uS00DsF1Vi7ePtwA7vKWggQWJAa7e2QlFI+TpP9FPPYu8f43zcOY+t1Aq7pEMKCLV6OSinl6zTZTyEtJ3r4n1cPce6cBP71k/PYX9VGbnKkt8NSSk0BmuynkLeP1OM0cNcn5rEkI5YlWqtXSo2RXqCdQt48WE9iVAiLtfeNUuoUabKfInrsTt450sBFeUkEBJx02INSSg2hyX6K2H68mY5uOxfPT/Z2KEqpKUiT/RTx5qE6QgIDOHeOjpZVSp06TfZTgDGGzUV1rJ2ToN0slVKnRZP9FHCkroOK5i4u0hKOUuo0abKfAt48VAfARfOTvByJUmqq0mQ/BWw+VMeSjBiSo0O9HYpSaorSZO/jyppOsLuilYvytISjlDp9mux93E9eLSIsyMINq3R9WaXU6dNk78M+OtbE6wdq+cZ5s0nSEo5S6gxosvdRxhh+8loRaTGhfG29LiaulDozmux91FtF9RRWtHLbRbmEBmnfeqXUmdFk74OMMdy/6QhZ8eFcd1aGt8NRSk0Dmux90MbCag5Ut/EvF+USZNFfkVLqzI2aSUTk9yJSLyL7R9h/vohYRWSP++fe8Q/Tf7Tbevnx3w+xJCOGa5anezscpdQ0MZbFSx4HfgU8eZJj3jXGXDEuEfm5+zcdoaGjm9/elI9FpzJWSo2TUVv2xpitQPMkxOL3th1r4vEPSrlxdRZLM3UVKqXU+BmvgvAaESkUkddEZOFIB4nIBhEpEJGChoaGcXrp6aHN1sudzxUyMz6cey6b7+1wlFLTzHgk+13ATGPMUuBB4K8jHWiMedQYk2+MyU9MTByHl54+nvqonKrWLu7//DIiQnRpYKXU+DrjZG+MaTPGdLhvvwoEiYiusHGK3jlSz4LUaFZkxXk7FKXUNHTGyV5EUkRE3LdXuZ+z6Uyf15909tjZWdbCurn6GamUmhij1gtE5GngfCBBRCqB/wSCAIwxjwCfAb4hInagC7jeGGMmLOJp5GB1G8X17USHBtHrMKybo6UtpdTEGDXZG2NuGGX/r3B1zVSn6KEtJfx9Xw0L06IJCQwgP1tLOEqpiaHDM72ouL4dgAPVbazOmaFz4CilJowmey/pdTg53niCK5emkRQVwhWLU70dklJqGtM+fl5S1tRJr8NwwbxEHrh+Ge5r3EopNSG0Ze8lJe4STm5SlCZ6pdSE02TvJcV1HQDMTorwciRKKX+gyd5Lius7yIgLIzxYK2lKqYmnyd5Lius7yE2K9HYYSik/ocl+kuyvsnLlg+/x4q5K7A4nRxs6yE2O8nZYSik/ocl+kvx9Xw37qqzc+Vwhy36wiR67kzmJ2rJXSk0OLRhPkp1lLSzNiOFL52RTWNGK08DFC5K9HZZSyk9osp8EPXYnhRWt/NPZM/n0igw+vUIXEVdKTS4t40yCgzVtdNud5M/UuW+UUt6hyX4S7CxrAWCFJnullJdosp8EO8uayYgLIzk61NuhKKX8lCb7CWZ3ONl+vIWztFWvlPIiTfYT7JW91TR2dHPFkjRvh6KU8mOa7CeQ02n41Vsl5KVEcVFekrfDUUr5Me16OQHabb184bfb6HU4OdpwggdvWE5AgM5sqZTyHm3ZT4D3S5rYV2UlODCASxYkc7kuTKKU8jJt2U+A90saCQ+28PzXzyE4UD9PlVLep5loArxf0sjqWfGa6JVSPmPUbCQivxeRehHZP8J+EZEHRKRERPaKyIrxD3PqqG7t4ljjCc6dk+DtUJRSymMsTc/HgUtPsv8yINf9swF4+MzDmrreL2kEYG2uJnullO8YNdkbY7YCzSc55GrgSePyERArIn55RbKiuZM/flRGQmQw83SueqWUDxmPC7TpQEW/+5XubTXj8NxTRlFtG9c89D4WEX587WJdRFwp5VPGI9kPl9XMsAeKbMBV6iErK2scXtp3PF9QidMJm//1PNJjw7wdjlJKDTAe3UUqgcx+9zOA6uEONMY8aozJN8bkJyYmjsNL+wZjDK/tr2VdboImeqWUTxqPZL8RuMndK+dswGqM8asSzr4qK1WtXVy6KMXboSil1LBGLeOIyNPA+UCCiFQC/wkEARhjHgFeBS4HSoBO4MsTFayvenVfLYEBwiW6zKBSykeNmuyNMTeMst8At45bRFNMdWsXL+yqZM3sGcSGB3s7HKWUGpYO8TwDjR3d/NPvtmHrcXDPZfO9HY5SSo1I58Y5TfVtNm58bBvVrV08+ZXVLEiL9nZISik1Ik32p+FEt53rH/2I2jYbf7h5FatmxXs7JKWUOilN9qfhoS0lHGs8wZ+/tpo1s2d4OxyllBqV1uxPUWnjCR579zifXpHOObN1/hul1NSgyf4UdNsd3PWXQoIDA7j70jxvh6OUUmOmZZwxMsbwvZf2U1DWwoM3LCcpOtTbISml1Jhpy36MXthVxfM7K7ntolyuXJrm7XCUUuqUaLIfg/p2Gz/820HyZ8Zx+0W53g5HKaVOmSb7URyotvLNP+2iq9fBfZ9ZQkCATl2slJp6tGY/Aluvg5+8eognPiwjOjSQn356MbMTI70dllJKnRZN9sOw9Tq47uEPOFDdxs3nZHPHJXOJCQvydlhKKXXaNNkP49dvH+VAdRsP37iCyxb75QqLSqlpRpO9W4/dyYNvFRMYEMAjbx/l6mVpmuiVUtOGJnu3v+6u4sG3SgCIDQ/i+59a4OWIlFJq/GiyBxxOwyNbj7IgNZqnN5yNMUbnpldKTSt+n+ytXb28tq+GYw0nePCG5XohVik1Lfl1sn+uoIJ7XtyHw2mYlRDBZbqGrFJqmvLbZP9+SSP//uI+VmXHc+PZWZydM4NAi44xU0pNT36Z7Dt77Nz29G5yEiP4zU1nER2qpRul1PTml8n+qY/KaTrRw6Oa6JVSfsLv6ha2Xge/2XqMc+fM4KyZupygUso/jCnZi8ilInJYREpE5O5h9t8sIg0issf989XxD/XM1Vpt3PHsHho7uvn2hTp7pVLKf4xaxhERC/AQcAlQCewQkY3GmIODDn3WGPOtCYjxtBljOFLXwb4qK1uK6tl0qA6AOy+Zy9k5unasUsp/jKVmvwooMcYcAxCRZ4CrgcHJ3qfsLGvhp68dYkdpCwBx4UHcsDKTW9bmkDUj3MvRKaXU5BpLsk8HKvrdrwRWD3PcdSKyHjgC3GGMqRh8gIhsADYAZGVlnXq0Y1TR3MlNv9tGVGgQ/3nlAtblJjIrIQKLzkWvlPJTY0n2w2VIM+j+K8DTxphuEfk68ARw4ZAHGfMo8ChAfn7+4Oc4I9bOXm5+fDu5SZGUN3cC8Pw31pARp614pZQaS7KvBDL73c8AqvsfYIxp6nf3t8B9Zx7a2BljuOv5QvZVWtlfZaXXYfjxtYs00SullNtYkv0OIFdEZgFVwPXAF/ofICKpxpga992rgEPjGuUIaqxdvHGgji2H63n7cAP/ccUCLsxLYmdZC9etSJ+MEJRSakoYNdkbY+wi8i3gH4AF+L0x5oCI/AAoMMZsBG4TkasAO9AM3DxRAb99uJ4f/u0gXT0Oqq02ADLiwvjm+bP5yrnZiAizEiIm6uWVUmpKGtMIWmPMq8Crg7bd2+/2PcA94xva8KLDgshLiSY0yEJOomvyshxdG1YppU5qyk2XsCIrjhU3xnk7DKWUmlL8broEpZTyR5rslVLKD2iyV0opP6DJXiml/IAme6WU8gOa7JVSyg9osldKKT+gyV4ppfyAGDOuk0+O/YVFGoCy03x4AtA4juGMN43v9PlybKDxnQlfjg18O77+sc00xiSe6hN4LdmfCREpMMbkezuOkWh8p8+XYwON70z4cmzg2/GNR2xaxlFKKT+gyV4ppfzAVE32j3o7gFFofKfPl2MDje9M+HJs4NvxnXFsU7Jmr5RS6tRM1Za9UkqpU6DJXiml/MCUS/YicqmIHBaREhG528uxZIrIFhE5JCIHRORf3Nv/S0SqRGSP++dyL8ZYKiL73HEUuLfFi8gmESl2/+uV1WBEZF6/c7RHRNpE5HZvnj8R+b2I1IvI/n7bhj1f4vKA+724V0RWeCG2/yciRe7Xf0lEYt3bs0Wkq985fGQiYztJfCP+LkXkHve5Oywin/RCbM/2i6tURPa4t3vj3I2US8bvvWeMmTI/uNbAPQrkAMFAIbDAi/GkAivct6OAI8AC4L+Au7x9vtxxlQIJg7b9DLjbfftu4D4fiNMC1AIzvXn+gPXACmD/aOcLuBx4DRDgbGCbF2L7BBDovn1fv9iy+x/nxXM37O/S/XdSCIQAs9x/15bJjG3Q/v8D7vXiuRspl4zbe2+qtexXASXGmGPGmB7gGeBqbwVjjKkxxuxy324HDgHp3ornFFwNPOG+/QRwjRdj6XMRcNQYc7qjqseFMWYr0Dxo80jn62rgSePyERArIqmTGZsx5g1jjN199yMgY6JefzQjnLuRXA08Y4zpNsYcB0pw/X1PemwiIsDngKcn6vVHc5JcMm7vvamW7NOBin73K/GR5Coi2cByYJt707fcX69+760yiZsB3hCRnSKywb0t2RhTA643GZDkteg+dj0D/9h85fzByOfL196PX8HV2uszS0R2i8g7IrLOW0Ex/O/Sl87dOqDOGFPcb5vXzt2gXDJu772pluxlmG1e7zsqIpHAC8Dtxpg24GFgNrAMqMH1FdFbzjXGrAAuA24VkfVejGVYIhIMXAX8xb3Jl87fyfjM+1FEvgfYgafcm2qALGPMcuBO4M8iEu2F0Eb6XfrMuQNuYGBDw2vnbphcMuKhw2w76fmbasm+Esjsdz8DqPZSLACISBCuX85TxpgXAYwxdcYYhzHGCfyWCfx6OhpjTLX733rgJXcsdX1f+dz/1nsrPrfLgF3GmDrwrfPnNtL58on3o4h8CbgCuNG4C7ru8kiT+/ZOXDXxuZMd20l+l75y7gKBTwPP9m3z1rkbLpcwju+9qZbsdwC5IjLL3Rq8HtjorWDctb7fAYeMMff3296/dnYtsH/wYyeDiESISFTfbVwX8/bjOmdfch/2JeBlb8TXz4CWla+cv35GOl8bgZvcPSPOBqx9X7kni4hcCnwXuMoY09lve6KIWNy3c4Bc4NhkxuZ+7ZF+lxuB60UkRERmuePbPtnxARcDRcaYyr4N3jh3I+USxvO9N5lXnMfpqvXluK5UHwW+5+VY1uL66rQX2OP+uRz4I7DPvX0jkOql+HJw9XgoBA70nS9gBrAZKHb/G+/FcxgONAEx/bZ57fzh+tCpAXpxtZ5uGel84foq/ZD7vbgPyPdCbCW4ard9779H3Mde5/6dFwK7gCu9dO5G/F0C33Ofu8PAZZMdm3v748DXBx3rjXM3Ui4Zt/eeTpeglFJ+YKqVcZRSSp0GTfZKKeUHNNkrpZQf0GSvlFJ+QJO9Ukr5AU32SinlBzTZK6WUH/j/VPpCX4NdGlcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(learn.activation_stats.stats[1][-2].numpy());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Internal implementation" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook[source][test]

\n", "\n", "> hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`i`**:`Tensors`, **`o`**:`Tensors`) → `Tuple`\\[`Rank0Tensor`, `Rank0Tensor`\\]\n", "\n", "
×

Tests found for hook:

Some other tests where hook is used:

  • pytest -sv tests/test_callbacks_hooks.py::test_hook_output_basics [source]

To run tests please refer to this guide.

\n", "\n", "Take the mean and std of `o`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.hook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Initialize stats. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_train_begin)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source][test]

\n", "\n", "> on_batch_end(**`train`**, **\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Take the stored results and puts it in `self.stats` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_batch_end)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source][test]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Polish the final result. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ActivationStats.on_train_end)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Hook[source][test]

\n", "\n", "> Hook(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n", "\n", "
×

No tests found for Hook. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Create a hook on `m` with `hook_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Registers and manually deregisters a [PyTorch hook](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Your `hook_func` will be called automatically when forward/backward (depending on `is_forward`) for your module `m` is run, and the result of that function is placed in `self.stored`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

remove[source][test]

\n", "\n", "> remove()\n", "\n", "
×

No tests found for remove. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Remove the hook from the model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deregister the hook, if not called already." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Hooks[source][test]

\n", "\n", "> Hooks(**`ms`**:`ModuleList`, **`hook_func`**:`HookFunc`, **`is_forward`**:`bool`=***`True`***, **`detach`**:`bool`=***`True`***)\n", "\n", "
×

No tests found for Hooks. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Create several hooks on the modules in `ms` with `hook_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Acts as a `Collection` (i.e. `len(hooks)` and `hooks[i]`) and an `Iterator` (i.e. `for hook in hooks`) of a group of hooks, one for each module in `ms`, with the ability to remove all as a group. Use `stored` to get all hook results. `hook_func` and `is_forward` behavior is the same as [`Hook`](/callbacks.hooks.html#Hook). See the source code for [`HookCallback`](/callbacks.hooks.html#HookCallback) for a simple example." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

remove[source][test]

\n", "\n", "> remove()\n", "\n", "
×

No tests found for remove. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Remove the hooks from the model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hooks.remove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deregister all hooks created by this class, if not previously called." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convenience functions for hooks" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook_output[source][test]

\n", "\n", "> hook_output(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hook`](/callbacks.hooks.html#Hook)\n", "\n", "
×

Tests found for hook_output:

  • pytest -sv tests/test_callbacks_hooks.py::test_hook_output_basics [source]

To run tests please refer to this guide.

\n", "\n", "Return a [`Hook`](/callbacks.hooks.html#Hook) that stores activations of `module` in `self.stored` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(hook_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Function that creates a [`Hook`](/callbacks.hooks.html#Hook) for `module` that simply stores the output of the layer." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

hook_outputs[source][test]

\n", "\n", "> hook_outputs(**`modules`**:`ModuleList`, **`detach`**:`bool`=***`True`***, **`grad`**:`bool`=***`False`***) → [`Hooks`](/callbacks.hooks.html#Hooks)\n", "\n", "
×

No tests found for hook_outputs. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Return [`Hooks`](/callbacks.hooks.html#Hooks) that store activations of all `modules` in `self.stored` " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(hook_outputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Function that creates a [`Hook`](/callbacks.hooks.html#Hook) for all passed `modules` that simply stores the output of the layers. For example, the (slightly simplified) source code of [`model_sizes`](/callbacks.hooks.html#model_sizes) is:\n", "\n", "```python\n", "def model_sizes(m, size):\n", " x = m(torch.zeros(1, in_channels(m), *size))\n", " return [o.stored.shape for o in hook_outputs(m)]\n", "```" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

model_sizes[source][test]

\n", "\n", "> model_sizes(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tuple`\\[`Sizes`, `Tensor`, [`Hooks`](/callbacks.hooks.html#Hooks)\\]\n", "\n", "
×

No tests found for model_sizes. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Pass a dummy input through the model `m` to get the various sizes of activations. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(model_sizes)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

model_summary[source][test]

\n", "\n", "> model_summary(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n", "\n", "
×

Tests found for model_summary:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]

To run tests please refer to this guide.

\n", "\n", "Print a summary of `m` using a output text width of `n` chars " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(model_summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This method only works on a [`Learner`](/basic_train.html#Learner) object with `train_ds` in it. If it was created as a result of [`load_learner`](/basic_train.html#load_learner), there is no [`data`](/vision.data.html#vision.data) to run through the model and therefore it's not possible to create such summary.\n", "\n", "A sample `summary` looks like:\n", "\n", "```\n", "======================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "======================================================================\n", "Conv2d [64, 176, 176] 9,408 False \n", "______________________________________________________________________\n", "BatchNorm2d [64, 176, 176] 128 True \n", "______________________________________________________________________\n", "ReLU [64, 176, 176] 0 False \n", "______________________________________________________________________\n", "MaxPool2d [64, 88, 88] 0 False \n", "______________________________________________________________________\n", "Conv2d [64, 88, 88] 36,864 False \n", "...\n", "```\n", "\n", "Column definition:\n", "\n", "1. **Layer (type)** is the name of the corresponding [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module).\n", "\n", "2. **Output Shape** is the shape of the output of the corresponding layer (minus the batch dimension, which is always the same and has no impact on the model params).\n", "\n", "3. **Param #** is the number of weights (and optionally bias), and it will vary for each layer.\n", "\n", " The number of params is calculated differently for each layer type. Here is how it's calculated for some of the most common layer types:\n", "\n", " * Conv: `kernel_size*kernel_size*ch_in*ch_out`\n", " * Linear: `(n_in+bias) * n_out`\n", " * Batchnorm: `2 * n_out`\n", " * Embeddings: `n_embed * emb_sz`\n", "\n", "4. **Trainable** indicates whether a layer is trainable or not.\n", "\n", " * Layers with `0` parameters are always Untrainable (e.g., `ReLU` and `MaxPool2d`).\n", " * Other layers are either Trainable or not, usually depending on whether they are frozen or not. See [Discriminative layer training](https://docs.fast.ai/basic_train.html#Discriminative-layer-training).\n", "\n", "To better understand this summary it helps to also execute `learn.model` and correlate the two outputs.\n", "\n", "Example:\n", "\n", "Let's feed to a [`Learner`](/basic_train.html#Learner) a dataset of 3-channel images size 352x352 and look at the model and its summary:\n", "\n", "```\n", "data.train_ds[0][0].data.shape\n", "learn = cnn_learner(data, models.resnet34, ...)\n", "print(learn.model)\n", "print(learn.summary())\n", "```\n", "Here are the outputs with everything but the relevant to the example lines removed:\n", "\n", "```\n", "torch.Size([3, 352, 352])\n", "\n", " [...]\n", " (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " [...]\n", " (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " [...]\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (8): Linear(in_features=512, out_features=37, bias=True)\n", "\n", "\n", "======================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "======================================================================\n", "Conv2d [64, 176, 176] 9,408 False \n", "______________________________________________________________________\n", "BatchNorm2d [64, 176, 176] 128 True \n", "______________________________________________________________________\n", "[...]\n", "MaxPool2d [64, 88, 88] 0 False \n", "______________________________________________________________________\n", "Conv2d [64, 88, 88] 36,864 False \n", "[...]\n", "______________________________________________________________________\n", "Linear [37] 18,981 True\n", "\n", "```\n", "\n", "**So let's calculate some params:**\n", "\n", "For the `Conv2d` layers, multiply the first 4 numbers from the corresponding layer definition:\n", "\n", "```\n", "Conv2d(3, 64, kernel_size=(7, 7), ...)\n", "\n", "3*64*7*7 = 9,408\n", "\n", "Conv2d(64, 64, kernel_size=(3, 3), ...)\n", "\n", "64*64*3*3 = 36,864\n", "```\n", "\n", "For the `BatchNorm2d` layer, multiply the first number by 2:\n", "```\n", "BatchNorm2d(64, ...)\n", "64*2 = 128\n", "```\n", "\n", "For `Linear` we multiply the first 2 and include the bias if it's `True`:\n", "\n", "```\n", "Linear(in_features=512, out_features=37, bias=True)\n", "\n", "(512+1)*37 = 18,981\n", "```\n", "\n", "**Now let's calculate some output shapes:**\n", "\n", "We started with 3x352x352 image and run it through this layer:\n", "\n", "`Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`\n", "\n", "How did we get: `[64, 176, 176]`\n", "\n", "The number of output channels is `64`, that's the first dimension in the number above. And then our image of `352x352` got convolved into `176x176` because of stride `2x2` (`352/2`).\n", "\n", "Then we had:\n", "\n", "`MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)`\n", "\n", "which reduced `[64, 176, 176]` to `[64, 88, 88]` again because of stride 2.\n", "\n", "And so on, finishing with:\n", "\n", "`Linear(in_features=512, out_features=37, bias=True)`\n", "\n", "which reduced everything to just `[37]`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Warning: Known issue: `model_summary` and `Learner.summary` don't work with the AWD LSTM in text models.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_warn(\"Known issue: `model_summary` and `Learner.summary` don't work with the AWD LSTM in text models.\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

num_features_model[source][test]

\n", "\n", "> num_features_model(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)) → `int`\n", "\n", "
×

No tests found for num_features_model. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Return the number of output features for `model`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(num_features_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a [`DynamicUnet`](/vision.models.unet.html#DynamicUnet)), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of `size`." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dummy_batch[source][test]

\n", "\n", "> dummy_batch(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***) → `Tensor`\n", "\n", "
×

No tests found for dummy_batch. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Create a dummy batch to go through `m` with `size`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(dummy_batch)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dummy_eval[source][test]

\n", "\n", "> dummy_eval(**`m`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`size`**:`tuple`=***`(64, 64)`***)\n", "\n", "
×

No tests found for dummy_eval. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Pass a [`dummy_batch`](/callbacks.hooks.html#dummy_batch) in evaluation mode in `m` with `size`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(dummy_eval)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class HookCallback[source][test]

\n", "\n", "> HookCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`modules`**:`Sequence`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`do_remove`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for HookCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For all `modules`, uses a callback to automatically register a method `self.hook` (that you must define in an inherited class) as a hook. This method must have the signature:\n", "\n", "```python\n", "def hook(self, m:Model, input:Tensors, output:Tensors)\n", "```\n", "\n", "If `do_remove` then the hook is automatically deregistered at the end of training. See [`ActivationStats`](/callbacks.hooks.html#ActivationStats) for a simple example of inheriting from this class." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Register the [`Hooks`](/callbacks.hooks.html#Hooks) on `self.modules`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source][test]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Remove the [`Hooks`](/callbacks.hooks.html#Hooks). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.on_train_end)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

remove[source][test]

\n", "\n", "> remove()\n", "\n", "
×

No tests found for remove. To contribute a test please refer to this guide and this discussion.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(HookCallback.remove)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

hook_fn[source][test]

\n", "\n", "> hook_fn(**`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`input`**:`Tensors`, **`output`**:`Tensors`)\n", "\n", "
×

No tests found for hook_fn. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Applies `hook_func` to `module`, `input`, `output`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Hook.hook_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implement callbacks using hooks", "title": "callbacks.hooks" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }