{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)\n", "- [`AccumulateScheduler`](/train.html#AccumulateScheduler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source][test]

\n", "\n", "> fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n", "\n", "
×

Tests found for fit_one_cycle:

Some other tests where fit_one_cycle is used:

To run tests please refer to this guide.

\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit_one_cycle)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

one_cycle_scheduler[source][test]

\n", "\n", "> one_cycle_scheduler(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "
×

No tests found for one_cycle_scheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Instantiate a [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) with `lr_max`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(one_cycle_scheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source][test]

\n", "\n", "> lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n", "\n", "
×

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source][test]

\n", "\n", "> to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` in FP16 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp32[source][test]

\n", "\n", "> to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n", "\n", "
×

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` back to FP32 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source][test]

\n", "\n", "> mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(mixup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Interpretation[source][test]

\n", "\n", "> Interpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n", "\n", "
×

No tests found for Interpretation. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Interpretation base class, can be inherited for task specific Interpretation classes " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

from_learner[source][test]

\n", "\n", "> from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n", "\n", "
×

Tests found for from_learner:

Some other tests where from_learner is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Gets preds, y_true, losses to construct base class from a learner " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.from_learner)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example in [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) is implemented using argmax on preds to set `self.pred_class` whereas and optional sigmoid is used for `MultilabelClassificationInterpretation`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ClassificationInterpretation[source][test]

\n", "\n", "> ClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***) :: [`Interpretation`](/train.html#Interpretation)\n", "\n", "
×

Tests found for ClassificationInterpretation:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

Some other tests where ClassificationInterpretation is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = cnn_learner(data, models.resnet18)\n", "learn.fit(1)\n", "preds,y,losses = learn.get_preds(with_loss=True)\n", "interp = ClassificationInterpretation(learn, preds, y, losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returns tuple of *(losses,indices)*." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/plain": [ "torch.return_types.topk(\n", "values=tensor([10.5023, 7.5183, 7.4790, 5.0515, 4.5135, 4.3364, 4.1073, 3.8588,\n", " 3.7758]),\n", "indices=tensor([ 979, 977, 226, 1019, 161, 1393, 1500, 1697, 1276]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.top_losses(9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source][test]

\n", "\n", "> plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If [`normalize`](/vision.data.html#normalize), plots the percentages with `norm_dec` digits. `slice_size` can be used to avoid out of memory error if your set is too big. `kwargs` are passed to `plt.figure`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFdpJREFUeJzt3Xl0VIXZx/HvA5EYQDZBRHEpO4gaCGiLFVFxQcEFlEVKFfcqdatarb5udVd6bNXS6mktbriLqK2iIi4IKCCiVhBRUEEooiA7SXjeP+aGRkvIEHhyY/h9zskxc++de58R8nXunZlo7o6ISKQaaQ8gItWfQiMi4RQaEQmn0IhIOIVGRMIpNCISTqHZRplZnpk9a2bLzOzxLdjPYDMbuzVnS4uZHWhms9KeozoyvY+majOzk4CLgHbAcmA6cIO7v7mF+x0C/Bro5u5FWzxoFWdmDrR290/SnmVbpGc0VZiZXQTcAdwINAV2B/4MHLsVdr8H8PG2EJlsmFlO2jNUa+6uryr4BdQHVgAnbmKbXDIhWpB83QHkJut6AF8CvwH+A3wFDE3WXQusAwqTY5wGXAM8WGrfewIO5CS3TwE+JfOs6jNgcKnlb5a6XzfgHWBZ8s9updaNB34PTEj2MxZoXMZjK5n/0lLzHwccBXwMfAP8rtT2+wETgaXJtncBtZJ1ryePZWXyeAeU2v9vgYXAAyXLkvu0TI7RObm9C/A10CPtvxs/xq/UB9BXGX8wcCRQVPKDXsY21wGTgJ2AJsBbwO+TdT2S+18HbJf8gK4CGibrfxiWMkMD1AG+A9om65oBeyXfbwgN0Aj4FhiS3G9QcnvHZP14YA7QBshLbt9cxmMrmf+qZP4zgMXAw8AOwF7AGqBFsn0B8NPkuHsCHwEXlNqfA602sv9byAQ7r3Rokm3OSPZTG3gRuD3tvxc/1i+dOlVdOwJf+6ZPbQYD17n7f9x9MZlnKkNKrS9M1he6+z/J/Ne8bQXnWQ90NLM8d//K3T/cyDZHA7Pd/QF3L3L3UcBMoE+pbe5z94/dfTXwGJC/iWMWkrkeVQg8AjQG/ujuy5PjfwjsA+DuU919UnLcucBfgYOyeExXu/vaZJ7vcfd7gdnAZDJxvaKc/UkZFJqqawnQuJxrB7sA80rdnpcs27CPH4RqFVB3cwdx95VkTjfOBr4ys+fNrF0W85TMtGup2ws3Y54l7l6cfF8SgkWl1q8uub+ZtTGz58xsoZl9R+a6VuNN7BtgsbuvKWebe4GOwJ3uvracbaUMCk3VNZHMqcFxm9hmAZmLuiV2T5ZVxEoypwgldi690t1fdPfDyPyXfSaZH8Dy5imZaX4FZ9ocI8jM1drd6wG/A6yc+2zyJVczq0vmutffgGvMrNHWGHRbpNBUUe6+jMz1ibvN7Dgzq21m25lZLzO7NdlsFHClmTUxs8bJ9g9W8JDTge5mtruZ1QcuL1lhZk3N7BgzqwOsJXMKVryRffwTaGNmJ5lZjpkNADoAz1Vwps2xA5nrSCuSZ1u/+sH6RUCLzdznH4Gp7n468Dzwly2echul0FRh7v4HMu+huZLMhdAvgGHA6GST64EpwAzgfWBasqwix3oJeDTZ11S+H4caZF69WkDmlZiDgHM2so8lQO9k2yVkXjHq7e5fV2SmzXQxcBKZV7PuJfNYSrsGGGlmS82sf3k7M7NjyVyQPztZdBHQ2cwGb7WJtyF6w56IhNMzGhEJp9CISDiFRkTCKTQiEq5KfZDMcvLccuulPYYE6NRut7RHkADz5s3l66+/Lu/9SlUsNLn1yG0/KO0xJMCESXekPYIEOGD/Llltp1MnEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXA5aQ9QXZ07sDtDj/8ZBtw3ehJ3jXqNfdrsyp2Xn0hure0oKi7mglueYMqHn9P7oI5cdfZRrF/vFBUXc+nwp3nrvc/SfghSji+++ILTh/6SRYsWUqNGDU497UyGnXc+vzhpALNnzQJg6bKlNKjfgMlTp6c8bbrCQmNm2wOvA7nJcZ5w96ujjleVdGi5M0OP/xkH/vIPrCsqZsyfzuJfb37IDef14YZ7X2TsWx9xxAHtueG8YzjirLt49e2Pee61DwDo2KoZD958Cvkn3JTyo5Dy5OTkcPOtw+nUuTPLly+n2/4FHNrzMB58+NEN2/z2kt9Qv379FKesGiKf0awFDnH3FWa2HfCmmf3L3ScFHrNKaLdnU95+fy6r1xYC8Ma0ORx78D64Q7062wNQv24eXy1eBsDK1es23LdOXi7ulT+zbL5mzZrRrFkzAHbYYQfatWvPggXzad+hAwDuzpNPPMYLY8elOWaVEBYad3dgRXJzu+Rrm/gR+nDOQq4552ga1a/N6jWFHHlAB6Z99DmXDH+aZ+86m5vOP4YaNYyDT/3jhvsc02NvrhvWmyYN69L3gntTnF4qYt7cuUyf/i5d99t/w7IJb75B052a0qp16xQnqxpCr9GYWU1gKtAKuNvdJ0cer6qYNXcRw+9/hefu/hUrV61jxuz5FBWv58wTDuDSPzzN6HEz6NcznxH/N5Cjzx0BwJjx7zNm/Psc0KkFV53da8NyqfpWrFjBoP79uG34HdSrV2/D8sceGcWJAwelOFnVEfqqk7sXu3s+0BzYz8w6/nAbMzvTzKaY2RQvWh05TqUa+cxkuv1iOIedeSffLlvFJ58vZnDvroweNwOAJ1+eTpe99vif+01491NaNG/MjvXrVPbIUgGFhYUM6t+PAYMGc9zxfTcsLyoq4pnRT3HCiQNSnK7qqJSXt919KTAeOHIj6+5x9y7u3sVy8ipjnErRpGFdAHZr2oBjD9mHx16cxleLv+PAglYA9Ojamk++WAxAi+aNN9wvv21zam1XkyXLVlb+0LJZ3J2zzziNtu3ac/6FF31v3bhXXqZN23Y0b948pemqlshXnZoAhe6+1MzygJ7ALVHHq2pG3TqURvXrUFiUeRl76fLVnHv9I9x2cV9yatZg7boiht2QeXXi+EP35aSjulBYtJ41awsZcvnIlKeXbLw1YQIPP/QAHTvuzf4F+QBce/2NHNnrKB5/9BH6D9BpUwnzoJc4zGwfYCRQk8wzp8fc/bpN3adGnaae215/ONXRt5PuSHsECXDA/l2YOnWKlbdd5KtOM4BOUfsXkR8PfQRBRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEi6nrBVm9izgZa1392NCJhKRaqfM0AC3V9oUIlKtlRkad3+tMgcRkeprU89oADCz1sBNQAdg+5Ll7t4icC4RqUayuRh8HzACKAIOBu4HHogcSkSql2xCk+furwDm7vPc/RrgkNixRKQ6KffUCVhjZjWA2WY2DJgP7BQ7lohUJ9k8o7kAqA2cBxQAQ4CTI4cSkeql3Gc07v5O8u0KYGjsOCJSHWXzqtOrbOSNe+6u6zQikpVsrtFcXOr77YF+ZF6BEhHJSjanTlN/sGiCmYW8ma9Tu92YMOmOiF1Lyhp2HZb2CBJg7azPs9oum1OnRqVu1iBzQXjnio0lItuibE6dppK5RmNkTpk+A06LHEpEqpdsQtPe3deUXmBmuUHziEg1lM37aN7ayLKJW3sQEam+NvX7aHYGdgXyzKwTmVMngHpk3sAnIpKVTZ06HQGcAjQHhvPf0HwH/C52LBGpTjb1+2hGAiPNrJ+7P1mJM4lINZPNNZoCM2tQcsPMGprZ9YEziUg1k01oern70pIb7v4tcFTcSCJS3WQTmpqlX842szxAL2+LSNayeR/Ng8ArZnZfcnsoMDJuJBGpbrL5rNOtZjYD6EnmlacXgD2iBxOR6iPb/4HcQmA9mU9uHwp8FDaRiFQ7m3rDXhtgIDAIWAI8Sub3Bh9cSbOJSDWxqVOnmcAbQB93/wTAzC6slKlEpFrZ1KlTPzKnTK+a2b1mdij/fXewiEjWygyNuz/t7gOAdsB44EKgqZmNMLPDK2k+EakGyr0Y7O4r3f0hd+9N5nNP04HLwicTkWoj21edAHD3b9z9r/rF5CKyOTYrNCIiFaHQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEi4n7QGquzVr1tDz4O6sW7uWouIiju97Av939bXM/ewzhgweyLfffkN+p878/R8PUKtWrbTHlXKcO6gHQ/t2w8y476kJ3PXwePZusyt3XjGQOnm5zFuwhKFXjGT5yjXk5NRgxFWDyW+3Gzk1a/DQ829z+9/Hpv0QUhH2jMbM2prZ9FJf35nZBVHHq6pyc3N54aVxvD3tPSZPmc7YF19g8qRJXPG73/Lr8y/kg49m07BBQ/7x97+lPaqUo0PLZgzt240Dh9zGfgNuolf3jrTcvQkjrjqJK//0DF3738iYV9/jwpMPBaBfz87k1sqha/8b6Tb4Fk7vdwC7N2uU8qNIR1ho3H2Wu+e7ez5QAKwCno46XlVlZtStWxeAwsJCigoLMTNee3UcffudAMDgISfz7JjRaY4pWWj3k515+/25rF5TSHHxet6Y+gnHHrwvrffYiTenfgLAuEkzOe7QfAAcp/b2tahZswZ5ubVYV1jM8pVr0nwIqamsazSHAnPcfV4lHa9KKS4uZv+CfHbfZScO6XkYLVq2pH6DBuTkZM5cd23enAUL5qc8pZTnwzkL+HnnVjSqX4e87bfjyJ/vRfOdG/LvOV/Ru8feAPQ9rDPNmzYE4KmX32XVmnV89tINfPyv67jj/lf49rtVaT6E1FRWaAYCoza2wszONLMpZjZl8deLK2mcylWzZk0mT53OJ3O/ZMo7bzNz5kf/s41hKUwmm2PWZ4sY/o+XeG7EMMbcfS4zPp5PUVExZ13zEGf1786Ehy6lbu1c1hUWA9B1rz0pLl5Pi8OvoP3RV3P+kEPYc9cdU34U6Qi/GGxmtYBjgMs3tt7d7wHuASgo6OLR86SpQYMGdD+oB29PnsSypUspKioiJyeH+V9+SbNddkl7PMnCyNETGTl6IgDXDuvD/EVL+XjuIvqcczcArXbfiV4H7gVA/15dGPvWvykqWs/ib1cwcfqnFHTYnbnzl6Q2f1oq4xlNL2Cauy+qhGNVOYsXL2bp0qUArF69mnGvvEy7du3p3uNgnnryCQAeemAkvfscm+aYkqUmDTPX23bbuSHHHrIvj70wZcMyM+OyM47g3ifeBODLhd/Qo2tbAGpvX4v99tmTWXO3yR+DSnl5exBlnDZtCxZ+9RVnnHoyxcXFrPf19DuhP0cd3Zv27TswZPBArr36SvbN78Qpp56W9qiShVG3n06jBnUoLCrmgpsfY+ny1Zw7qAdnDegOwDPjpnP/M5MA+Mujr3PPtb9g6hNXYAYPPDOJD2YvSHP81Jh73NmKmdUGvgBauPuy8rYvKOjiEyZPCZtH0tOw67C0R5AAa2c9xvpV/yn3AmPoMxp3XwVsm1e/RGQDfQRBRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJy5e9ozbGBmi4F5ac9RSRoDX6c9hGx129qf6x7u3qS8japUaLYlZjbF3bukPYdsXfpz3TidOolIOIVGRMIpNOm5J+0BJIT+XDdC12hEJJye0YhIOIVGRMIpNCISLiftAbYVZrYf4O7+jpl1AI4EZrr7P1MeTSScLgZXAjO7GuhFJuwvAfsD44GewIvufkN600lFmdl5wNPu/kXas1R1Ck0lMLP3gXwgF1gINHf378wsD5js7vukOqBUiJktA1YCc4BRwOPuvjjdqaomXaOpHEXuXuzuq4A57v4dgLuvBtanO5psgU+B5sDvgQLg32b2gpmdbGY7pDta1aLQVI51ZlY7+b6gZKGZ1Ueh+TFzd1/v7mPd/TRgF+DPZK6/fZruaFWLTp0qgZnluvvajSxvDDRz9/dTGEu2kJm96+6dyliXlzxjFRQakQozszbu/nHac/wYKDQiEk7XaEQknEIjIuEUGgHAzIrNbLqZfWBmj5d6lawi++phZs8l3x9jZpdtYtsGZnZOBY5xjZldXNEZpXIpNFJitbvnu3tHYB1wdumVlrHZf1/cfYy737yJTRoAmx0a+XFRaGRj3gBamdmeZvaRmf0ZmAbsZmaHm9lEM5uWPPOpC2BmR5rZTDN7E+hbsiMzO8XM7kq+b2pmT5vZe8lXN+BmoGXybOq2ZLtLzOwdM5thZteW2tcVZjbLzF4G2lbavw3ZYgqNfI+Z5ZD5XFbJe3vaAvcn7xdZCVwJ9HT3zsAU4CIz2x64F+gDHAjsXMbu/wS85u77Ap2BD4HLyLxbOt/dLzGzw4HWwH5kPrZRYGbdzawAGAh0IhOyrlv5oUsgfXpbSuSZ2fTk+zeAv5F5p+s8d5+ULP8p0AGYYGYAtYCJQDvgM3efDWBmDwJnbuQYhwC/BHD3YmCZmTX8wTaHJ1/vJrfrkgnPDmQ+wLgqOcaYLXq0UqkUGimx2t3zSy9IYrKy9CLgJXcf9IPt8oGt9YYsA25y97/+4BgXbMVjSCXTqZNsjknAAWbWCsDMaptZG2Am8BMza5lsN6iM+78C/Cq5b00zqwcsJ/NspcSLwKmlrv3samY7Aa8Dx5tZXvKBxT5b+bFJIIVGspb8CoRTgFFmNoNMeNq5+xoyp0rPJxeDy/q/jZ4PHJz82oypwF7uvoTMqdgHZnabu48FHgYmJts9Aezg7tOAR4HpwJNkTu/kR0IfQRCRcHpGIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIT7f769u9MNCdWQAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source][test]

\n", "\n", "> confusion_matrix(**`slice_size`**:`int`=***`1`***)\n", "\n", "
×

Tests found for confusion_matrix:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]

Some other tests where confusion_matrix is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[983, 27],\n", " [ 30, 998]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source][test]

\n", "\n", "> most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "
×

Tests found for most_confused:

Some other tests where most_confused is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class MultiLabelClassificationInterpretation[source][test]

\n", "\n", "> MultiLabelClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`sigmoid`**:`bool`=***`True`***, **`thresh`**:`float`=***`0.3`***) :: [`Interpretation`](/train.html#Interpretation)\n", "\n", "
×

No tests found for MultiLabelClassificationInterpretation. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(MultiLabelClassificationInterpretation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Warning: MultiLabelClassificationInterpretation is not implemented yet. Feel free to implement it :)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_warn(\"MultiLabelClassificationInterpretation is not implemented yet. Feel free to implement it :)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Working with large datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When working with large datasets, memory problems can arise when computing the confusion matrix. For example, an error can look like this:\n", "\n", " RuntimeError: $ Torch: not enough memory: you tried to allocate 64GB. Buy new RAM!\n", "\n", "In this case it is possible to force [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) to compute the confusion matrix for data slices and then aggregate the result by specifying slice_size parameter. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[984, 26],\n", " [ 37, 991]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFVtJREFUeJzt3XmUVIWVx/HvZWsa2WVVgkYREBCBRvCgKJtbZFMSER2jxDVK4hI1Rh2jqFGDzrggjjKJwQ0FDYpLVKKigqCAEpSIuDKMso8gO01z5496TVoPDUXD7Vc0v885deyq9+q92wJf3ntV1Zi7IyISqVLaA4hIxafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaPZSZpZvZs+b2SozG78L2znTzF7dnbOlxcy6m9knac9REZneR5PbzOwM4AqgNbAamA3c6u5TdnG7ZwG/Arq5++ZdHjTHmZkDh7j7Z2nPsjfSEU0OM7MrgLuBPwCNgebAKGDAbtj8AcD8vSEy2TCzKmnPUKG5u245eAPqAGuAn21nnTwyIfomud0N5CXLegD/C/wGWAosAoYmy24CNgGFyT7OBW4EHiux7QMBB6ok988BviBzVPUlcGaJx6eUeF43YAawKvlvtxLLJgM3A1OT7bwKNCjleyue/+oS8w8EfgLMB/4PuLbE+l2AacDKZN2RQLVk2VvJ97I2+X4Hl9j+b4HFwKPFjyXPOTjZR6fk/n7AcqBH2r839sRb6gPoVsovDJwIbC7+g17KOsOB6UAjoCHwDnBzsqxH8vzhQNXkD+g6oF6y/IdhKTU0wD7Ad0CrZFlToG3y9dbQAPWBb4GzkucNSe7vmyyfDHwOtATyk/u3l/K9Fc9/QzL/+cAy4AmgFtAW2AAclKxfAByZ7PdA4GPgshLbc6DFNrZ/B5lg55cMTbLO+cl2agCvAHem/ftiT73p1Cl37Qss9+2f2pwJDHf3pe6+jMyRylkllhcmywvd/SUyf5u3KuM8W4B2Zpbv7ovcfe421jkZ+NTdH3X3ze4+FpgH9CuxzsPuPt/d1wPjgA7b2WchmetRhcCTQAPgHndfnex/LtAewN1nufv0ZL9fAQ8Cx2bxPf3e3Tcm83yPu48GPgXeJRPX63awPSmFQpO7VgANdnDtYD9gQYn7C5LHtm7jB6FaB9Tc2UHcfS2Z042LgEVm9qKZtc5inuKZ9i9xf/FOzLPC3YuSr4tDsKTE8vXFzzezlmb2gpktNrPvyFzXarCdbQMsc/cNO1hnNNAOuM/dN+5gXSmFQpO7ppE5NRi4nXW+IXNRt1jz5LGyWEvmFKFYk5IL3f0Vdz+OzN/s88j8AdzRPMUzfV3GmXbGA2TmOsTdawPXAraD52z3JVczq0nmutefgBvNrP7uGHRvpNDkKHdfReb6xP1mNtDMaphZVTM7ycz+mKw2FrjezBqaWYNk/cfKuMvZwDFm1tzM6gC/K15gZo3NrL+Z7QNsJHMKVrSNbbwEtDSzM8ysipkNBtoAL5Rxpp1Ri8x1pDXJ0dYvf7B8CXDQTm7zHmCWu58HvAj81y5PuZdSaHKYu/8HmffQXE/mQuhCYBjwbLLKLcBMYA7wIfB+8lhZ9jUJeCrZ1iy+H4dKZF69+obMKzHHAhdvYxsrgL7JuivIvGLU192Xl2WmnXQlcAaZV7NGk/leSroRGGNmK83stB1tzMwGkLkgf1Hy0BVAJzM7c7dNvBfRG/ZEJJyOaEQknEIjIuEUGhEJp9CISLic+iCZVcl3q1Yr7TEkQMdDm6c9ggRYsOArli9fvqP3K+VYaKrVIq/14LTHkABT370v7REkwFFdO2e1nk6dRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNEEuGXIsM8f9jlnjr2XYGT0AaN9yf94ccwXTx/6WKY9dRee2B3zvOQVtmrNmxj2c0rtDChPLzlq4cCEn9OlJh8MOpdPhbRl57z1bl40aeR/t27ai0+Ftufaaq1OcMjdUidqwmVUH3gLykv087e6/j9pfLmlzcFOGntKN7j+/k02FRUwceTF/e3sut146gFsffJlX3/knJxzVhlsvHcAJF9wLQKVKxi2XDmDStI9Tnl6yVaVKFW7/41107NSJ1atX061rAb37HMfSpUt44fnnmPH+HPLy8li6dGnao6YuLDTARqCXu68xs6rAFDP7m7tPD9xnTmj948a89+FXrN9QCMDbsz5lQK/2OFC7ZnUA6tTMZ9GyVVufc/Hpx/Lsa7MpaHPAtjYpOahp06Y0bdoUgFq1atG69aF8883X/PlPo7ny6mvIy8sDoFGjRmmOmRPCTp08Y01yt2py86j95ZK5ny/i6E4tqF+nBvnVq3Li0W1p1rgeV935DH+4dACfvjSc2y4fyA0jJwKwX8M69O/ZntFPT0l5cimrBV99xezZH3BEl658Nn8+U6e8TfduXTmu17HMnDEj7fFSF3lEg5lVBmYBLYD73f3dyP3lik++XMJdf5nEC6OGsXb9RubM/5rNRVu44KdHc/Vdf+XZ1//BoOM68sANZ3LyL0cy4spBXH/vRLZs2Ss6XOGsWbOGIacNYsRdd1O7dm02F23m22+/5a2p05k5Ywb/dsZpfDz/C8ws7VFTExoady8COphZXWCCmbVz949KrmNmFwAXAFC1ZuQ45WrMc9MZ81zmLPGmYf34eslKhg/rx29GPAPAM5M+YNS/DwGgU5vmPHLbOQDsW7cmJxzdhs1FW3h+8pxUZpfsFRYWMuS0QQweciYDTzkVgP33b8bAU07FzDiiSxcqVarE8uXLadiwYcrTpqdcXnVy95XAZODEbSx7yN07u3tnq5JfHuOUi4b1MtH8UZN6DOh5OONensmi5avoXtACgB5dWvLZwmUAHNrvRlr3zdwm/H02l902TpHZA7g7F51/Lq1aH8qll1+x9fF+/Qcy+Y3XAfh0/nw2bdpEgwYN0hozJ0S+6tQQKHT3lWaWD/QB7ojaX64Ze+d51K9Tg8LNW7jsjnGsXL2eS24ey4irBlGlcmU2bixk2C1Ppj2m7IJ3pk7liccfpV27w+hakHlLwk23/IGzh/6CC8/7BQUd2lGtajX++89j9urTJgBzj7kuYGbtgTFAZTJHTuPcffj2nlOpRiPPaz04ZB5J17fv3Zf2CBLgqK6dmTVr5g4rGnZE4+5zgI5R2xeRPYfeGSwi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJV6W0BWb2POClLXf3/iETiUiFU2pogDvLbQoRqdBKDY27v1meg4hIxbW9IxoAzOwQ4DagDVC9+HF3PyhwLhGpQLK5GPww8ACwGegJPAI8GjmUiFQs2YQm391fA8zdF7j7jUCv2LFEpCLZ4akTsMHMKgGfmtkw4GugUexYIlKRZHNEcxlQA/g1UACcBZwdOZSIVCw7PKJx9xnJl2uAobHjiEhFlM2rTm+wjTfuubuu04hIVrK5RnNlia+rA4PIvAIlIpKVbE6dZv3goalmFvJmvg6HNmfKtHsjNi0pq3fEsLRHkAAbP/mfrNbL5tSpfom7lchcEG5StrFEZG+UzanTLDLXaIzMKdOXwLmRQ4lIxZJNaA519w0lHzCzvKB5RKQCyuZ9NO9s47Fpu3sQEam4tvfzaJoA+wP5ZtaRzKkTQG0yb+ATEcnK9k6dTgDOAZoBd/Gv0HwHXBs7lohUJNv7eTRjgDFmNsjdnynHmUSkgsnmGk2BmdUtvmNm9czslsCZRKSCySY0J7n7yuI77v4t8JO4kUSkoskmNJVLvpxtZvmAXt4Wkaxl8z6ax4DXzOzh5P5QYEzcSCJS0WTzWac/mtkcoA+ZV55eBg6IHkxEKo5s/wG5xcAWMp/c7g18HDaRiFQ423vDXkvgdGAIsAJ4iszPDe5ZTrOJSAWxvVOnecDbQD93/wzAzC4vl6lEpELZ3qnTIDKnTG+Y2Wgz682/3h0sIpK1UkPj7hPcfTDQGpgMXA40NrMHzOz4cppPRCqAHV4Mdve17v64u/cl87mn2cA14ZOJSIWR7atOALj7/7n7g/rB5CKyM3YqNCIiZaHQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEq5K2gNUdBs2bOD43seyceNGijZvZuCpg7j+hps4rtcxrF69GoBly5bSuXMXnnp6QsrTyo5cMqQHQ0/thpnx8F+nMvKJyRzWcn/uu+509snPY8E3Kxh63RhWr91A/Tr78MSIcyloewCPTZzO5XeMT3v81ISFxsxaAU+VeOgg4AZ3vztqn7koLy+Pl155jZo1a1JYWEifnt05/oSTmPT6W1vXOWPwTzm5X/8Up5RstDm4KUNP7Ub3s0awqbCIifdfzN+mzOWBG87gmv+cwJRZn/HzAUdy+dm9GT7qRTZsLGT4qBdo02I/2h7cNO3xUxV26uTun7h7B3fvABQA64C97q9sM6NmzZoAFBYWUlhYiJltXb569WrenPw6/foPTGtEyVLrHzfhvQ+/Yv2GQoqKtvD2rM8Y0PNwDjmgEVNmfQbA69PnMbB3BwDWbdjEO7O/YMPGwjTHzgnldY2mN/C5uy8op/3llKKiIo48oiMHNmtMr959OKJL163LJj43gR49e1O7du0UJ5RszP38G47u1IL6dfYhv3pVTjy6Lc2a1OOfny+ib4/DADj1uE40a1wv5UlzT3mF5nRg7LYWmNkFZjbTzGYuX76snMYpX5UrV2b6jA+Y/8VCZs2cwdy5H21dNv6pJ/nZ4NNTnE6y9cmXS7jrL5N44YFhTLz/EubM/5rNm4u48MbHufC0Y5j6+NXUrJHHpsKitEfNOeEXg82sGtAf+N22lrv7Q8BDAJ0KOnv0PGmqW7cu3Y85lkmvvEzbtu1YsWIFs2a+x5Pj/5r2aJKlMc9OY8yz0wC4aVg/vl6ykvlfLaHfxfcD0KJ5I07q3jbNEXNSeRzRnAS87+5LymFfOWfZsmWsXLkSgPXr1/PG66/RqlVrACY8M54Tf9KX6tWrpzmi7ISG9TLX237UpB4Deh3OuJdnbn3MzLjm/BMY/fSUNEfMSeXx8vYQSjlt2hssXryIC849h6KiIrZs2cKgn/6Mk07uC8DT45/iiit/m+6AslPG3nke9evuQ+HmIi67fRwrV6/nkiE9uHDwMQA89/psHnlu+tb15714E7X2qU61qlXo17M9fS++n3lfLE5r/NSYe9zZipnVABYCB7n7qh2t36mgs0+ZNiNsHknPvl1/lfYIEmDjJ+PYsm6p7Wi90CMad18H7Bu5DxHJffoIgoiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwik0IhJOoRGRcAqNiIRTaEQknEIjIuEUGhEJp9CISDiFRkTCKTQiEk6hEZFwCo2IhFNoRCScQiMi4RQaEQmn0IhIOIVGRMIpNCISTqERkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4c/e0Z9jKzJYBC9Keo5w0AJanPYTsdnvbr+sB7t5wRyvlVGj2JmY20907pz2H7F76dd02nTqJSDiFRkTCKTTpeSjtASSEfl23QddoRCScjmhEJJxCIyLhFBoRCVcl7QH2FmbWBXB3n2FmbYATgXnu/lLKo4mE08XgcmBmvwdOIhP2SUBXYDLQB3jF3W9NbzopKzP7NTDB3RemPUuuU2jKgZl9CHQA8oDFQDN3/87M8oF33b19qgNKmZjZKmAt8DkwFhjv7svSnSo36RpN+djs7kXuvg743N2/A3D39cCWdEeTXfAF0Ay4GSgA/mlmL5vZ2WZWK93RcotCUz42mVmN5OuC4gfNrA4KzZ7M3X2Lu7/q7ucC+wGjyFx/+yLd0XKLTp3KgZnlufvGbTzeAGjq7h+mMJbsIjP7wN07lrIsPzliFRQakTIzs5buPj/tOfYECo2IhNM1GhEJp9CISDiFRgAwsyIzm21mH5nZ+BKvkpVlWz3M7IXk6/5mds121q1rZheXYR83mtmVZZ1RypdCI8XWu3sHd28HbAIuKrnQMnb694u7T3T327ezSl1gp0MjexaFRrblbaCFmR1oZh+b2SjgfeBHZna8mU0zs/eTI5+aAGZ2opnNM7MpwKnFGzKzc8xsZPJ1YzObYGb/SG7dgNuBg5OjqRHJeleZ2Qwzm2NmN5XY1nVm9omZ/R1oVW7/N2SXKTTyPWZWhcznsorf29MKeCR5v8ha4Hqgj7t3AmYCV5hZdWA00A/oDjQpZfP3Am+6++FAJ2AucA2Zd0t3cPerzOx44BCgC5mPbRSY2TFmVgCcDnQkE7IjdvO3LoH06W0plm9ms5Ov3wb+ROadrgvcfXry+JFAG2CqmQFUA6YBrYEv3f1TADN7DLhgG/voBfwcwN2LgFVmVu8H6xyf3D5I7tckE55aZD7AuC7Zx8Rd+m6lXCk0Umy9u3co+UASk7UlHwImufuQH6zXAdhdb8gy4DZ3f/AH+7hsN+5DyplOnWRnTAeOMrMWAGZWw8xaAvOAH5vZwcl6Q0p5/mvAL5PnVjaz2sBqMkcrxV4BflHi2s/+ZtYIeAs4xczykw8s9tvN35sEUmgka8mPQDgHGGtmc8iEp7W7byBzqvRicjG4tH9t9FKgZ/JjM2YBbd19BZlTsY/MbIS7vwo8AUxL1nsaqOXu7wNPAbOBZ8ic3skeQh9BEJFwOqIRkXAKjYiEU2hEJJxCIyLhFBoRCafQiEg4hUZEwv0/Pd+IZFTouI4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "interp.plot_confusion_matrix(slice_size=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('7', '3', 37), ('3', '7', 26)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interp.most_confused(slice_size=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll show examples below using our MNIST sample. As usual the `on_something` methods are directly called by the fastai library, no need to call them yourself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ShowGraph[source][test]

\n", "\n", "> ShowGraph(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for ShowGraph. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Update a graph of learner stats and metrics after each epoch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)\n", "learn.fit(3)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Training graph](imgs/train_graph.gif)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "If we have `last_metrics` plot them in our pbar graph " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GradientClipping[source][test]

\n", "\n", "> GradientClipping(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for GradientClipping. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Gradient clipping during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1404710.0795710.97154100:08
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy,\n", " callback_fns=partial(GradientClipping, clip=0.1))\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_end[source][test]

\n", "\n", "> on_backward_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_backward_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Clip the gradient before the optimizer step. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping.on_backward_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class BnFreeze[source][test]

\n", "\n", "> BnFreeze(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for BnFreeze. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Freeze moving average statistics in all non-trainable batchnorm layers. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1474090.0813700.97203100:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_begin[source][test]

\n", "\n", "> on_epoch_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put bn layers in eval mode just after `model.train()`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze.on_epoch_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class AccumulateScheduler[source][test]

\n", "\n", "> AccumulateScheduler(**`learn`**:[`Learner`](/basic_train.html#Learner), **`n_step`**:`int`=***`1`***, **`drop_last`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for AccumulateScheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Does accumlated step every nth step by accumulating gradients " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(AccumulateScheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's force `batch_size=2` to mimic a scenario where we can't fit enough batch samples to our memory. We can then set `n_step` as desired to have an effective batch_size of `effective_batch_size=batch_size*n_step`.\n", "\n", "It is also important to use loss func with `reduce='sum'` in order to calculate exact average accumulated gradients.\n", "\n", "Another important note for users is that `batchnorm` is not yet adapted to accumulated gradients. So you should use this callback at your own risk until a hero fixes it :)\n", "\n", "Here we demonstrate this callback with a model without `batchnorm` layers, alternatively you can use `nn.InstanceNorm` or [`nn.GroupNorm`](https://pytorch.org/docs/stable/nn.html#torch.nn.GroupNorm).\n", "\n", "```\n", "from torchvision.models import vgg11\n", "\n", "data = ImageDataBunch.from_folder(path, bs=2)\n", "\n", "learn = cnn_learner(data, resnet18, metrics=accuracy, loss_func=CrossEntropyFlat(reduction='sum'),\n", " callback_fns=partial(AccumulateScheduler, n_step=16))\n", "learn.fit(1)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_top_losses[source][test]

\n", "\n", "> _cl_int_plot_top_losses(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`None`***, **`heatmap_thresh`**:`int`=***`16`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for _cl_int_plot_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_from_learner[source][test]

\n", "\n", "> _cl_int_from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`tta`**=***`False`***)\n", "\n", "
×

Tests found for _cl_int_from_learner:

  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Create an instance of [`ClassificationInterpretation`](/train.html#ClassificationInterpretation). `tta` indicates if we want to use Test Time Augmentation. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.from_learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source][test]

\n", "\n", "> confusion_matrix(**`slice_size`**:`int`=***`1`***)\n", "\n", "
×

Tests found for confusion_matrix:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]

Some other tests where confusion_matrix is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source][test]

\n", "\n", "> most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "
×

Tests found for most_confused:

Some other tests where most_confused is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source][test]

\n", "\n", "> plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_multi_top_losses[source][test]

\n", "\n", "> _cl_int_plot_multi_top_losses(**`samples`**:`int`=***`3`***, **`figsize`**:`Tuple`\\[`int`, `int`\\]=***`(8, 8)`***, **`save_misclassified`**:`bool`=***`False`***)\n", "\n", "
×

No tests found for _cl_int_plot_multi_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class in a multilabeled dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_multi_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Extensions to Learner that easily implement Callback", "title": "train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }