{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Computer Vision Learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`vision.learner`](/vision.learner.html#vision.learner) is the module that defines the [`create_cnn`](/vision.learner.html#create_cnn) method, to easily get a model suitable for transfer learning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transfer learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transfer learning is a technique where you use a model trained on a very large dataset (usually [ImageNet](http://image-net.org/) in computer vision) and then adapt it to your own dataset. The idea is that it has learned to recognize many features on all of this data, and that you will benefit from this knowledge, especially if your dataset is small, compared to starting from a randomly initiliazed model. It has been proved in [this article](https://arxiv.org/abs/1805.08974) on a wide range of tasks that transfer learning nearly always give better results.\n", "\n", "In practice, you need to change the last part of your model to be adapted to your own number of classes. Most convolutional models end with a few linear layers (a part will call head). The last convolutional layer will have analyzed features in the image that went through the model, and the job of the head is to convert those in predictions for each of our classes. In transfer learning we will keep all the convolutional layers (called the body or the backbone of the model) with their weights pretrained on ImageNet but will define a new head initiliazed randomly.\n", "\n", "Then we will train the model we obtain in two phases: first we freeze the body weights and only train the head (to convert those analyzed features into predictions for our own data), then we unfreeze the layers of the backbone (gradually if necessary) and fine-tune the whole model (possily using differential learning rates).\n", "\n", "The [`create_cnn`](/vision.learner.html#create_cnn) factory method helps you to automatically get a pretrained model from a given architecture with a custom head that is suitable for your data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

create_cnn[source]

\n", "\n", "> create_cnn(`data`:[`DataBunch`](/basic_data.html#DataBunch), `arch`:`Callable`, `cut`:`Union`\\[`int`, `Callable`\\]=`None`, `pretrained`:`bool`=`True`, `lin_ftrs`:`Optional`\\[`Collection`\\[`int`\\]\\]=`None`, `ps`:`Floats`=`0.5`, `custom_head`:`Optional`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=`None`, `split_on`:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=`None`, `bn_final`:`bool`=`False`, `kwargs`:`Any`) → [`Learner`](/basic_train.html#Learner)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(create_cnn, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This method creates a [`Learner`](/basic_train.html#Learner) object from the [`data`](/vision.data.html#vision.data) object and model inferred from it with the backbone given in `arch`. Specifically, it will cut the model defined by `arch` (randomly initialized if `pretrained` is False) at the last convolutional layer by default (or as defined in `cut`, see below) and add:\n", "- an [`AdaptiveConcatPool2d`](/layers.html#AdaptiveConcatPool2d) layer,\n", "- a [`Flatten`](/layers.html#Flatten) layer,\n", "- blocks of \\[[`nn.BatchNorm1d`](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d), [`nn.Dropout`](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout), [`nn.Linear`](https://pytorch.org/docs/stable/nn.html#torch.nn.Linear), [`nn.ReLU`](https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU)\\] layers.\n", "\n", "The blocks are defined by the `lin_ftrs` and `ps` arguments. Specifically, the first block will have a number of inputs inferred from the backbone `arch` and the last one will have a number of outputs equal to `data.c` (which contains the number of classes of the data) and the intermediate blocks have a number of inputs/outputs determined by `lin_frts` (of course a block has a number of inputs equal to the number of outputs of the previous block). The default is to have an intermediate hidden size of 512 (which makes two blocks `model_activation` -> 512 -> `n_classes`). If you pass a float then the final dropout layer will have the value `ps`, and the remaining will be `ps/2`. If you pass a list then the values are used for dropout probabilities directly.\n", "\n", "Note that the very last block doesn't have a [`nn.ReLU`](https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU) activation, to allow you to use any final activation you want (generally included in the loss function in pytorch). Also, the backbone will be frozen if you choose `pretrained=True` (so only the head will train if you call [`fit`](/basic_train.html#fit)) so that you can immediately start phase one of training as described above.\n", "\n", "Alternatively, you can define your own `custom_head` to put on top of the backbone. If you want to specify where to split `arch` you should so in the argument `cut` which can either be the index of a specific layer (the result will not include that layer) or a function that, when passed the model, will return the backbone you want.\n", "\n", "The final model obtained by stacking the backbone and the head (custom or defined as we saw) is then separated in groups for gradual unfreezeing or differential learning rates. You can specify of to split the backbone in groups with the optional argument `split_on` (should be a function that returns those groups when given the backbone). \n", "\n", "The `kwargs` will be passed on to [`Learner`](/basic_train.html#Learner), so you can put here anything that [`Learner`](/basic_train.html#Learner) will accept ([`metrics`](/metrics.html#metrics), `loss_func`, `opt_func`...)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:09

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1272110.0804210.976938
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learner = create_cnn(data, models.resnet18, metrics=[accuracy])\n", "learner.fit_one_cycle(1,1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save('one_epoch')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you've actually trained your model, you may want to use it on a single image. This is done by using the following method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

predict[source]

\n", "\n", "> predict(`img`:[`ItemBase`](/core.html#ItemBase), `kwargs`)\n", "\n", "Return prect class, label and probabilities for `img`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.predict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Category 7, tensor(1), tensor([0.0200, 0.9800]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = learner.data.train_ds[0][0]\n", "learner.predict(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here the predict class for our iamge is '3', which corresponds to a label of 0. The probabilities the model found for each class are 99.65% and 0.35% respectively, so its confidence is pretty high.\n", "\n", "Note that if you want to load your trained model and use it on inference mode with the previous function, you can create a `cnn_learner` from empty data. First export the relevant bits of your data object by typing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then you can load an empty data object that has the same internal state like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = ImageDataBunch.load_empty(path, tfms=get_transforms()).normalize(imagenet_stats)\n", "learn = create_cnn(empty_data, models.resnet18)\n", "learn = learn.load('one_epoch')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Customize your model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can customize [`create_cnn`](/vision.learner.html#create_cnn) for your own model's default `cut` and `split_on` functions by adding them to the dictionary `model_meta`. The key should be your model and the value should be a dictionary with the keys `cut` and `split_on` (see the source code for examples). The constructor will call [`create_body`](/vision.learner.html#create_body) and [`create_head`](/vision.learner.html#create_head) for you based on `cut`; you can also call them yourself, which is particularly useful for testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

create_body[source]

\n", "\n", "> create_body(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `cut`:`Optional`\\[`int`\\]=`None`, `body_fn`:`Callable`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), [`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=`None`)\n", "\n", "Cut off the body of a typically pretrained `model` at `cut` or as specified by `body_fn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(create_body)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

create_head[source]

\n", "\n", "> create_head(`nf`:`int`, `nc`:`int`, `lin_ftrs`:`Optional`\\[`Collection`\\[`int`\\]\\]=`None`, `ps`:`Floats`=`0.5`, `bn_final`:`bool`=`False`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(create_head, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model head that takes `nf` features, runs through `lin_ftrs`, and ends with `nc` classes. `ps` is the probability of the dropouts, as documented above in [`create_cnn`](/vision.learner.html#create_cnn)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ClassificationInterpretation[source]

\n", "\n", "> ClassificationInterpretation(`data`:[`DataBunch`](/basic_data.html#DataBunch), `probs`:`Tensor`, `y_true`:`Tensor`, `losses`:`Tensor`, `sigmoid`:`bool`=`None`)\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This provides a confusion matrix and visualization of the most incorrect images. Pass in your [`data`](/vision.data.html#vision.data), calculated `preds`, actual `y`, and your `losses`, and then use the methods below to view the model interpretation results. For instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = create_cnn(data, models.resnet18)\n", "learn.fit(1)\n", "preds,y,losses = learn.get_preds(with_loss=True)\n", "interp = ClassificationInterpretation(data, preds, y, losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following fatory method gives a more convenient way to create an instance of this class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

from_learner[source]

\n", "\n", "> from_learner(`learn`:[`Learner`](/basic_train.html#Learner), `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `sigmoid`:`bool`=`None`, `tta`=`False`)\n", "\n", "Create an instance of [`ClassificationInterpretation`](/vision.learner.html#ClassificationInterpretation). `tta` indicates if we want to use Test Time Augmentation. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.from_learner)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Warning: In both those methods `sigmoid` is a deprecated argument and will be removed in a future version.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_warn('In both those methods `sigmoid` is a deprecated argument and will be removed in a future version.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_top_losses[source]

\n", "\n", "> plot_top_losses(`k`, `largest`=`True`, `figsize`=`(12, 12)`)\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `k` items are arranged as a square, so it will look best if `k` is a square number (4, 9, 16, etc). The title of each image shows: prediction, actual, loss, probability of actual class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "interp.plot_top_losses(9, figsize=(7,7))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source]

\n", "\n", "> top_losses(`k`:`int`=`None`, `largest`=`True`)\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returns tuple of *(losses,indices)*." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/plain": [ "(tensor([6.6163, 5.7370, 5.2444, 4.4769, 3.1993, 3.0553, 2.9016, 2.8053, 2.7020]),\n", " tensor([ 515, 1031, 1581, 326, 597, 1509, 877, 674, 1516]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.top_losses(9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source]

\n", "\n", "> plot_confusion_matrix(`normalize`:`bool`=`False`, `title`:`str`=`'Confusion matrix'`, `cmap`:`Any`=`'Blues'`, `norm_dec`:`int`=`2`, `slice_size`:`int`=`None`, `kwargs`)\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `normalize`, plots the percentages with `norm_dec` digits. `slice_size` can be used to avoid out of memory error if your set is too big. `kwargs` are passed to `plt.figure`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFbdJREFUeJzt3XeUVIXZx/Hvs6wsS4dgQQQ7rEiUpiYQFbEiFoIRRWzEEmNFX2OMWMBu1PeoiBo9voo1tiCoMWJsUQSkCFaaAkFARJTedpfn/WPukpWw7LLw7F2G3+ecOc7MvXPvM8J+nXtnZjV3R0QkUk7aA4hI9lNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNNsoM8s3s1fMbLGZvbAZ2+ljZiO25GxpMbODzWxK2nNkI9PnaKo3MzsNuAIoAJYCE4Fb3P2DzdzuGcAlQCd3L9rsQas5M3Ngb3efnvYs2yK9oqnGzOwK4B7gVmBHoAXwAHDiFtj8rsDUbSEyFWFmuWnPkNXcXZdqeAEaAMuAkzeyTh6ZEM1NLvcAecmyLsA3wP8A3wHzgL7JsoHAGqAw2cc5wADgqVLb3g1wIDe5fTbwNZlXVTOAPqXu/6DU4zoBY4HFyT87lVr2LnATMDLZzgigSRnPrWT+q0rN3wM4FpgK/ABcU2r9A4FRwKJk3fuBmsmyfyXPZXnyfE8ptf0/At8CT5bclzxmz2Qf7ZPbOwMLgC5p/93YGi+pD6BLGX8wcAxQVPKDXsY6NwKjgR2A7YEPgZuSZV2Sx98IbJf8gK4AGiXL1w9LmaEB6gBLgFbJsqbAvsn1daEBGgM/Amckj+ud3P5Zsvxd4CugJZCf3L69jOdWMv/1yfznJT/ozwD1gH2BlcDuyfodgF8k+90N+BLoV2p7Duy1ge3fQSbY+aVDk6xzHvAFUBt4A7gr7b8XW+tFh07V18+A733jhzZ9gBvd/Tt3X0DmlcoZpZYXJssL3f3vZP5r3qqS86wF2phZvrvPc/fPN7BOd2Cauz/p7kXu/iwwGTi+1DqPuftUd18JPA+03cg+C8mcjyoE/go0Ae5196XJ/r8A9gdw9/HuPjrZ70zgL8ChFXhON7j76mSen3D3R4DpwBgyce1fzvakDApN9bUQaFLOuYOdgVmlbs9K7lu3jfVCtQKou6mDuPtyMocbFwDzzOw1MyuowDwlMzUrdfvbTZhnobsXJ9dLQjC/1PKVJY83s5Zm9qqZfWtmS8ic12qykW0DLHD3VeWs8wjQBhjk7qvLWVfKoNBUX6OA1WTOS5RlLpmTuiVaJPdVxnIyhwgldiq90N3fcPcjyfyXfTKZH8Dy5imZaU4lZ9oUD5KZa293rw9cA1g5j9noW65mVpfMea9HgQFm1nhLDLotUmiqKXdfTOb8xGAz62Fmtc1sOzPrZmZ/TlZ7FrjWzLY3sybJ+k9VcpcTgUPMrIWZNQD+VLLAzHY0sxPNrA6Z+C0jc9ixvr8DLc3sNDPLNbNTgNbAq5WcaVPUI3MeaVnyauv36y2fD+yxidu8Fxjn7ucCrwEPbfaU2yiFphpz97vJfIbmWjInQmcDFwMvJ6vcDIwDPgE+BSYk91VmX28CzyXbGs9P45CTzDGXzDsxh/LfP8i4+0LgODLvdC0k847Rce7+fWVm2kRXAqeReTfrETLPpbQBwBAzW2RmvcrbmJmdSOaEfMnzvAJob2Z9ttjE2xB9YE9EwukVjYiEU2hEJJxCIyLhFBoRCVetvkhmufluefXTHkMCtCtonvYIEmDWrJl8//335X1eqZqFJq8+eQWnpj2GBBg55t60R5AAnQ/qWKH1dOgkIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJl5v2ANnqot6H0rfHLzGDx4aO4v5n32O/ls0YdE0v8mrmUlS8ln63v8C4z//N5Wd05ZRuHQDIrVGDgt13pPkR/flxyYqUn4VszOzZszm375l89918zIzfnnM+F196GQNvuI5Xhw8jJyeH7XfYgYcffZydd9457XFTZe6e9gzr5NTZ0fMKTk17jM3Wes+mPHHrWRx81t2sKSxm+KALuOTW57n3Tycz6Ol3GfHhlxzduTVXnNmVo393/08ee+zB+3JJny50u2BwStPH+HHMvWmPsMXNmzePb+fNo1379ixdupROB3Xg+Rdfptkuu1C/fn0ABg+6j8lffsGgBx5KedoYnQ/qyPjx46y89XToFKBg9x0Z+9ksVq4qpLh4Le9PmE6Prvvh7tSvUwuABnVrMe/7Jf/12F7HdOD5NyZU9chSCU2bNqVd+/YA1KtXj4KCfZg7d866yACsWLEcs3J/DrOeDp0CfD59HgMu7E7jBrVZubqQYzq3ZsIXs/nDXUN5ZfDvua3fieTkGIf1vecnj8uvtR1H/rKAy+94MaXJpbJmzZzJxIkfc8CBBwFww3X9efqpJ2jQoAH/ePOdlKdLX9grGjOrZWYfmdkkM/vczAZG7au6mTJzPncPeYtXBl/I8EEXMGnqHIrXruX8kztz1d1D2bv7AK7636E8eH3vnzyu+8FtGDVphs7NbGWWLVtG714ncefd96x7NTPwpluYPmM2p/buw0MP3F/OFrJf5KHTaqCru+8PtAWOMbNfBO6vWhkybDSdT7+LI88bxKIlK5n27wX0Oe5AXn57EgAvvTmRjvvu+pPHnHx0e17QYdNWpbCwkN69TuKU3n3o8eue/7X8lN59eHnoSylMVr2EhcYzliU3t0su1efMc7DtG9UFoPlOjTix63489/p45i1YzMEd9gKgywEtmT57wbr169etxa/a78kr736ayryy6dydC847h1YF+3DZ5Vesu3/6tGnrrr86fBgtWxWkMV61EnqOxsxqAOOBvYDB7j4mcn/VybN3/pbGDepQWFRMv9tfZPGylVx083PceWVPcmvksHpNIRff/Nd1659w2H68NXoKK1atSXFq2RQfjhzJM08/SZs2P+egDm0BGHjzrTz+2KNMmzqFHMuhxa67ct/g7HzHaVNUydvbZtYQGApc4u6frbfsfOB8AGrW61CrTd/weaTqZePb21LN3t5290XAO8AxG1j2sLt3dPeOlptfFeOISBWLfNdp++SVDGaWDxwJTI7an4hUX5HnaJoCQ5LzNDnA8+7+auD+RKSaCguNu38CtIvavohsPfQVBBEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEi43LIWmNkrgJe13N1PCJlIRLJOmaEB7qqyKUQkq5UZGnd/ryoHEZHstbFXNACY2d7AbUBroFbJ/e6+R+BcIpJFKnIy+DHgQaAIOAx4AngqcigRyS4VCU2+u78FmLvPcvcBQPfYsUQkm5R76ASsNrMcYJqZXQzMAerGjiUi2aQir2guA2oDlwIdgDOAsyKHEpHsUu4rGncfm1xdBvSNHUdEslFF3nV6hw18cM/du4ZMJCJZpyLnaK4sdb0WcBKZd6BERCqkIodO49e7a6SZfRQxTNuC5owcfU/EpiVljQ64OO0RJMDqKf+u0HoVOXRqXOpmDpkTwg0qN5aIbIsqcug0nsw5GiNzyDQDOCdyKBHJLhUJzT7uvqr0HWaWFzSPiGShinyO5sMN3DdqSw8iItlrY7+PZiegGZBvZu3IHDoB1CfzAT4RkQrZ2KHT0cDZwC7A3fwnNEuAa2LHEpFssrHfRzMEGGJmJ7n7S1U4k4hkmYqco+lgZg1LbphZIzO7OXAmEckyFQlNN3dfVHLD3X8Ejo0bSUSyTUVCU6P029lmlg/o7W0RqbCKfI7maeAtM3uMzAnhs4EhkUOJSHapyHed7jCzScARZD4h/Aawa/RgIpI9Kvo/kJtPJjInA12BL8MmEpGss7EP7LUEeieX74HnyPze4MOqaDYRyRIbO3SaDLwPHOfu0wHM7PIqmUpEssrGDp16AvOAd8zsETM7nP98OlhEpMLKDI27v+zupwIFwDtAP2AHM3vQzI6qqgFFZOtX7slgd1/u7s+4+/Fkvvf0MfDH8MlEJGtU9F0nIPOpYHd/2N0PjxpIRLLPJoVGRKQyFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuFy0x4g261atYojux7KmtWrKSoqokfPk7juhoHMnDGDM0/vzQ8/LKRduw48+vgT1KxZM+1xpRwX9e5C356dMDMe+9tI7n/mXX7eshmD+p9Knfw8Zs1dSN/+Q1i6fBW5uTk8eH0f2hY0J7dGDk+/9hF3/d+ItJ9CKvSKJlheXh6vj3iLMeMnMnrcx7w54g0+GjOaa6+5mksu7cdnX06jYaOGPP7Yo2mPKuVovWdT+vbsxMFn3MmBp9xGt0PasEfzJjx4/Wlce98wDuh1K8PfmcTlZx0OwElHtCevZi4H9LqVTn3u4NyTOtOiaeOUn0U6FJpgZkbdunUBKCwspLCwEMx47923+fVJvwHg9DPO4tXhw9IcUyqgYPedGPvZTFauKqS4eC3vj59Oj65t2avFDnwwfjoAb4+eTI/D2wLgOLVr1aRGjRzy82qyprCYpctXpfkUUqPQVIHi4mIO6tiOXZvtyOGHH8Eee+xJg4YNyc3NHLk2a7YLc+fMSXlKKc/nX82lc7u9aNygDvm1tuOYX+3LLjs14suv53F8l/0A6Hlke3bZsREAf/vnx6xYtYYZb97C1Ndv5J4n3uLHJSvSfAqpCQuNmbUys4mlLkvMrF/U/qqzGjVqMGbcx0ybMZtx48YydcrktEeSSpgyYz53P/4mrzxwEcMHX8SkKd9QXLyW3w14mvN7HczIp6+ibu081hQWA3DAvrtRXLyWPY7qzz7db+CyM7qyW7Ofpfws0hF2MtjdpwBtAcysBjAHGBq1v61Bw4YNOeTQLowZPYrFixZRVFREbm4uc+Z8w87NmqU9nlTAkJdHMeTlUQAMvPh45sxfxNSZ8zn+wsEA7NViB7odvC8Avbp1ZMSHX1BUtJYFPy5j1MSv6dC6BTPnLExt/rRU1aHT4cBX7j6rivZXbSxYsIBFixYBsHLlSt5+65+0KtiHQw49jKEvvQjAU08OofvxJ6Q5plTQ9o0y59ua79SIE7vuz3Ovj1t3n5lx9XlH88iLHwDwzbc/0OWAVgDUrlWTA/fbjSkz56czeMqq6u3tU4FnN7TAzM4Hzgdo3qJFFY1Tdb6dN4/zzjmbtcXFrF27lp6/OZljux/HPvu05szTezNwwHXsv387zu57TtqjSgU8e9e5NG5Yh8KiYvrd/jyLl63kot5d+N0phwAw7O2JPDFsNAAPPfcvHh54OuNf7I8ZPDlsNJ9Nm5vm+Kkxd4/dgVlNYC6wr7tvNOftO3T0kaPHhs4j6Wh84CVpjyABVk95nrUrvrPy1quKQ6duwITyIiMi2asqQtObMg6bRGTbEBoaM6sDHAn8LXI/IlK9hZ4MdvflwLb5wQERWUefDBaRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJZ+6e9gzrmNkCYFbac1SRJsD3aQ8hW9y29ue6q7tvX95K1So02xIzG+fuHdOeQ7Ys/blumA6dRCScQiMi4RSa9Dyc9gASQn+uG6BzNCISTq9oRCScQiMi4RQaEQmXm/YA2wozOxBwdx9rZq2BY4DJ7v73lEcTCaeTwVXAzG4AupEJ+5vAQcA7wJHAG+5+S4rjSSWZ2aXAUHefnfYs1Z1CUwXM7FOgLZAHfAvs4u5LzCwfGOPu+6U6oFSKmS0GlgNfAc8CL7j7gnSnqp50jqZqFLl7sbuvAL5y9yUA7r4SWJvuaLIZvgZ2AW4COgBfmNk/zOwsM6uX7mjVi0JTNdaYWe3keoeSO82sAQrN1szdfa27j3D3c4CdgQfInH/7Ot3RqhcdOlUBM8tz99UbuL8J0NTdP01hLNlMZvaxu7crY1nt5BWsoNCIVJqZtXT3qWnPsTVQaEQknM7RiEg4hUZEwik0AoCZFZvZRDP7zMxeKPUuWWW21cXMXk2un2BmV29k3YZmdmEl9jHAzK6s7IxStRQaKbHS3du6extgDXBB6YWWscl/X9x9uLvfvpFVGgKbHBrZuig0siHvA3uZ2W5mNsXMngA+A5qb2VFmNsrMJiSvfOoCmNkxZjbZzCYAPUs2ZGZnm9n9yfUdzWyomU1KLp2A24E9k1dTdybr/cHMxprZJ2Y2sNS2+pvZVDP7AGhVZf82ZLPpS5XyE2aWS+Z7Wf9I7tobOMvdRyef+7kWOMLdl5vZH4ErzOzPwCNAV2A68FwZm78PeM/df21mNYC6wNVAG3dvm+z/qGSfBwIGDDezQ8h81P9UMl/lyAUmAOO37LOXKAqNlMg3s4nJ9feBR8l80nWWu49O7v8F0BoYaWYANYFRQAEww92nAZjZU8D5G9hHV+BMAHcvBhabWaP11jkquXyc3K5LJjz1yHyBcUWyj+Gb9WylSik0UmJlyauKEklMlpe+C3jT3Xuvt95PHreZDLjN3f+y3j76bcF9SBXTORrZFKOBzma2F4CZ1TGzlsBkYDcz2zNZr3cZj38L+H3y2BrJd72Wknm1UuIN4Lelzv00M7MdgH8BPcwsP/nC4vFb+LlJIIVGKiz5FQhnA8+a2Sckh03uvorModJrycng78rYxGXAYcmvzRgPtHb3hWQOxT4zszvdfQTwDDAqWe9FoJ67TyBz7mcS8DowNuyJyhanryCISDi9ohGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTC/T8iUZsMv9QFLwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source]

\n", "\n", "> confusion_matrix(`slice_size`:`int`=`None`)\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[987, 23],\n", " [ 30, 998]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source]

\n", "\n", "> most_confused(`min_val`:`int`=`1`, `slice_size`:`int`=`None`) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Working with large datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When working with large datasets, memory problems can arise when computing the confusion matrix. For example, an error can look like this:\n", "\n", " RuntimeError: $ Torch: not enough memory: you tried to allocate 64GB. Buy new RAM! at /opt/conda/conda-bld/pytorch-nightly_1540719301766/work/aten/src/TH/THGeneral.cpp:204\n", "\n", "In this case it is possible to force [`ClassificationInterpretation`](/vision.learner.html#ClassificationInterpretation) to compute the confusion matrix for data slices and then aggregate the result by specifying slice_size parameter. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[987, 23],\n", " [ 30, 998]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFbdJREFUeJzt3XeUVIXZx/Hvs6wsS4dgQQQ7rEiUpiYQFbEiFoIRRWzEEmNFX2OMWMBu1PeoiBo9voo1tiCoMWJsUQSkCFaaAkFARJTedpfn/WPukpWw7LLw7F2G3+ecOc7MvXPvM8J+nXtnZjV3R0QkUk7aA4hI9lNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNNsoM8s3s1fMbLGZvbAZ2+ljZiO25GxpMbODzWxK2nNkI9PnaKo3MzsNuAIoAJYCE4Fb3P2DzdzuGcAlQCd3L9rsQas5M3Ngb3efnvYs2yK9oqnGzOwK4B7gVmBHoAXwAHDiFtj8rsDUbSEyFWFmuWnPkNXcXZdqeAEaAMuAkzeyTh6ZEM1NLvcAecmyLsA3wP8A3wHzgL7JsoHAGqAw2cc5wADgqVLb3g1wIDe5fTbwNZlXVTOAPqXu/6DU4zoBY4HFyT87lVr2LnATMDLZzgigSRnPrWT+q0rN3wM4FpgK/ABcU2r9A4FRwKJk3fuBmsmyfyXPZXnyfE8ptf0/At8CT5bclzxmz2Qf7ZPbOwMLgC5p/93YGi+pD6BLGX8wcAxQVPKDXsY6NwKjgR2A7YEPgZuSZV2Sx98IbJf8gK4AGiXL1w9LmaEB6gBLgFbJsqbAvsn1daEBGgM/Amckj+ud3P5Zsvxd4CugJZCf3L69jOdWMv/1yfznJT/ozwD1gH2BlcDuyfodgF8k+90N+BLoV2p7Duy1ge3fQSbY+aVDk6xzHvAFUBt4A7gr7b8XW+tFh07V18+A733jhzZ9gBvd/Tt3X0DmlcoZpZYXJssL3f3vZP5r3qqS86wF2phZvrvPc/fPN7BOd2Cauz/p7kXu/iwwGTi+1DqPuftUd18JPA+03cg+C8mcjyoE/go0Ae5196XJ/r8A9gdw9/HuPjrZ70zgL8ChFXhON7j76mSen3D3R4DpwBgyce1fzvakDApN9bUQaFLOuYOdgVmlbs9K7lu3jfVCtQKou6mDuPtyMocbFwDzzOw1MyuowDwlMzUrdfvbTZhnobsXJ9dLQjC/1PKVJY83s5Zm9qqZfWtmS8ic12qykW0DLHD3VeWs8wjQBhjk7qvLWVfKoNBUX6OA1WTOS5RlLpmTuiVaJPdVxnIyhwgldiq90N3fcPcjyfyXfTKZH8Dy5imZaU4lZ9oUD5KZa293rw9cA1g5j9noW65mVpfMea9HgQFm1nhLDLotUmiqKXdfTOb8xGAz62Fmtc1sOzPrZmZ/TlZ7FrjWzLY3sybJ+k9VcpcTgUPMrIWZNQD+VLLAzHY0sxPNrA6Z+C0jc9ixvr8DLc3sNDPLNbNTgNbAq5WcaVPUI3MeaVnyauv36y2fD+yxidu8Fxjn7ucCrwEPbfaU2yiFphpz97vJfIbmWjInQmcDFwMvJ6vcDIwDPgE+BSYk91VmX28CzyXbGs9P45CTzDGXzDsxh/LfP8i4+0LgODLvdC0k847Rce7+fWVm2kRXAqeReTfrETLPpbQBwBAzW2RmvcrbmJmdSOaEfMnzvAJob2Z9ttjE2xB9YE9EwukVjYiEU2hEJJxCIyLhFBoRCVetvkhmufluefXTHkMCtCtonvYIEmDWrJl8//335X1eqZqFJq8+eQWnpj2GBBg55t60R5AAnQ/qWKH1dOgkIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJl5v2ANnqot6H0rfHLzGDx4aO4v5n32O/ls0YdE0v8mrmUlS8ln63v8C4z//N5Wd05ZRuHQDIrVGDgt13pPkR/flxyYqUn4VszOzZszm375l89918zIzfnnM+F196GQNvuI5Xhw8jJyeH7XfYgYcffZydd9457XFTZe6e9gzr5NTZ0fMKTk17jM3Wes+mPHHrWRx81t2sKSxm+KALuOTW57n3Tycz6Ol3GfHhlxzduTVXnNmVo393/08ee+zB+3JJny50u2BwStPH+HHMvWmPsMXNmzePb+fNo1379ixdupROB3Xg+Rdfptkuu1C/fn0ABg+6j8lffsGgBx5KedoYnQ/qyPjx46y89XToFKBg9x0Z+9ksVq4qpLh4Le9PmE6Prvvh7tSvUwuABnVrMe/7Jf/12F7HdOD5NyZU9chSCU2bNqVd+/YA1KtXj4KCfZg7d866yACsWLEcs3J/DrOeDp0CfD59HgMu7E7jBrVZubqQYzq3ZsIXs/nDXUN5ZfDvua3fieTkGIf1vecnj8uvtR1H/rKAy+94MaXJpbJmzZzJxIkfc8CBBwFww3X9efqpJ2jQoAH/ePOdlKdLX9grGjOrZWYfmdkkM/vczAZG7au6mTJzPncPeYtXBl/I8EEXMGnqHIrXruX8kztz1d1D2bv7AK7636E8eH3vnzyu+8FtGDVphs7NbGWWLVtG714ncefd96x7NTPwpluYPmM2p/buw0MP3F/OFrJf5KHTaqCru+8PtAWOMbNfBO6vWhkybDSdT7+LI88bxKIlK5n27wX0Oe5AXn57EgAvvTmRjvvu+pPHnHx0e17QYdNWpbCwkN69TuKU3n3o8eue/7X8lN59eHnoSylMVr2EhcYzliU3t0su1efMc7DtG9UFoPlOjTix63489/p45i1YzMEd9gKgywEtmT57wbr169etxa/a78kr736ayryy6dydC847h1YF+3DZ5Vesu3/6tGnrrr86fBgtWxWkMV61EnqOxsxqAOOBvYDB7j4mcn/VybN3/pbGDepQWFRMv9tfZPGylVx083PceWVPcmvksHpNIRff/Nd1659w2H68NXoKK1atSXFq2RQfjhzJM08/SZs2P+egDm0BGHjzrTz+2KNMmzqFHMuhxa67ct/g7HzHaVNUydvbZtYQGApc4u6frbfsfOB8AGrW61CrTd/weaTqZePb21LN3t5290XAO8AxG1j2sLt3dPeOlptfFeOISBWLfNdp++SVDGaWDxwJTI7an4hUX5HnaJoCQ5LzNDnA8+7+auD+RKSaCguNu38CtIvavohsPfQVBBEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEi43LIWmNkrgJe13N1PCJlIRLJOmaEB7qqyKUQkq5UZGnd/ryoHEZHstbFXNACY2d7AbUBroFbJ/e6+R+BcIpJFKnIy+DHgQaAIOAx4AngqcigRyS4VCU2+u78FmLvPcvcBQPfYsUQkm5R76ASsNrMcYJqZXQzMAerGjiUi2aQir2guA2oDlwIdgDOAsyKHEpHsUu4rGncfm1xdBvSNHUdEslFF3nV6hw18cM/du4ZMJCJZpyLnaK4sdb0WcBKZd6BERCqkIodO49e7a6SZfRQxTNuC5owcfU/EpiVljQ64OO0RJMDqKf+u0HoVOXRqXOpmDpkTwg0qN5aIbIsqcug0nsw5GiNzyDQDOCdyKBHJLhUJzT7uvqr0HWaWFzSPiGShinyO5sMN3DdqSw8iItlrY7+PZiegGZBvZu3IHDoB1CfzAT4RkQrZ2KHT0cDZwC7A3fwnNEuAa2LHEpFssrHfRzMEGGJmJ7n7S1U4k4hkmYqco+lgZg1LbphZIzO7OXAmEckyFQlNN3dfVHLD3X8Ejo0bSUSyTUVCU6P029lmlg/o7W0RqbCKfI7maeAtM3uMzAnhs4EhkUOJSHapyHed7jCzScARZD4h/Aawa/RgIpI9Kvo/kJtPJjInA12BL8MmEpGss7EP7LUEeieX74HnyPze4MOqaDYRyRIbO3SaDLwPHOfu0wHM7PIqmUpEssrGDp16AvOAd8zsETM7nP98OlhEpMLKDI27v+zupwIFwDtAP2AHM3vQzI6qqgFFZOtX7slgd1/u7s+4+/Fkvvf0MfDH8MlEJGtU9F0nIPOpYHd/2N0PjxpIRLLPJoVGRKQyFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuFy0x4g261atYojux7KmtWrKSoqokfPk7juhoHMnDGDM0/vzQ8/LKRduw48+vgT1KxZM+1xpRwX9e5C356dMDMe+9tI7n/mXX7eshmD+p9Knfw8Zs1dSN/+Q1i6fBW5uTk8eH0f2hY0J7dGDk+/9hF3/d+ItJ9CKvSKJlheXh6vj3iLMeMnMnrcx7w54g0+GjOaa6+5mksu7cdnX06jYaOGPP7Yo2mPKuVovWdT+vbsxMFn3MmBp9xGt0PasEfzJjx4/Wlce98wDuh1K8PfmcTlZx0OwElHtCevZi4H9LqVTn3u4NyTOtOiaeOUn0U6FJpgZkbdunUBKCwspLCwEMx47923+fVJvwHg9DPO4tXhw9IcUyqgYPedGPvZTFauKqS4eC3vj59Oj65t2avFDnwwfjoAb4+eTI/D2wLgOLVr1aRGjRzy82qyprCYpctXpfkUUqPQVIHi4mIO6tiOXZvtyOGHH8Eee+xJg4YNyc3NHLk2a7YLc+fMSXlKKc/nX82lc7u9aNygDvm1tuOYX+3LLjs14suv53F8l/0A6Hlke3bZsREAf/vnx6xYtYYZb97C1Ndv5J4n3uLHJSvSfAqpCQuNmbUys4mlLkvMrF/U/qqzGjVqMGbcx0ybMZtx48YydcrktEeSSpgyYz53P/4mrzxwEcMHX8SkKd9QXLyW3w14mvN7HczIp6+ibu081hQWA3DAvrtRXLyWPY7qzz7db+CyM7qyW7Ofpfws0hF2MtjdpwBtAcysBjAHGBq1v61Bw4YNOeTQLowZPYrFixZRVFREbm4uc+Z8w87NmqU9nlTAkJdHMeTlUQAMvPh45sxfxNSZ8zn+wsEA7NViB7odvC8Avbp1ZMSHX1BUtJYFPy5j1MSv6dC6BTPnLExt/rRU1aHT4cBX7j6rivZXbSxYsIBFixYBsHLlSt5+65+0KtiHQw49jKEvvQjAU08OofvxJ6Q5plTQ9o0y59ua79SIE7vuz3Ovj1t3n5lx9XlH88iLHwDwzbc/0OWAVgDUrlWTA/fbjSkz56czeMqq6u3tU4FnN7TAzM4Hzgdo3qJFFY1Tdb6dN4/zzjmbtcXFrF27lp6/OZljux/HPvu05szTezNwwHXsv387zu57TtqjSgU8e9e5NG5Yh8KiYvrd/jyLl63kot5d+N0phwAw7O2JPDFsNAAPPfcvHh54OuNf7I8ZPDlsNJ9Nm5vm+Kkxd4/dgVlNYC6wr7tvNOftO3T0kaPHhs4j6Wh84CVpjyABVk95nrUrvrPy1quKQ6duwITyIiMi2asqQtObMg6bRGTbEBoaM6sDHAn8LXI/IlK9hZ4MdvflwLb5wQERWUefDBaRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJZ+6e9gzrmNkCYFbac1SRJsD3aQ8hW9y29ue6q7tvX95K1So02xIzG+fuHdOeQ7Ys/blumA6dRCScQiMi4RSa9Dyc9gASQn+uG6BzNCISTq9oRCScQiMi4RQaEQmXm/YA2wozOxBwdx9rZq2BY4DJ7v73lEcTCaeTwVXAzG4AupEJ+5vAQcA7wJHAG+5+S4rjSSWZ2aXAUHefnfYs1Z1CUwXM7FOgLZAHfAvs4u5LzCwfGOPu+6U6oFSKmS0GlgNfAc8CL7j7gnSnqp50jqZqFLl7sbuvAL5y9yUA7r4SWJvuaLIZvgZ2AW4COgBfmNk/zOwsM6uX7mjVi0JTNdaYWe3keoeSO82sAQrN1szdfa27j3D3c4CdgQfInH/7Ot3RqhcdOlUBM8tz99UbuL8J0NTdP01hLNlMZvaxu7crY1nt5BWsoNCIVJqZtXT3qWnPsTVQaEQknM7RiEg4hUZEwik0AoCZFZvZRDP7zMxeKPUuWWW21cXMXk2un2BmV29k3YZmdmEl9jHAzK6s7IxStRQaKbHS3du6extgDXBB6YWWscl/X9x9uLvfvpFVGgKbHBrZuig0siHvA3uZ2W5mNsXMngA+A5qb2VFmNsrMJiSvfOoCmNkxZjbZzCYAPUs2ZGZnm9n9yfUdzWyomU1KLp2A24E9k1dTdybr/cHMxprZJ2Y2sNS2+pvZVDP7AGhVZf82ZLPpS5XyE2aWS+Z7Wf9I7tobOMvdRyef+7kWOMLdl5vZH4ErzOzPwCNAV2A68FwZm78PeM/df21mNYC6wNVAG3dvm+z/qGSfBwIGDDezQ8h81P9UMl/lyAUmAOO37LOXKAqNlMg3s4nJ9feBR8l80nWWu49O7v8F0BoYaWYANYFRQAEww92nAZjZU8D5G9hHV+BMAHcvBhabWaP11jkquXyc3K5LJjz1yHyBcUWyj+Gb9WylSik0UmJlyauKEklMlpe+C3jT3Xuvt95PHreZDLjN3f+y3j76bcF9SBXTORrZFKOBzma2F4CZ1TGzlsBkYDcz2zNZr3cZj38L+H3y2BrJd72Wknm1UuIN4Lelzv00M7MdgH8BPcwsP/nC4vFb+LlJIIVGKiz5FQhnA8+a2Sckh03uvorModJrycng78rYxGXAYcmvzRgPtHb3hWQOxT4zszvdfQTwDDAqWe9FoJ67TyBz7mcS8DowNuyJyhanryCISDi9ohGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTC/T8iUZsMv9QFLwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('7', '3', 30), ('3', '7', 23)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.most_confused(slice_size=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GANs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

gan_learner[source]

\n", "\n", "> gan_learner(`data`, `generator`, `discriminator`, `loss_funcD`=`None`, `loss_funcG`=`None`, `noise_size`:`int`=`None`, `wgan`:`bool`=`False`, `kwargs`)\n", "\n", "Create a [`GANLearner`](/vision.learner.html#GANLearner) from `data` with a `generator` and a `discriminator`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(gan_learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `noise_size` is set, the GAN will generate fakes from a noise of this size, otherwise it'll use the inputs in data. If `wgan` is set to `True`, overrides the loss functions for a WGAN. `loss_funcD` and `loss_funcG` are used for discriminator and the generator. `kwargs` are passed to the `Learner` init." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GANLearner[source]

\n", "\n", "> GANLearner(`data`:[`DataBunch`](/basic_data.html#DataBunch), `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `opt_func`:`Callable`=`'Adam'`, `loss_func`:`Callable`=`None`, `metrics`:`Collection`\\[`Callable`\\]=`None`, `true_wd`:`bool`=`True`, `bn_wd`:`bool`=`True`, `wd`:`Floats`=`0.01`, `train_bn`:`bool`=`True`, `path`:`str`=`None`, `model_dir`:`str`=`'models'`, `callback_fns`:`Collection`\\[`Callable`\\]=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=``, `layer_groups`:`ModuleList`=`None`) :: [`Learner`](/basic_train.html#Learner)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GANLearner, title_level=3, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Subclass of [`Learner`](/basic_train.html#Learner) to deal with `predict` and `show_results` for GANs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class GANLearner[source]

\n", "\n", "> GANLearner(`data`:[`DataBunch`](/basic_data.html#DataBunch), `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `opt_func`:`Callable`=`'Adam'`, `loss_func`:`Callable`=`None`, `metrics`:`Collection`\\[`Callable`\\]=`None`, `true_wd`:`bool`=`True`, `bn_wd`:`bool`=`True`, `wd`:`Floats`=`0.01`, `train_bn`:`bool`=`True`, `path`:`str`=`None`, `model_dir`:`str`=`'models'`, `callback_fns`:`Collection`\\[`Callable`\\]=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=``, `layer_groups`:`ModuleList`=`None`) :: [`Learner`](/basic_train.html#Learner)\n", "\n", "Train `model` using `data` to minimize `loss_func` with optimizer `opt_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GANLearner)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

show_results[source]

\n", "\n", "> show_results(`rows`:`int`=`5`, `figsize`=`(10, 10)`)\n", "\n", "Show `rows` by `rows` fake images with `figsize`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GANLearner.show_results)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

add_gan_trainer[source]

\n", "\n", "> add_gan_trainer(`cb`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GANLearner.add_gan_trainer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

predict[source]

\n", "\n", "> predict()\n", "\n", "Predict one batch of fake images. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GANLearner.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "`Learner` support for computer vision", "title": "vision.learner" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }