{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.data.all import *\n", "from fastai2.optimizer import *\n", "from fastai2.learner import *\n", "from fastai2.callback.core import *\n", "from torch.utils.data import TensorDataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp test_utils" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Synthetic Learner\n", "\n", "> For quick testing of the training loop and Callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch.utils.data import TensorDataset\n", "\n", "def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):\n", " def get_data(n):\n", " x = torch.randn(bs*n, 1)\n", " return TensorDataset(x, a*x + b + 0.1*torch.randn(bs*n, 1))\n", " train_ds = get_data(n_train)\n", " valid_ds = get_data(n_valid)\n", " device = default_device() if cuda else None\n", " train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)\n", " valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)\n", " return DataLoaders(train_dl, valid_dl, device=device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class RegModel(Module):\n", " def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " def forward(self, x): return x*self.a + self.b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@delegates(Learner.__init__)\n", "def synth_learner(n_trn=10, n_val=2, cuda=False, lr=1e-3, data=None, model=None, **kwargs):\n", " if data is None: data=synth_dbunch(n_train=n_trn,n_valid=n_val, cuda=cuda)\n", " if model is None: model=RegModel()\n", " return Learner(data, model, lr=lr, loss_func=MSELossFlat(),\n", " opt_func=partial(SGD, mom=0.9), **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class VerboseCallback(Callback):\n", " \"Callback that prints the name of each event called\"\n", " def __call__(self, event_name):\n", " print(event_name)\n", " super().__call__(event_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## - Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import *\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }