{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.mixup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.basics import *\n", "from local.callback.progress import *\n", "from local.vision.core import *\n", "\n", "from torch.distributions.beta import Beta" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *\n", "from local.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Mixup callback\n", "\n", "> Callback to apply MixUp data augmentation to your training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MixupCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def reduce_loss(loss, reduction='mean'):\n", " return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class MixUp(Callback):\n", " run_after=[Normalize, Cuda]\n", " def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))\n", " def begin_fit(self): \n", " self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n", " if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf\n", " \n", " def after_fit(self): \n", " if self.stack_y: self.learn.loss_func = self.old_lf\n", "\n", " def begin_batch(self):\n", " if not self.training: return\n", " lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)\n", " lam = torch.stack([lam, 1-lam], 1)\n", " self.lam = lam.max(1)[0]\n", " shuffle = torch.randperm(self.y.size(0)).to(self.x.device)\n", " xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))\n", " nx_dims = len(self.x.size())\n", " self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))\n", " if not self.stack_y: \n", " ny_dims = len(self.y.size())\n", " self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))\n", "\n", " def lf(self, pred, *yb):\n", " if not self.training: return self.old_lf(pred, *yb)\n", " with NoneReduce(self.old_lf) as lf:\n", " loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)\n", " return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.vision.core import *\n", "\n", "path = untar_data(URLs.MNIST_TINY)\n", "items = get_image_files(path)\n", "tds = DataSource(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))\n", "dbunch = tds.databunch(after_item=[ToTensor(), IntToFloatTensor()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " \n", "
\n", " \n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mixup = MixUp(0.5)\n", "learn = Learner(dbunch, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=mixup)\n", "learn._do_begin_fit(1)\n", "learn.epoch,learn.training = 0,True\n", "learn.dl = dbunch.train_dl\n", "b = dbunch.one_batch()\n", "learn._split(b)\n", "learn('begin_batch')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_,axs = plt.subplots(3,3, figsize=(9,9))\n", "dbunch.show_batch(b=(mixup.x,mixup.y), ctxs=axs.flatten())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core.ipynb.\n", "Converted 01a_utils.ipynb.\n", "Converted 01b_dispatch.ipynb.\n", "Converted 01c_transform.ipynb.\n", "Converted 02_script.ipynb.\n", "Converted 03_torch_core.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_dataloader.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }