{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp interpret" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.data.all import *\n", "from fastai.optimizer import *\n", "from fastai.learner import *\n", "from fastai.tabular.core import *\n", "import sklearn.metrics as skm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.test_utils import *\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpretation of Predictions\n", "\n", "> Classes to build objects to better interpret predictions of a model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.vision.all import *\n", "mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), \n", " get_items=get_image_files, \n", " splitter=RandomSubsetSplitter(.1,.1, seed=42),\n", " get_y=parent_label)\n", "test_dls = mnist.dataloaders(untar_data(URLs.MNIST_SAMPLE), bs=8)\n", "test_learner = vision_learner(test_dls, resnet18)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def plot_top_losses(x, y, *args, **kwargs):\n", " raise Exception(f\"plot_top_losses is not implemented for {type(x)},{type(y)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = [\"plot_top_losses\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Interpretation():\n", " \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n", " def __init__(self, learn, dl, losses, act=None): \n", " store_attr()\n", "\n", " def __getitem__(self, idxs):\n", " \"Return inputs, preds, targs, decoded outputs, and losses at `idxs`\"\n", " if isinstance(idxs, Tensor): idxs = idxs.tolist()\n", " if not is_listy(idxs): idxs = [idxs]\n", " items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]\n", " tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))\n", " inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False, \n", " with_decoded=True, act=self.act, reorder=False)\n", " return inps, preds, targs, decoded, self.losses[idxs]\n", "\n", " @classmethod\n", " def from_learner(cls, learn, ds_idx=1, dl=None, act=None):\n", " \"Construct interpretation object from a learner\"\n", " if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)\n", " _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,\n", " with_preds=False, with_targs=False, act=act)\n", " return cls(learn, dl, losses, act)\n", "\n", " def top_losses(self, k=None, largest=True, items=False):\n", " \"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). Optionally include items.\"\n", " losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n", " if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]\n", " else: return losses, idx\n", "\n", " def plot_top_losses(self, k, largest=True, **kwargs):\n", " \"Show `k` largest(/smallest) preds and losses. `k` may be int, list, or `range` of desired results.\"\n", " if is_listy(k) or isinstance(k, range):\n", " losses, idx = (o[k] for o in self.top_losses(None, largest))\n", " else: \n", " losses, idx = self.top_losses(k, largest)\n", " inps, preds, targs, decoded, _ = self[idx]\n", " inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)\n", " x, y, its = self.dl._pre_show_batch(inps+targs)\n", " x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))\n", " if its is not None:\n", " plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)\n", " #TODO: figure out if this is needed\n", " #its None means that a batch knows how to show itself as a whole, so we pass x, x1\n", " #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)\n", "\n", " def show_results(self, idxs, **kwargs):\n", " \"Show predictions and targets of `idxs`\"\n", " if isinstance(idxs, Tensor): idxs = idxs.tolist()\n", " if not is_listy(idxs): idxs = [idxs]\n", " inps, _, targs, decoded, _ = self[idxs]\n", " b = tuplify(inps)+tuplify(targs)\n", " self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class Interpretation[source]

\n", "\n", "> Interpretation(**`learn`**, **`dl`**, **`losses`**, **`act`**=*`None`*)\n", "\n", "Interpretation base class, can be inherited for task specific Interpretation classes" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Interpretation` is a helper base class for exploring predictions from trained models. It can be inherited for task specific interpretation classes, such as `ClassificationInterpretation`. `Interpretation` is memory efficient and should be able to process any sized dataset, provided the hardware could train the same model.\n", "\n", "> Note: `Interpretation` is memory efficient due to generating inputs, predictions, targets, decoded outputs, and losses for each item on the fly, using batch processing where possible." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Interpretation.from_learner[source]

\n", "\n", "> Interpretation.from_learner(**`learn`**, **`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`act`**=*`None`*)\n", "\n", "Construct interpretation object from a learner" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.from_learner, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Interpretation.top_losses[source]

\n", "\n", "> Interpretation.top_losses(**`k`**=*`None`*, **`largest`**=*`True`*, **`items`**=*`False`*)\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). Optionally include items." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.top_losses, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With the default of `k=None`, `top_losses` will return the entire dataset's losses. `top_losses` can optionally include the input items for each loss, which is usually a file path or Pandas `DataFrame`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Interpretation.plot_top_losses[source]

\n", "\n", "> Interpretation.plot_top_losses(**`k`**, **`largest`**=*`True`*, **\\*\\*`kwargs`**)\n", "\n", "Show `k` largest(/smallest) preds and losses. `k` may be int, list, or `range` of desired results." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.plot_top_losses, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To plot the first 9 top losses:\n", "```python\n", "interp = Interpretation.from_learner(learn)\n", "interp.plot_top_losses(9)\n", "```\n", "Then to plot the 7th through 16th top losses:\n", "```python\n", "interp.plot_top_losses(range(7,16))\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Interpretation.show_results[source]

\n", "\n", "> Interpretation.show_results(**`idxs`**, **\\*\\*`kwargs`**)\n", "\n", "Show predictions and targets of `idxs`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Interpretation.show_results, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like `Learner.show_results`, except can pass desired index or indicies for item(s) to show results from." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "interp = Interpretation.from_learner(test_learner)\n", "x, y, out = [], [], []\n", "for batch in test_learner.dls.valid:\n", " x += batch[0]\n", " y += batch[1]\n", " out += test_learner.model(batch[0])\n", "x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)\n", "inps, preds, targs, decoded, losses = interp[:]\n", "test_eq(inps, to_cpu(x))\n", "test_eq(targs, to_cpu(y))\n", "loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)\n", "test_close(losses, to_cpu(loss))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "# verify stored losses equal calculated losses for idx\n", "top_losses, idx = interp.top_losses(9)\n", "\n", "dl = test_learner.dls[1].new(shuffle=False, drop_last=False)\n", "items = getattr(dl.items, 'iloc', L(dl.items))[idx]\n", "tmp_dl = test_learner.dls.test_dl(items, with_labels=True, process=not isinstance(dl, TabDataLoader))\n", "_, _, _, _, losses = test_learner.get_preds(dl=tmp_dl, with_input=True, with_loss=True, \n", " with_decoded=True, act=None, reorder=False)\n", "\n", "test_close(top_losses, losses, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#dummy test to ensure we can run on the training set\n", "interp = Interpretation.from_learner(test_learner, ds_idx=0)\n", "x, y, out = [], [], []\n", "for batch in test_learner.dls.train.new(drop_last=False, shuffle=False):\n", " x += batch[0]\n", " y += batch[1]\n", " out += test_learner.model(batch[0])\n", "x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)\n", "inps, preds, targs, decoded, losses = interp[:]\n", "test_eq(inps, to_cpu(x))\n", "test_eq(targs, to_cpu(y))\n", "loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)\n", "test_close(losses, to_cpu(loss))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ClassificationInterpretation(Interpretation):\n", " \"Interpretation methods for classification models.\"\n", "\n", " def __init__(self, learn, dl, losses, act=None):\n", " super().__init__(learn, dl, losses, act)\n", " self.vocab = self.dl.vocab\n", " if is_listy(self.vocab): self.vocab = self.vocab[-1]\n", "\n", " def confusion_matrix(self):\n", " \"Confusion matrix as an `np.ndarray`.\"\n", " x = torch.arange(0, len(self.vocab))\n", " _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, \n", " with_targs=True, act=self.act)\n", " d,t = flatten_check(decoded, targs)\n", " cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)\n", " return to_np(cm)\n", "\n", " def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap=\"Blues\", norm_dec=2,\n", " plot_txt=True, **kwargs):\n", " \"Plot the confusion matrix, with `title` and using `cmap`.\"\n", " # This function is mainly copied from the sklearn docs\n", " cm = self.confusion_matrix()\n", " if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " fig = plt.figure(**kwargs)\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " tick_marks = np.arange(len(self.vocab))\n", " plt.xticks(tick_marks, self.vocab, rotation=90)\n", " plt.yticks(tick_marks, self.vocab, rotation=0)\n", "\n", " if plot_txt:\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'\n", " plt.text(j, i, coeff, horizontalalignment=\"center\", verticalalignment=\"center\", color=\"white\"\n", " if cm[i, j] > thresh else \"black\")\n", "\n", " ax = fig.gca()\n", " ax.set_ylim(len(self.vocab)-.5,-.5)\n", "\n", " plt.tight_layout()\n", " plt.ylabel('Actual')\n", " plt.xlabel('Predicted')\n", " plt.grid(False)\n", "\n", " def most_confused(self, min_val=1):\n", " \"Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences\"\n", " cm = self.confusion_matrix()\n", " np.fill_diagonal(cm, 0)\n", " res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))]\n", " return sorted(res, key=itemgetter(2), reverse=True)\n", "\n", " def print_classification_report(self):\n", " \"Print scikit-learn classification report\"\n", " _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, \n", " with_targs=True, act=self.act)\n", " d,t = flatten_check(decoded, targs)\n", " names = [str(v) for v in self.vocab]\n", " print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=names))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "# simple test to make sure ClassificationInterpretation works\n", "interp = ClassificationInterpretation.from_learner(test_learner)\n", "cm = interp.confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SegmentationInterpretation(Interpretation):\n", " \"Interpretation methods for segmentation models.\"\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 01a_losses.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 10b_tutorial.albumentations.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 18b_callback.preds.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.image_sequence.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted dev-setup.ipynb.\n", "Converted app_examples.ipynb.\n", "Converted camvid.ipynb.\n", "Converted migrating_catalyst.ipynb.\n", "Converted migrating_ignite.ipynb.\n", "Converted migrating_lightning.ipynb.\n", "Converted migrating_pytorch.ipynb.\n", "Converted migrating_pytorch_verbose.ipynb.\n", "Converted ulmfit.ipynb.\n", "Converted index.ipynb.\n", "Converted quick_start.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }