{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.mixup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.basics import *\n", "from local.callback.progress import *\n", "from local.vision.core import *\n", "\n", "from torch.distributions.beta import Beta" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *\n", "from local.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Mixup callback\n", "\n", "> Callback to apply MixUp data augmentation to your training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MixupCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def reduce_loss(loss, reduction='mean'):\n", " return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class MixUp(Callback):\n", " run_after=[Normalize, Cuda]\n", " def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))\n", " def begin_fit(self): \n", " self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n", " if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf\n", " \n", " def after_fit(self): \n", " if self.stack_y: self.learn.loss_func = self.old_lf\n", "\n", " def begin_batch(self):\n", " if not self.training: return\n", " lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)\n", " lam = torch.stack([lam, 1-lam], 1)\n", " self.lam = lam.max(1)[0]\n", " shuffle = torch.randperm(self.y.size(0)).to(self.x.device)\n", " xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))\n", " nx_dims = len(self.x.size())\n", " self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))\n", " if not self.stack_y: \n", " ny_dims = len(self.y.size())\n", " self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))\n", "\n", " def lf(self, pred, *yb):\n", " if not self.training: return self.old_lf(pred, *yb)\n", " with NoneReduce(self.old_lf) as lf:\n", " loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)\n", " return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.vision.core import *\n", "\n", "path = untar_data(URLs.MNIST_TINY)\n", "items = get_image_files(path)\n", "tds = DataSource(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))\n", "dbunch = tds.databunch(after_item=[ToTensor(), IntToFloatTensor()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "