{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Mixup data augmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.mixup import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is mixup?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module contains the implementation of a data augmentation technique called [mixup](https://arxiv.org/abs/1710.09412). It is extremely efficient at regularizing models in computer vision (we used it to get our time to train CIFAR10 to 94% on one GPU to 6 minutes). \n", "\n", "As the name kind of suggests, the authors of the mixup article propose training the model on mixes of the training set images. For example, suppose we’re training on CIFAR10. Instead of feeding the model the raw images, we take two images (not necessarily from the same class) and make a linear combination of them: in terms of tensors, we have:\n", "\n", "`new_image = t * image1 + (1-t) * image2`\n", "\n", "where t is a float between 0 and 1. The target we assign to that new image is the same combination of the original targets:\n", "\n", "`new_target = t * target1 + (1-t) * target2`\n", "\n", "assuming the targets are one-hot encoded (which isn’t the case in PyTorch usually). And it's as simple as that.\n", "\n", "\n", "\n", "Dog or cat? The right answer here is 70% dog and 30% cat!\n", "\n", "As the picture above shows, it’s a bit hard for the human eye to make sense of images obtained in this way (although we do see the shapes of a dog and a cat). However, it somehow makes a lot of sense to the model, which trains more efficiently. One important side note is that when training with mixup, the final loss (training or validation) will be higher than when training without it, even when the accuracy is far better: a model trained like this will make predictions that are a bit less confident." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To test this method, we first create a [`simple_cnn`](/layers.html#simple_cnn) and train it like we did with [`basic_train`](/basic_train.html#basic_train) so we can compare its results with a network trained with mixup." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.111498 | \n", "0.094612 | \n", "0.965653 | \n", "00:02 | \n", "
| 2 | \n", "0.079887 | \n", "0.064684 | \n", "0.975466 | \n", "00:02 | \n", "
| 3 | \n", "0.053950 | \n", "0.042022 | \n", "0.985280 | \n", "00:02 | \n", "
| 4 | \n", "0.043062 | \n", "0.035917 | \n", "0.986752 | \n", "00:02 | \n", "
| 5 | \n", "0.030692 | \n", "0.025291 | \n", "0.989205 | \n", "00:02 | \n", "
| 6 | \n", "0.027065 | \n", "0.024845 | \n", "0.987733 | \n", "00:02 | \n", "
| 7 | \n", "0.031135 | \n", "0.020047 | \n", "0.990186 | \n", "00:02 | \n", "
| 8 | \n", "0.025115 | \n", "0.025447 | \n", "0.988714 | \n", "00:02 | \n", "
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.358743 | \n", "0.156058 | \n", "0.961236 | \n", "00:02 | \n", "
| 2 | \n", "0.334059 | \n", "0.124648 | \n", "0.982336 | \n", "00:02 | \n", "
| 3 | \n", "0.321510 | \n", "0.105825 | \n", "0.987242 | \n", "00:02 | \n", "
| 4 | \n", "0.314596 | \n", "0.099804 | \n", "0.988714 | \n", "00:02 | \n", "
| 5 | \n", "0.314716 | \n", "0.094472 | \n", "0.989205 | \n", "00:02 | \n", "
| 6 | \n", "0.309679 | \n", "0.095133 | \n", "0.989696 | \n", "00:02 | \n", "
| 7 | \n", "0.314474 | \n", "0.086767 | \n", "0.990186 | \n", "00:02 | \n", "
| 8 | \n", "0.309931 | \n", "0.095609 | \n", "0.990186 | \n", "00:02 | \n", "
class MixUpCallback[source][test]MixUpCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for MixUpCallback. To contribute a test please refer to this guide and this discussion.
on_batch_begin[source][test]on_batch_begin(**`last_input`**, **`last_target`**, **`train`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.
class MixUpLoss[source][test]MixUpLoss(**`crit`**, **`reduction`**=***`'mean'`***) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`Module`](/torch_core.html#Module)\n",
"\n",
"No tests found for MixUpLoss. To contribute a test please refer to this guide and this discussion.
forward[source][test]forward(**`output`**, **`target`**)\n",
"\n",
"No tests found for forward. To contribute a test please refer to this guide and this discussion.