{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "#| eval: false\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from __future__ import annotations\n", "from fastai.basics import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|default_exp callback.rnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Callback for RNN training\n", "\n", "> Callback that uses the outputs of language models to add AR and TAR regularization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@docs\n", "class ModelResetter(Callback):\n", " \"`Callback` that resets the model at each validation/training step\"\n", " def before_train(self): self.model.reset()\n", " def before_validate(self): self.model.reset()\n", " def after_fit(self): self.model.reset()\n", " _docs = dict(before_train=\"Reset the model before training\",\n", " before_validate=\"Reset the model before validation\",\n", " after_fit=\"Reset the model after fitting\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class RNNCallback(Callback):\n", " \"Save the raw and dropped-out outputs and only keep the true output for loss computation\"\n", " def after_pred(self): self.learn.pred,self.raw_out,self.out = [o[-1] if is_listy(o) else o for o in self.pred]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class RNNRegularizer(Callback):\n", " \"Add AR and TAR regularization\"\n", " order,run_valid = RNNCallback.order+1,False\n", " def __init__(self, alpha=0., beta=0.): store_attr()\n", " def after_loss(self):\n", " if not self.training: return\n", " if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()\n", " if self.beta:\n", " h = self.rnn.raw_out\n", " if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def rnn_cbs(alpha=0., beta=0.):\n", " \"All callbacks needed for (optionally regularized) RNN training\"\n", " reg = [RNNRegularizer(alpha=alpha, beta=beta)] if alpha or beta else []\n", " return [ModelResetter(), RNNCallback()] + reg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 01a_losses.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 10b_tutorial.albumentations.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 18b_callback.preds.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted dev-setup.ipynb.\n", "Converted index.ipynb.\n", "Converted quick_start.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#|hide\n", "from nbdev import nbdev_export\n", "nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }