{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Mixup data augmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.mixup import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is Mixup?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module contains the implementation of a data augmentation technique called [Mixup](https://arxiv.org/abs/1710.09412). It is extremely efficient at regularizing models in computer vision (we used it to get our time to train CIFAR10 to 94% on one GPU to 6 minutes). \n", "\n", "As the name kind of suggests, the authors of the mixup article propose to train the model on a mix of the pictures of the training set. Let’s say we’re on CIFAR10 for instance, then instead of feeding the model the raw images, we take two (which could be in the same class or not) and do a linear combination of them: in terms of tensor it’s\n", "\n", "`new_image = t * image1 + (1-t) * image2`\n", "\n", "where t is a float between 0 and 1. Then the target we assign to that image is the same combination of the original targets:\n", "\n", "`new_target = t * target1 + (1-t) * target2`\n", "\n", "assuming your targets are one-hot encoded (which isn’t the case in pytorch usually). And that’s as simple as this.\n", "\n", "![mixup](imgs/mixup.png)\n", "\n", "Dog or cat? The right answer here is 70% dog and 30% cat!\n", "\n", "As the picture above shows, it’s a bit hard for a human eye to comprehend the pictures obtained (although we do see the shapes of a dog and a cat) but somehow, it makes a lot of sense to the model which trains more efficiently. The final loss (training or validation) will be higher than when training without mixup even if the accuracy is far better, which means that a model trained like this will make predictions that are a bit less confident." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To test this method, we will first build a [`simple_cnn`](/layers.html#simple_cnn) and train it like we did with [`basic_train`](/basic_train.html#basic_train) so we can compare its results with a network trained with Mixup." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.111498 | \n", "0.094612 | \n", "0.965653 | \n", "00:02 | \n", "
2 | \n", "0.079887 | \n", "0.064684 | \n", "0.975466 | \n", "00:02 | \n", "
3 | \n", "0.053950 | \n", "0.042022 | \n", "0.985280 | \n", "00:02 | \n", "
4 | \n", "0.043062 | \n", "0.035917 | \n", "0.986752 | \n", "00:02 | \n", "
5 | \n", "0.030692 | \n", "0.025291 | \n", "0.989205 | \n", "00:02 | \n", "
6 | \n", "0.027065 | \n", "0.024845 | \n", "0.987733 | \n", "00:02 | \n", "
7 | \n", "0.031135 | \n", "0.020047 | \n", "0.990186 | \n", "00:02 | \n", "
8 | \n", "0.025115 | \n", "0.025447 | \n", "0.988714 | \n", "00:02 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.358743 | \n", "0.156058 | \n", "0.961236 | \n", "00:02 | \n", "
2 | \n", "0.334059 | \n", "0.124648 | \n", "0.982336 | \n", "00:02 | \n", "
3 | \n", "0.321510 | \n", "0.105825 | \n", "0.987242 | \n", "00:02 | \n", "
4 | \n", "0.314596 | \n", "0.099804 | \n", "0.988714 | \n", "00:02 | \n", "
5 | \n", "0.314716 | \n", "0.094472 | \n", "0.989205 | \n", "00:02 | \n", "
6 | \n", "0.309679 | \n", "0.095133 | \n", "0.989696 | \n", "00:02 | \n", "
7 | \n", "0.314474 | \n", "0.086767 | \n", "0.990186 | \n", "00:02 | \n", "
8 | \n", "0.309931 | \n", "0.095609 | \n", "0.990186 | \n", "00:02 | \n", "
class
MixUpCallback
[source][test]MixUpCallback
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for MixUpCallback
. To contribute a test please refer to this guide and this discussion.
on_batch_begin
[source][test]on_batch_begin
(**`last_input`**, **`last_target`**, **`train`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
class
MixUpLoss
[source][test]MixUpLoss
(**`crit`**, **`reduction`**=***`'mean'`***) :: [`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\n",
"\n",
"No tests found for MixUpLoss
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`output`**, **`target`**)\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.