{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.cutmix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch.distributions.beta import Beta\n", "from fastai2.vision.all import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# CutMix Callback\n", "> Callback to apply [CutMix](https://arxiv.org/pdf/1905.04899.pdf) data augmentation technique to the training data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the [research paper](https://arxiv.org/pdf/1905.04899.pdf), `CutMix` is a way to combine two images. It comes from `MixUp` and `Cutout`. In this data augmentation technique:\n", "> patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches\n", "\n", "Also, from the paper: \n", "> By making efficient use of training pixels and retaining the regularization effect of regional dropout, CutMix consistently outperforms the state-of-the-art augmentation strategies on CIFAR and ImageNet classification tasks, as well as on the ImageNet weakly-supervised localization task. Moreover, unlike previous augmentation methods, our CutMix-trained ImageNet classifier, when used as a pretrained model, results in consistent performance gains in Pascal detection and MS-COCO image captioning benchmarks. We also show that CutMix improves the model robustness against input corruptions and its out-of-distribution detection performances. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class CutMix(Callback):\n", " \"Implementation of `https://arxiv.org/abs/1905.04899`\"\n", " run_after,run_valid = [Normalize],False\n", " def __init__(self, alpha=1.): self.distrib = Beta(tensor(alpha), tensor(alpha))\n", " def before_fit(self):\n", " self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n", " if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf\n", "\n", " def after_fit(self):\n", " if self.stack_y: self.learn.loss_func = self.old_lf\n", "\n", " def before_batch(self):\n", " W, H = self.xb[0].size(3), self.xb[0].size(2)\n", " lam = self.distrib.sample((1,)).squeeze().to(self.x.device)\n", " lam = torch.stack([lam, 1-lam])\n", " self.lam = lam.max()\n", " shuffle = torch.randperm(self.y.size(0)).to(self.x.device)\n", " xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))\n", " nx_dims = len(self.x.size())\n", " x1, y1, x2, y2 = self.rand_bbox(W, H, self.lam)\n", " self.learn.xb[0][:, :, x1:x2, y1:y2] = xb1[0][:, :, x1:x2, y1:y2]\n", " self.lam = (1 - ((x2-x1)*(y2-y1))/float(W*H)).item()\n", "\n", " if not self.stack_y:\n", " ny_dims = len(self.y.size())\n", " self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))\n", "\n", " def lf(self, pred, *yb):\n", " if not self.training: return self.old_lf(pred, *yb)\n", " with NoneReduce(self.old_lf) as lf:\n", " loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)\n", " return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))\n", "\n", " def rand_bbox(self, W, H, lam):\n", " cut_rat = torch.sqrt(1. - lam)\n", " cut_w = (W * cut_rat).type(torch.long)\n", " cut_h = (H * cut_rat).type(torch.long)\n", " # uniform\n", " cx = torch.randint(0, W, (1,)).to(self.x.device)\n", " cy = torch.randint(0, H, (1,)).to(self.x.device)\n", " x1 = torch.clamp(cx - cut_w // 2, 0, W)\n", " y1 = torch.clamp(cy - cut_h // 2, 0, H)\n", " x2 = torch.clamp(cx + cut_w // 2, 0, W)\n", " y2 = torch.clamp(cy + cut_h // 2, 0, H)\n", " return x1, y1, x2, y2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How does the batch with `CutMix` data augmentation technique look like?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's quickly create the `dls` using `ImageDataLoaders.from_name_re` DataBlocks API." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.PETS)\n", "pat = r'([^/]+)_\\d+.*$'\n", "fnames = get_image_files(path/'images')\n", "item_tfms = [Resize(256, method='crop')]\n", "batch_tfms = [*aug_transforms(size=224), Normalize.from_stats(*imagenet_stats)]\n", "dls = ImageDataLoaders.from_name_re(path, fnames, pat, bs=64, item_tfms=item_tfms, \n", " batch_tfms=batch_tfms)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's initialize the callback `CutMix`, create a learner, do one batch and display the images with the labels. `CutMix` inside updates the loss function based on the ratio of the cutout bbox to the complete image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cutmix = CutMix(alpha=1.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| epoch | \n", "train_loss | \n", "valid_loss | \n", "time | \n", "
|---|---|---|---|
| 0 | \n", "00:00 | \n", "