{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp interpret" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.data.all import *\n", "from fastai.optimizer import *\n", "from fastai.learner import *\n", "from fastai.tabular.core import *\n", "import sklearn.metrics as skm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.test_utils import *\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpretation of Predictions\n", "\n", "> Classes to build objects to better interpret predictions of a model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai.vision.all import *\n", "mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), \n", " get_items=get_image_files, \n", " splitter=RandomSubsetSplitter(.1,.1, seed=42),\n", " get_y=parent_label)\n", "test_dls = mnist.dataloaders(untar_data(URLs.MNIST_SAMPLE), bs=8)\n", "test_learner = vision_learner(test_dls, resnet18)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def plot_top_losses(x, y, *args, **kwargs):\n", " raise Exception(f\"plot_top_losses is not implemented for {type(x)},{type(y)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = [\"plot_top_losses\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Interpretation():\n", " \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n", " def __init__(self, learn, dl, losses, act=None): \n", " store_attr()\n", "\n", " def __getitem__(self, idxs):\n", " \"Return inputs, preds, targs, decoded outputs, and losses at `idxs`\"\n", " if isinstance(idxs, Tensor): idxs = idxs.tolist()\n", " if not is_listy(idxs): idxs = [idxs]\n", " items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]\n", " tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))\n", " inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False, \n", " with_decoded=True, act=self.act, reorder=False)\n", " return inps, preds, targs, decoded, self.losses[idxs]\n", "\n", " @classmethod\n", " def from_learner(cls, learn, ds_idx=1, dl=None, act=None):\n", " \"Construct interpretation object from a learner\"\n", " if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)\n", " _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,\n", " with_preds=False, with_targs=False, act=act)\n", " return cls(learn, dl, losses, act)\n", "\n", " def top_losses(self, k=None, largest=True, items=False):\n", " \"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). Optionally include items.\"\n", " losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n", " if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]\n", " else: return losses, idx\n", "\n", " def plot_top_losses(self, k, largest=True, **kwargs):\n", " \"Show `k` largest(/smallest) preds and losses. `k` may be int, list, or `range` of desired results.\"\n", " if is_listy(k) or isinstance(k, range):\n", " losses, idx = (o[k] for o in self.top_losses(None, largest))\n", " else: \n", " losses, idx = self.top_losses(k, largest)\n", " inps, preds, targs, decoded, _ = self[idx]\n", " inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)\n", " x, y, its = self.dl._pre_show_batch(inps+targs)\n", " x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))\n", " if its is not None:\n", " plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)\n", " #TODO: figure out if this is needed\n", " #its None means that a batch knows how to show itself as a whole, so we pass x, x1\n", " #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)\n", "\n", " def show_results(self, idxs, **kwargs):\n", " \"Show predictions and targets of `idxs`\"\n", " if isinstance(idxs, Tensor): idxs = idxs.tolist()\n", " if not is_listy(idxs): idxs = [idxs]\n", " inps, _, targs, decoded, _ = self[idxs]\n", " b = tuplify(inps)+tuplify(targs)\n", " self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
class
Interpretation
[source]Interpretation
(**`learn`**, **`dl`**, **`losses`**, **`act`**=*`None`*)\n",
"\n",
"Interpretation base class, can be inherited for task specific Interpretation classes"
],
"text/plain": [
"Interpretation.from_learner
[source]Interpretation.from_learner
(**`learn`**, **`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`act`**=*`None`*)\n",
"\n",
"Construct interpretation object from a learner"
],
"text/plain": [
"Interpretation.top_losses
[source]Interpretation.top_losses
(**`k`**=*`None`*, **`largest`**=*`True`*, **`items`**=*`False`*)\n",
"\n",
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). Optionally include items."
],
"text/plain": [
"Interpretation.plot_top_losses
[source]Interpretation.plot_top_losses
(**`k`**, **`largest`**=*`True`*, **\\*\\*`kwargs`**)\n",
"\n",
"Show `k` largest(/smallest) preds and losses. `k` may be int, list, or `range` of desired results."
],
"text/plain": [
"Interpretation.show_results
[source]Interpretation.show_results
(**`idxs`**, **\\*\\*`kwargs`**)\n",
"\n",
"Show predictions and targets of `idxs`"
],
"text/plain": [
"