{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp interpret" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.data.all import *\n", "from local.optimizer import *\n", "from local.learner import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from local.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpretation\n", "\n", "> Classes to build objects to better interpret predictions of a model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def plot_top_losses(x, y, *args, **kwargs):\n", " raise Exception(f\"plot_top_losses is not implemented for {type(x)},{type(y)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = [\"plot_top_losses\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Interpretation():\n", " \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n", " def __init__(self, dl, inputs, preds, targs, decoded, losses):\n", " store_attr(self, \"dl,inputs,preds,targs,decoded,losses\")\n", "\n", " @classmethod\n", " def from_learner(cls, learn, ds_idx=1, dl=None, act=None):\n", " \"Construct interpretatio object from a learner\"\n", " if dl is None: dl = learn.dbunch.dls[ds_idx]\n", " return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None))\n", "\n", " def top_losses(self, k=None, largest=True):\n", " \"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`).\"\n", " return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n", " \n", " def plot_top_losses(self, k, largest=True, **kwargs):\n", " losses,idx = self.top_losses(k, largest)\n", " if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)\n", " else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))\n", " b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))\n", " x,y,its = self.dl._pre_show_batch(b, max_n=k)\n", " b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))\n", " x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)\n", " if its is not None:\n", " plot_top_losses(x, y, its, outs.itemgot(slice(len(self.inputs), None)), self.preds[idx], losses, **kwargs)\n", " #TODO: figure out if this is needed\n", " #its None means that a batch knos how to show itself as a whole, so we pass x, x1\n", " #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "interp = Interpretation.from_learner(learn)\n", "x,y = learn.dbunch.valid_ds.tensors\n", "test_eq(*interp.inputs, x)\n", "test_eq(interp.targs, y)\n", "out = learn.model.a * x + learn.model.b\n", "test_eq(interp.preds, out)\n", "test_eq(interp.losses, (out-y)[:,0]**2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ClassificationInterpretation(Interpretation):\n", " \"Interpretation methods for classification models.\"\n", " \n", " def __init__(self, dl, inputs, preds, targs, decoded, losses):\n", " super().__init__(dl, inputs, preds, targs, decoded, losses)\n", " self.vocab = self.dl.vocab\n", " if is_listy(self.vocab): self.vocab = self.vocab[-1]\n", " \n", " def confusion_matrix(self):\n", " \"Confusion matrix as an `np.ndarray`.\"\n", " x = torch.arange(0, len(self.vocab))\n", " cm = ((self.decoded==x[:,None]) & (self.targs==x[:,None,None])).sum(2)\n", " return to_np(cm)\n", "\n", " def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap=\"Blues\", norm_dec=2,\n", " plot_txt=True, **kwargs):\n", " \"Plot the confusion matrix, with `title` and using `cmap`.\"\n", " # This function is mainly copied from the sklearn docs\n", " cm = self.confusion_matrix()\n", " if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " fig = plt.figure(**kwargs)\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " tick_marks = np.arange(len(self.vocab))\n", " plt.xticks(tick_marks, self.vocab, rotation=90)\n", " plt.yticks(tick_marks, self.vocab, rotation=0)\n", "\n", " if plot_txt:\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'\n", " plt.text(j, i, coeff, horizontalalignment=\"center\", verticalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n", " \n", " ax = fig.gca()\n", " ax.set_ylim(len(self.vocab)-.5,-.5)\n", "\n", " plt.tight_layout()\n", " plt.ylabel('Actual')\n", " plt.xlabel('Predicted')\n", " plt.grid(False)\n", "\n", " def most_confused(self, min_val=1):\n", " \"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences.\"\n", " cm = self.confusion_matrix()\n", " np.fill_diagonal(cm, 0)\n", " res = [(self.vocab[i],self.vocab[j],cm[i,j])\n", " for i,j in zip(*np.where(cm>=min_val))]\n", " return sorted(res, key=itemgetter(2), reverse=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core.ipynb.\n", "Converted 01a_utils.ipynb.\n", "Converted 01b_dispatch.ipynb.\n", "Converted 01c_transform.ipynb.\n", "Converted 02_script.ipynb.\n", "Converted 03_torch_core.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_dataloader.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }