{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)\n", "- [`AccumulateScheduler`](/train.html#AccumulateScheduler)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source][test]

\n", "\n", "> fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n", "\n", "
×

Tests found for fit_one_cycle:

Some other tests where fit_one_cycle is used:

To run tests please refer to this guide.

\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit_one_cycle)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

one_cycle_scheduler[source][test]

\n", "\n", "> one_cycle_scheduler(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "
×

No tests found for one_cycle_scheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Instantiate a [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) with `lr_max`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(one_cycle_scheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source][test]

\n", "\n", "> lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n", "\n", "
×

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source][test]

\n", "\n", "> to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***, **`loss_fp32`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` in FP16 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp32[source][test]

\n", "\n", "> to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n", "\n", "
×

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` back to FP32 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp32)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source][test]

\n", "\n", "> mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(mixup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Interpretation[source][test]

\n", "\n", "> Interpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n", "\n", "
×

No tests found for Interpretation. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Interpretation base class, can be inherited for task specific Interpretation classes " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

from_learner[source][test]

\n", "\n", "> from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***)\n", "\n", "
×

Tests found for from_learner:

Some other tests where from_learner is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Gets preds, y_true, losses to construct base class from a learner " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.from_learner)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example in [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) is implemented using argmax on preds to set `self.pred_class` whereas an optional sigmoid is used for `MultilabelClassificationInterpretation`" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ClassificationInterpretation[source][test]

\n", "\n", "> ClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***) :: [`Interpretation`](/train.html#Interpretation)\n", "\n", "
×

Tests found for ClassificationInterpretation:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

Some other tests where ClassificationInterpretation is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = cnn_learner(data, models.resnet18)\n", "learn.fit(1)\n", "preds,y,losses = learn.get_preds(with_loss=True)\n", "interp = ClassificationInterpretation(learn, preds, y, losses)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returns tuple of *(losses,indices)*." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/plain": [ "torch.return_types.topk(\n", "values=tensor([14.2152, 10.3850, 9.1650, 8.7286, 5.8163, 5.6689, 4.9013, 4.5471,\n", " 4.2432]),\n", "indices=tensor([1059, 299, 960, 1831, 1775, 1467, 750, 1892, 634]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.top_losses(9)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source][test]

\n", "\n", "> plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If [`normalize`](/vision.data.html#normalize), plots the percentages with `norm_dec` digits. `slice_size` can be used to avoid out of memory error if your set is too big. `kwargs` are passed to `plt.figure`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAAEmCAYAAACnN7/iAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVn0lEQVR4nO3deXgV9b3H8fcXIhAgbAUFUbQim1AJS4GHCoI7rQqubEWlVGsLbdGKpa4oIFjhXrUqVa9VUEERd1y5XnEFFRSxKCCyiCAaEJBdEr73j5nYlCbhgL+TOYHP63nymDMzmfkeCW9n5pxEc3dEREKqkPQAIrL/UVhEJDiFRUSCU1hEJDiFRUSCU1hEJDiF5QBlZtlm9qyZbTSzx37Afvqb2cshZ0uKmXUxs0VJz7E/ML2PJbOZWT/gcqA5sAmYB4x29zd/4H4HAL8HOrt7/g8eNMOZmQNN3H1J0rMcCHTGksHM7HLgVuAm4BCgEXAX0DPA7o8AFh8IUUmFmWUlPcN+xd31kYEfQE1gM3BeKdtUJgrP6vjjVqByvK4b8AXwJ+Br4EtgYLzuBuA7YGd8jEHACOChIvs+EnAgK358EbCU6KxpGdC/yPI3i3xdZ+A9YGP8z85F1s0ERgJvxft5GahbwnMrnP/KIvP3An4OLAa+Aa4qsn0HYBawId72DqBSvO71+LlsiZ9v7yL7/zOwBniwcFn8NY3jY7SNHx8KrAW6Jf29UR4+Eh9AHyX8wcBpQH7hX+wStrkRmA0cDNQD3gZGxuu6xV9/I3BQ/BdyK1A7Xr97SEoMC1AN+BZoFq9rALSMP/8+LEAdYD0wIP66vvHjH8XrZwKfAU2B7Pjx2BKeW+H818XzXwzkAZOBHKAlsB04Kt6+HdApPu6RwCfA0CL7c+DoYvZ/M1Ggs4uGJd7m4ng/VYGXgHFJf1+Ulw9dCmWuHwFrvfRLlf7Aje7+tbvnEZ2JDCiyfme8fqe7P0/0X+tm+zjPLqCVmWW7+5fuvqCYbX4BfOruD7p7vrtPARYCZxTZ5n53X+zu24CpQG4px9xJdD9pJ/AIUBe4zd03xcdfABwL4O5z3X12fNzlwN3A8Sk8p+vdfUc8z79x93uBT4F3iGJ69R72JzGFJXOtA+ru4dr/UGBFkccr4mXf72O3MG0Fqu/tIO6+hejy4VLgSzN7zsyapzBP4UwNizxesxfzrHP3gvjzwr/4XxVZv63w682sqZlNN7M1ZvYt0X2puqXsGyDP3bfvYZt7gVbA39x9xx62lZjCkrlmEZ3q9yplm9VEN2ELNYqX7YstRKf8heoXXenuL7n7yUT/5V5I9BduT/MUzrRqH2faGxOI5mri7jWAqwDbw9eU+pKomVUnum91HzDCzOqEGPRAoLBkKHffSHR/4U4z62VmVc3sIDPrYWZ/jTebAlxjZvXMrG68/UP7eMh5QFcza2RmNYG/FK4ws0PM7EwzqwbsILqkKihmH88DTc2sn5llmVlv4Bhg+j7OtDdyiO4DbY7Ppn672/qvgKP2cp+3AXPd/dfAc8Dff/CUBwiFJYO5+38RvYflGqIblyuBIcBT8SajgDnAfOAj4P142b4cawbwaLyvufx7DCoQvbq0muiVkuOB3xWzj3XA6fG264he0Tnd3dfuy0x76QqgH9GrTfcSPZeiRgATzWyDmZ2/p52ZWU+iG+iXxosuB9qaWf9gE+/H9AY5EQlOZywiEpzCIiLBKSwiEpzCIiLBZdQPXllWtlulnKTHkDRo06JR0iNIGqxYsZy1a9f+x/uFMisslXKo3GyPrwRKOfTWO3ckPYKkwc86ti92uS6FRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hSVNBvftxpzHrmLutKsZ0q8bAMc2bchrE//E7EeG8+bDV9K+5REA1MrJ5tHxF/Puo3/hjQev4JjGDRKcXFK1cuVKTj2pO7k/aUHb1i254/bbAHh82mO0bd2SqpUqMHfOnISnTEbawmJmVczsXTP70MwWmNkN6TpWpjmmcQMGnt2ZLgNuoUPvMfTo2orGjeoxemgvRt/zAp36jGXkhOmMHtoLgCsHncqHi76gQ+8xDLr2QcYNOzfhZyCpyMrKYuxfxzPvo0947c3Z3P33O/nk449p2bIVj0x9guO6dE16xMSk84xlB3CCu7cGcoHTzKxTGo+XMZr/uD7vfrScbdt3UlCwizfmLqFn99a4Q41qVQCoWT2bL/M2RtsfVZ+Z7y4CYPHyrzji0DocXCcnsfklNQ0aNKBN27YA5OTk0Lx5C1avXkXzFi1o2qxZwtMlK21h8cjm+OFB8Yen63iZZMFnqzmu7dHUqVmN7CoHcdpxLTmsfm2GjZvGTUN78ekLIxlz2Vlc97enAfho8Sp6npgLQPuWR9CoQR0aHlIryacge2nF8uXMm/cBP+3QMelRMkJa77GYWUUzmwd8Dcxw93fSebxMsWjZV4x/YAbTJwzhmTsHM3/xKvLzC7jkvC5cOf4JmvS4livHPc6E6/sDMO7+GdTKqcrsR4bz2z7H8+GiL8gv2JXws5BUbd68mb7nn8Mt42+lRo0aSY+TEdIaFncvcPdc4DCgg5m12n0bM7vEzOaY2RzP35bOccrUxKdm0bnfzZw86FbWb9zCks/z6H96R556ZR4Aj8/44Pubt5u2bOc3Ix6iU5+xDLp2EnVrV2f5qnVJji8p2rlzJ33PP4feffvT66yzkx4nY5TJq0LuvgGYCZxWzLp73L29u7e3rOyyGKdM1KtdHYDD69em5wmtmfriHL7M20iXdk0A6NahKUs+zwOi+y0HZVUEYOBZnXnz/SVs2rI9mcElZe7OpRcPolnzFvzxssuTHiejZKVrx2ZWD9jp7hvMLBs4Cbg5XcfLNFPG/Zo6taqxM7+AoWOnsmHTNgaPnMwtw84lK6sCO3bkM2TUFCC6efs/IwdQULCLhUvXcOkNDyc8vaTi7bfeYvLDD9Kq1U/o2C66R3bDqJvYsWMHlw/9PWvz8ji75y84tnUuzz7/UsLTli1zT8/9VDM7FpgIVCQ6M5rq7jeW9jUVqh7slZudn5Z5JFnr37sj6REkDX7WsT1z586x3Zen7YzF3ecDbdK1fxHJXHrnrYgEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHBZJa0ws2cBL2m9u5+ZlolEpNwrMSzAuDKbQkT2KyWGxd1fK8tBRGT/UdoZCwBm1gQYAxwDVClc7u5HpXEuESnHUrl5ez8wAcgHugOTgAfTOZSIlG+phCXb3V8BzN1XuPsI4IT0jiUi5dkeL4WA7WZWAfjUzIYAq4CD0zuWiJRnqZyxDAWqAn8A2gEDgAvTOZSIlG97PGNx9/fiTzcDA9M7jojsD1J5VehVinmjnLvrPouIFCuVeyxXFPm8CnAO0StEIiLFSuVSaO5ui94ys7S8eS63RSNef/v2dOxaElb7p0OSHkHSYMeiz4tdnsqlUJ0iDysQ3cCtH2YsEdkfpXIpNJfoHosRXQItAwalcygRKd9SCUsLd99edIGZVU7TPCKyH0jlfSxvF7NsVuhBRGT/UdrvY6kPNASyzawN0aUQQA2iN8yJiBSrtEuhU4GLgMOA8fwrLN8CV6V3LBEpz0r7fSwTgYlmdo67P16GM4lIOZfKPZZ2Zlar8IGZ1TazUWmcSUTKuVTC0sPdNxQ+cPf1wM/TN5KIlHephKVi0ZeXzSwb0MvNIlKiVN7H8hDwipndHz8eCExM30giUt6l8rNCfzWz+cBJRK8MvQgcke7BRKT8SvV/WLYG2EX0k80nAp+kbSIRKfdKe4NcU6AP0BdYBzxK9Htvu5fRbCJSTpV2KbQQeAM4w92XAJjZZWUylYiUa6VdCp1DdAn0qpnda2Yn8q9334qIlKjEsLj7k+7eG2gOzAQuAw4xswlmdkoZzSci5dAeb966+xZ3f9jdTyf6uaF5wPC0TyYi5VaqrwoB4O7fuPvd+kXaIlKavQqLiEgqFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCS4r6QEOFAUFBXTt3IEGhx7KtCefZfmyZQy8oB/rv/mG1m3acO8/JlGpUqWkx5QUDO7bjYFnd8bMuP+Jt7hj8kyObdqQv13dh8qVDyK/YBdDb3qUOQtWUKN6Ff4x6kIOb1CbrIoVuXXSKzz4zOykn0Lape2Mxcyamdm8Ih/fmtnQdB0v0911x+00a9b8+8fXXTOcwb//I/MWLKJWrdpMeuC+BKeTVB3TuAEDz+5MlwG30KH3GHp0bUXjRvUYPbQXo+95gU59xjJywnRGD+0FwG/O78rCpWvo2Hssp158G2MvP4uDsiom/CzSL21hcfdF7p7r7rlAO2Ar8GS6jpfJVn3xBS+98DwXDhwEgLvz2sxX6XX2uQD0++UFTH/m6SRHlBQ1/3F93v1oOdu276SgYBdvzF1Cz+6tcYca1aoAULN6Nl/mbQTAgerVKgNQLbsy6zduJb9gV1Ljl5myuhQ6EfjM3VeU0fEyyp+HXcbIm8ayedMmANatW0etmrXIyor+9TdseBirV69OckRJ0YLPVjNiyBnUqVmNbTu+47TjWvL+x58zbNw0nr1zMGMuO4sKFYzuF40H4O+PvMa0W3/D0pdHk1OtCgP+/A/cPeFnkX5ldfO2DzCluBVmdomZzTGzOWvz8sponLLzwvPTqVfvYNq0bff9suK+scysLMeSfbRo2VeMf2AG0ycM4Zk7BzN/8Sry8wu45LwuXDn+CZr0uJYrxz3OhOv7A3By5xbMX/QFR51yNR37jOG/h59HTnxmsz9Le1jMrBJwJvBYcevd/R53b+/u7evWq5fuccrc7Lff5vnnnqVl06O46IJ+vD7zVYZfcRkbNm4gPz8fgFWrvqBBgwYJTyqpmvjULDr3u5mTB93K+o1bWPJ5Hv1P78hTr8wD4PEZH9C+5READDizE0//34cALF25luWr1tHsyEMSm72slMUZSw/gfXf/qgyOlXFuGHUTiz77nAWLl/LApMl07dad+yY+RNfju/HUE9MAmPzQJH5xRs+EJ5VU1atdHYDD69em5wmtmfriHL7M20iXdk0A6NahKUs+j86+V65ZT7cOzQA4uE4OTY88hGWr1iYzeBkqi3ssfSnhMuhAduOosQy8oB8jR1zHsbm5XHDRr5IeSVI0ZdyvqVOrGjvzCxg6diobNm1j8MjJ3DLsXLKyKrBjRz5DRkXf8mPvfZF7bvgl7029CjO4+ranWbdhS8LPIP0snTeSzKwqsBI4yt037mn7tu3a++tvv5u2eSQ59Tr9IekRJA12LJrKrq1f/8cNwrSesbj7VuBH6TyGiGQevaVfRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIIzd096hu+ZWR6wIuk5ykhdYG3SQ0hwB9qf6xHuXm/3hRkVlgOJmc1x9/ZJzyFh6c81okshEQlOYRGR4BSW5NyT9ACSFvpzRfdYRCQNdMYiIsEpLCISnMIiIsFlJT3AgcLMOgDu7u+Z2THAacBCd38+4dFEgtPN2zJgZtcDPYhCPgPoCMwETgJecvfRyU0n+8rM/gA86e4rk54l0ygsZcDMPgJygcrAGuAwd//WzLKBd9z92EQHlH1iZhuBLcBnwBTgMXfPS3aqzKB7LGUj390L3H0r8Jm7fwvg7tuAXcmOJj/AUuAwYCTQDvjYzF40swvNLCfZ0ZKlsJSN78ysavx5u8KFZlYThaU8c3ff5e4vu/sg4FDgLqL7Z0uTHS1ZuhQqA2ZW2d13FLO8LtDA3T9KYCz5gczsA3dvU8K67PiM9ICksIjsIzNr6u6Lk54jEyksIhKc7rGISHAKi4gEp7AIAGZWYGbzzOyfZvZYkVex9mVf3cxsevz5mWY2vJRta5nZ7/bhGCPM7Ip9nVHSS2GRQtvcPdfdWwHfAZcWXWmRvf5+cfdn3H1sKZvUAvY6LJLZFBYpzhvA0WZ2pJl9YmZ3Ae8Dh5vZKWY2y8zej89sqgOY2WlmttDM3gTOLtyRmV1kZnfEnx9iZk+a2YfxR2dgLNA4Plu6Jd5umJm9Z2bzzeyGIvu62swWmdn/As3K7N+G7DWFRf6NmWUR/VxT4XtrmgGT4vdrbAGuAU5y97bAHOByM6sC3AucAXQB6pew+9uB19y9NdAWWAAMJ3o3cq67DzOzU4AmQAeiH4NoZ2Zdzawd0AdoQxSunwZ+6hKQfrpZCmWb2bz48zeA+4jeSbrC3WfHyzsBxwBvmRlAJWAW0BxY5u6fApjZQ8AlxRzjBOACAHcvADaaWe3dtjkl/vggflydKDQ5RD/wtzU+xjM/6NlKWiksUmibu+cWXRDHY0vRRcAMd++723a5QKg3RBkwxt3v3u0YQwMeQ9JMl0KyN2YDPzOzowHMrKqZNQUWAj82s8bxdn1L+PpXgN/GX1vRzGoAm4jORgq9BPyqyL2bhmZ2MPA6cJaZZcc/4HdG4OcmASkskrL4VwJcBEwxs/lEoWnu7tuJLn2ei2/elvR/s/wj0D3+NRJzgZbuvo7o0uqfZnaLu78MTAZmxdtNA3Lc/X3gUWAe8DjR5ZpkKL2lX0SC0xmLiASnsIhIcAqLiASnsIhIcAqLiASnsIhIcAqLiAT3/9uZgtN2PkZ9AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source][test]

\n", "\n", "> confusion_matrix(**`slice_size`**:`int`=***`1`***)\n", "\n", "
×

Tests found for confusion_matrix:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]

Some other tests where confusion_matrix is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[989, 21],\n", " [ 40, 988]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source][test]

\n", "\n", "> most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "
×

Tests found for most_confused:

Some other tests where most_confused is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class MultiLabelClassificationInterpretation[source][test]

\n", "\n", "> MultiLabelClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`sigmoid`**:`bool`=***`True`***, **`thresh`**:`float`=***`0.3`***) :: [`Interpretation`](/train.html#Interpretation)\n", "\n", "
×

No tests found for MultiLabelClassificationInterpretation. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(MultiLabelClassificationInterpretation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Warning: MultiLabelClassificationInterpretation is not implemented yet. Feel free to implement it :)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_warn(\"MultiLabelClassificationInterpretation is not implemented yet. Feel free to implement it :)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Working with large datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When working with large datasets, memory problems can arise when computing the confusion matrix. For example, an error can look like this:\n", "\n", " RuntimeError: $ Torch: not enough memory: you tried to allocate 64GB. Buy new RAM!\n", "\n", "In this case it is possible to force [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) to compute the confusion matrix for data slices and then aggregate the result by specifying slice_size parameter. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[989, 21],\n", " [ 40, 988]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAAEmCAYAAACnN7/iAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVn0lEQVR4nO3deXgV9b3H8fcXIhAgbAUFUbQim1AJS4GHCoI7rQqubEWlVGsLbdGKpa4oIFjhXrUqVa9VUEERd1y5XnEFFRSxKCCyiCAaEJBdEr73j5nYlCbhgL+TOYHP63nymDMzmfkeCW9n5pxEc3dEREKqkPQAIrL/UVhEJDiFRUSCU1hEJDiFRUSCU1hEJDiF5QBlZtlm9qyZbTSzx37Afvqb2cshZ0uKmXUxs0VJz7E/ML2PJbOZWT/gcqA5sAmYB4x29zd/4H4HAL8HOrt7/g8eNMOZmQNN3H1J0rMcCHTGksHM7HLgVuAm4BCgEXAX0DPA7o8AFh8IUUmFmWUlPcN+xd31kYEfQE1gM3BeKdtUJgrP6vjjVqByvK4b8AXwJ+Br4EtgYLzuBuA7YGd8jEHACOChIvs+EnAgK358EbCU6KxpGdC/yPI3i3xdZ+A9YGP8z85F1s0ERgJvxft5GahbwnMrnP/KIvP3An4OLAa+Aa4qsn0HYBawId72DqBSvO71+LlsiZ9v7yL7/zOwBniwcFn8NY3jY7SNHx8KrAW6Jf29UR4+Eh9AHyX8wcBpQH7hX+wStrkRmA0cDNQD3gZGxuu6xV9/I3BQ/BdyK1A7Xr97SEoMC1AN+BZoFq9rALSMP/8+LEAdYD0wIP66vvHjH8XrZwKfAU2B7Pjx2BKeW+H818XzXwzkAZOBHKAlsB04Kt6+HdApPu6RwCfA0CL7c+DoYvZ/M1Ggs4uGJd7m4ng/VYGXgHFJf1+Ulw9dCmWuHwFrvfRLlf7Aje7+tbvnEZ2JDCiyfme8fqe7P0/0X+tm+zjPLqCVmWW7+5fuvqCYbX4BfOruD7p7vrtPARYCZxTZ5n53X+zu24CpQG4px9xJdD9pJ/AIUBe4zd03xcdfABwL4O5z3X12fNzlwN3A8Sk8p+vdfUc8z79x93uBT4F3iGJ69R72JzGFJXOtA+ru4dr/UGBFkccr4mXf72O3MG0Fqu/tIO6+hejy4VLgSzN7zsyapzBP4UwNizxesxfzrHP3gvjzwr/4XxVZv63w682sqZlNN7M1ZvYt0X2puqXsGyDP3bfvYZt7gVbA39x9xx62lZjCkrlmEZ3q9yplm9VEN2ELNYqX7YstRKf8heoXXenuL7n7yUT/5V5I9BduT/MUzrRqH2faGxOI5mri7jWAqwDbw9eU+pKomVUnum91HzDCzOqEGPRAoLBkKHffSHR/4U4z62VmVc3sIDPrYWZ/jTebAlxjZvXMrG68/UP7eMh5QFcza2RmNYG/FK4ws0PM7EwzqwbsILqkKihmH88DTc2sn5llmVlv4Bhg+j7OtDdyiO4DbY7Ppn672/qvgKP2cp+3AXPd/dfAc8Dff/CUBwiFJYO5+38RvYflGqIblyuBIcBT8SajgDnAfOAj4P142b4cawbwaLyvufx7DCoQvbq0muiVkuOB3xWzj3XA6fG264he0Tnd3dfuy0x76QqgH9GrTfcSPZeiRgATzWyDmZ2/p52ZWU+iG+iXxosuB9qaWf9gE+/H9AY5EQlOZywiEpzCIiLBKSwiEpzCIiLBZdQPXllWtlulnKTHkDRo06JR0iNIGqxYsZy1a9f+x/uFMisslXKo3GyPrwRKOfTWO3ckPYKkwc86ti92uS6FRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hUVEglNYRCQ4hSVNBvftxpzHrmLutKsZ0q8bAMc2bchrE//E7EeG8+bDV9K+5REA1MrJ5tHxF/Puo3/hjQev4JjGDRKcXFK1cuVKTj2pO7k/aUHb1i254/bbAHh82mO0bd2SqpUqMHfOnISnTEbawmJmVczsXTP70MwWmNkN6TpWpjmmcQMGnt2ZLgNuoUPvMfTo2orGjeoxemgvRt/zAp36jGXkhOmMHtoLgCsHncqHi76gQ+8xDLr2QcYNOzfhZyCpyMrKYuxfxzPvo0947c3Z3P33O/nk449p2bIVj0x9guO6dE16xMSk84xlB3CCu7cGcoHTzKxTGo+XMZr/uD7vfrScbdt3UlCwizfmLqFn99a4Q41qVQCoWT2bL/M2RtsfVZ+Z7y4CYPHyrzji0DocXCcnsfklNQ0aNKBN27YA5OTk0Lx5C1avXkXzFi1o2qxZwtMlK21h8cjm+OFB8Yen63iZZMFnqzmu7dHUqVmN7CoHcdpxLTmsfm2GjZvGTUN78ekLIxlz2Vlc97enAfho8Sp6npgLQPuWR9CoQR0aHlIryacge2nF8uXMm/cBP+3QMelRMkJa77GYWUUzmwd8Dcxw93fSebxMsWjZV4x/YAbTJwzhmTsHM3/xKvLzC7jkvC5cOf4JmvS4livHPc6E6/sDMO7+GdTKqcrsR4bz2z7H8+GiL8gv2JXws5BUbd68mb7nn8Mt42+lRo0aSY+TEdIaFncvcPdc4DCgg5m12n0bM7vEzOaY2RzP35bOccrUxKdm0bnfzZw86FbWb9zCks/z6H96R556ZR4Aj8/44Pubt5u2bOc3Ix6iU5+xDLp2EnVrV2f5qnVJji8p2rlzJ33PP4feffvT66yzkx4nY5TJq0LuvgGYCZxWzLp73L29u7e3rOyyGKdM1KtdHYDD69em5wmtmfriHL7M20iXdk0A6NahKUs+zwOi+y0HZVUEYOBZnXnz/SVs2rI9mcElZe7OpRcPolnzFvzxssuTHiejZKVrx2ZWD9jp7hvMLBs4Cbg5XcfLNFPG/Zo6taqxM7+AoWOnsmHTNgaPnMwtw84lK6sCO3bkM2TUFCC6efs/IwdQULCLhUvXcOkNDyc8vaTi7bfeYvLDD9Kq1U/o2C66R3bDqJvYsWMHlw/9PWvz8ji75y84tnUuzz7/UsLTli1zT8/9VDM7FpgIVCQ6M5rq7jeW9jUVqh7slZudn5Z5JFnr37sj6REkDX7WsT1z586x3Zen7YzF3ecDbdK1fxHJXHrnrYgEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHAKi4gEp7CISHBZJa0ws2cBL2m9u5+ZlolEpNwrMSzAuDKbQkT2KyWGxd1fK8tBRGT/UdoZCwBm1gQYAxwDVClc7u5HpXEuESnHUrl5ez8wAcgHugOTgAfTOZSIlG+phCXb3V8BzN1XuPsI4IT0jiUi5dkeL4WA7WZWAfjUzIYAq4CD0zuWiJRnqZyxDAWqAn8A2gEDgAvTOZSIlG97PGNx9/fiTzcDA9M7jojsD1J5VehVinmjnLvrPouIFCuVeyxXFPm8CnAO0StEIiLFSuVSaO5ui94ys7S8eS63RSNef/v2dOxaElb7p0OSHkHSYMeiz4tdnsqlUJ0iDysQ3cCtH2YsEdkfpXIpNJfoHosRXQItAwalcygRKd9SCUsLd99edIGZVU7TPCKyH0jlfSxvF7NsVuhBRGT/UdrvY6kPNASyzawN0aUQQA2iN8yJiBSrtEuhU4GLgMOA8fwrLN8CV6V3LBEpz0r7fSwTgYlmdo67P16GM4lIOZfKPZZ2Zlar8IGZ1TazUWmcSUTKuVTC0sPdNxQ+cPf1wM/TN5KIlHephKVi0ZeXzSwb0MvNIlKiVN7H8hDwipndHz8eCExM30giUt6l8rNCfzWz+cBJRK8MvQgcke7BRKT8SvV/WLYG2EX0k80nAp+kbSIRKfdKe4NcU6AP0BdYBzxK9Htvu5fRbCJSTpV2KbQQeAM4w92XAJjZZWUylYiUa6VdCp1DdAn0qpnda2Yn8q9334qIlKjEsLj7k+7eG2gOzAQuAw4xswlmdkoZzSci5dAeb966+xZ3f9jdTyf6uaF5wPC0TyYi5VaqrwoB4O7fuPvd+kXaIlKavQqLiEgqFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCU5hEZHgFBYRCS4r6QEOFAUFBXTt3IEGhx7KtCefZfmyZQy8oB/rv/mG1m3acO8/JlGpUqWkx5QUDO7bjYFnd8bMuP+Jt7hj8kyObdqQv13dh8qVDyK/YBdDb3qUOQtWUKN6Ff4x6kIOb1CbrIoVuXXSKzz4zOykn0Lape2Mxcyamdm8Ih/fmtnQdB0v0911x+00a9b8+8fXXTOcwb//I/MWLKJWrdpMeuC+BKeTVB3TuAEDz+5MlwG30KH3GHp0bUXjRvUYPbQXo+95gU59xjJywnRGD+0FwG/O78rCpWvo2Hssp158G2MvP4uDsiom/CzSL21hcfdF7p7r7rlAO2Ar8GS6jpfJVn3xBS+98DwXDhwEgLvz2sxX6XX2uQD0++UFTH/m6SRHlBQ1/3F93v1oOdu276SgYBdvzF1Cz+6tcYca1aoAULN6Nl/mbQTAgerVKgNQLbsy6zduJb9gV1Ljl5myuhQ6EfjM3VeU0fEyyp+HXcbIm8ayedMmANatW0etmrXIyor+9TdseBirV69OckRJ0YLPVjNiyBnUqVmNbTu+47TjWvL+x58zbNw0nr1zMGMuO4sKFYzuF40H4O+PvMa0W3/D0pdHk1OtCgP+/A/cPeFnkX5ldfO2DzCluBVmdomZzTGzOWvz8sponLLzwvPTqVfvYNq0bff9suK+scysLMeSfbRo2VeMf2AG0ycM4Zk7BzN/8Sry8wu45LwuXDn+CZr0uJYrxz3OhOv7A3By5xbMX/QFR51yNR37jOG/h59HTnxmsz9Le1jMrBJwJvBYcevd/R53b+/u7evWq5fuccrc7Lff5vnnnqVl06O46IJ+vD7zVYZfcRkbNm4gPz8fgFWrvqBBgwYJTyqpmvjULDr3u5mTB93K+o1bWPJ5Hv1P78hTr8wD4PEZH9C+5READDizE0//34cALF25luWr1tHsyEMSm72slMUZSw/gfXf/qgyOlXFuGHUTiz77nAWLl/LApMl07dad+yY+RNfju/HUE9MAmPzQJH5xRs+EJ5VU1atdHYDD69em5wmtmfriHL7M20iXdk0A6NahKUs+j86+V65ZT7cOzQA4uE4OTY88hGWr1iYzeBkqi3ssfSnhMuhAduOosQy8oB8jR1zHsbm5XHDRr5IeSVI0ZdyvqVOrGjvzCxg6diobNm1j8MjJ3DLsXLKyKrBjRz5DRkXf8mPvfZF7bvgl7029CjO4+ranWbdhS8LPIP0snTeSzKwqsBI4yt037mn7tu3a++tvv5u2eSQ59Tr9IekRJA12LJrKrq1f/8cNwrSesbj7VuBH6TyGiGQevaVfRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIJTWEQkOIVFRIIzd096hu+ZWR6wIuk5ykhdYG3SQ0hwB9qf6xHuXm/3hRkVlgOJmc1x9/ZJzyFh6c81okshEQlOYRGR4BSW5NyT9ACSFvpzRfdYRCQNdMYiIsEpLCISnMIiIsFlJT3AgcLMOgDu7u+Z2THAacBCd38+4dFEgtPN2zJgZtcDPYhCPgPoCMwETgJecvfRyU0n+8rM/gA86e4rk54l0ygsZcDMPgJygcrAGuAwd//WzLKBd9z92EQHlH1iZhuBLcBnwBTgMXfPS3aqzKB7LGUj390L3H0r8Jm7fwvg7tuAXcmOJj/AUuAwYCTQDvjYzF40swvNLCfZ0ZKlsJSN78ysavx5u8KFZlYThaU8c3ff5e4vu/sg4FDgLqL7Z0uTHS1ZuhQqA2ZW2d13FLO8LtDA3T9KYCz5gczsA3dvU8K67PiM9ICksIjsIzNr6u6Lk54jEyksIhKc7rGISHAKi4gEp7AIAGZWYGbzzOyfZvZYkVex9mVf3cxsevz5mWY2vJRta5nZ7/bhGCPM7Ip9nVHSS2GRQtvcPdfdWwHfAZcWXWmRvf5+cfdn3H1sKZvUAvY6LJLZFBYpzhvA0WZ2pJl9YmZ3Ae8Dh5vZKWY2y8zej89sqgOY2WlmttDM3gTOLtyRmV1kZnfEnx9iZk+a2YfxR2dgLNA4Plu6Jd5umJm9Z2bzzeyGIvu62swWmdn/As3K7N+G7DWFRf6NmWUR/VxT4XtrmgGT4vdrbAGuAU5y97bAHOByM6sC3AucAXQB6pew+9uB19y9NdAWWAAMJ3o3cq67DzOzU4AmQAeiH4NoZ2Zdzawd0AdoQxSunwZ+6hKQfrpZCmWb2bz48zeA+4jeSbrC3WfHyzsBxwBvmRlAJWAW0BxY5u6fApjZQ8AlxRzjBOACAHcvADaaWe3dtjkl/vggflydKDQ5RD/wtzU+xjM/6NlKWiksUmibu+cWXRDHY0vRRcAMd++723a5QKg3RBkwxt3v3u0YQwMeQ9JMl0KyN2YDPzOzowHMrKqZNQUWAj82s8bxdn1L+PpXgN/GX1vRzGoAm4jORgq9BPyqyL2bhmZ2MPA6cJaZZcc/4HdG4OcmASkskrL4VwJcBEwxs/lEoWnu7tuJLn2ei2/elvR/s/wj0D3+NRJzgZbuvo7o0uqfZnaLu78MTAZmxdtNA3Lc/X3gUWAe8DjR5ZpkKL2lX0SC0xmLiASnsIhIcAqLiASnsIhIcAqLiASnsIhIcAqLiAT3/9uZgtN2PkZ9AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('7', '3', 40), ('3', '7', 21)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.most_confused(slice_size=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll show examples below using our MNIST sample. As usual the `on_something` methods are directly called by the fastai library, no need to call them yourself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ShowGraph[source][test]

\n", "\n", "> ShowGraph(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for ShowGraph. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Update a graph of learner stats and metrics after each epoch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)\n", "learn.fit(3)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Training graph](imgs/train_graph.gif)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "If we have `last_metrics` plot them in our pbar graph " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GradientClipping[source][test]

\n", "\n", "> GradientClipping(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for GradientClipping. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Gradient clipping during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1620010.1007770.97105000:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy,\n", " callback_fns=partial(GradientClipping, clip=0.1))\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_end[source][test]

\n", "\n", "> on_backward_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_backward_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Clip the gradient before the optimizer step. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping.on_backward_end)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class BnFreeze[source][test]

\n", "\n", "> BnFreeze(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for BnFreeze. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Freeze moving average statistics in all non-trainable batchnorm layers. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1635500.0941370.97154100:06
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_begin[source][test]

\n", "\n", "> on_epoch_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put bn layers in eval mode just after `model.train()`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze.on_epoch_begin)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class AccumulateScheduler[source][test]

\n", "\n", "> AccumulateScheduler(**`learn`**:[`Learner`](/basic_train.html#Learner), **`n_step`**:`int`=***`1`***, **`drop_last`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for AccumulateScheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Does accumlated step every nth step by accumulating gradients " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(AccumulateScheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's force `batch_size=2` to mimic a scenario where we can't fit enough batch samples to our memory. We can then set `n_step` as desired to have an effective batch_size of `effective_batch_size=batch_size*n_step`.\n", "\n", "It is also important to use loss func with `reduce='sum'` in order to calculate exact average accumulated gradients.\n", "\n", "Another important note for users is that `batchnorm` is not yet adapted to accumulated gradients. So you should use this callback at your own risk until a hero fixes it :)\n", "\n", "Here we demonstrate this callback with a model without `batchnorm` layers, alternatively you can use `nn.InstanceNorm` or [`nn.GroupNorm`](https://pytorch.org/docs/stable/nn.html#torch.nn.GroupNorm).\n", "\n", "```\n", "from torchvision.models import vgg11\n", "\n", "data = ImageDataBunch.from_folder(path, bs=2)\n", "\n", "learn = cnn_learner(data, resnet18, metrics=accuracy, loss_func=CrossEntropyFlat(reduction='sum'),\n", " callback_fns=partial(AccumulateScheduler, n_step=16))\n", "learn.fit(1)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_top_losses[source][test]

\n", "\n", "> _cl_int_plot_top_losses(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`False`***, **`heatmap_thresh`**:`int`=***`16`***, **`alpha`**:`float`=***`0.6`***, **`cmap`**:`str`=***`'magma'`***, **`show_text`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for _cl_int_plot_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 25, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_from_learner[source][test]

\n", "\n", "> _cl_int_from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`tta`**=***`False`***)\n", "\n", "
×

Tests found for _cl_int_from_learner:

  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Create an instance of [`ClassificationInterpretation`](/train.html#ClassificationInterpretation). `tta` indicates if we want to use Test Time Augmentation. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.from_learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 26, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 27, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source][test]

\n", "\n", "> confusion_matrix(**`slice_size`**:`int`=***`1`***)\n", "\n", "
×

Tests found for confusion_matrix:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]

Some other tests where confusion_matrix is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 28, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source][test]

\n", "\n", "> most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "
×

Tests found for most_confused:

Some other tests where most_confused is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 29, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source][test]

\n", "\n", "> plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 30, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_multi_top_losses[source][test]

\n", "\n", "> _cl_int_plot_multi_top_losses(**`samples`**:`int`=***`3`***, **`figsize`**:`Tuple`\\[`int`, `int`\\]=***`(8, 8)`***, **`save_misclassified`**:`bool`=***`False`***)\n", "\n", "
×

No tests found for _cl_int_plot_multi_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class in a multilabeled dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_multi_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Open This Notebook\n", "\n", "" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Extensions to Learner that easily implement Callback", "title": "train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }