{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_005 import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model interpretation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic data aug" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/dogscats')\n", "arch = tvm.resnet34" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size,lr = 224,3e-3\n", "\n", "data_norm,data_denorm = normalize_funcs(*imagenet_stats)\n", "tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)\n", "data = data_from_imagefolder(PATH, bs=64, ds_tfms=tfms, num_workers=8, tfms=data_norm, size=size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save activations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "\n", "HookFunc = Callable[[Model, Tensors, Tensors], Any]\n", "\n", "class Hook():\n", " \"Creates a hook\"\n", " def __init__(self, m:Model, hook_func:HookFunc, is_forward:bool=True):\n", " self.hook_func,self.stored = hook_func,None\n", " f = m.register_forward_hook if is_forward else m.register_backward_hook\n", " self.hook = f(self.hook_fn)\n", " self.removed = False\n", "\n", " def hook_fn(self, module:Model, input:Tensors, output:Tensors):\n", " input = (o.detach() for o in input ) if is_listy(input ) else input.detach()\n", " output = (o.detach() for o in output) if is_listy(output) else output.detach()\n", " self.stored = self.hook_func(module, input, output)\n", "\n", " def remove(self):\n", " if not self.removed:\n", " self.hook.remove()\n", " self.removed=True\n", "\n", "class Hooks():\n", " \"Creates several hooks\"\n", " def __init__(self, ms:Collection[Model], hook_func:HookFunc, is_forward:bool=True):\n", " self.hooks = [Hook(m, hook_func, is_forward) for m in ms]\n", " \n", " def __getitem__(self,i:int) -> Hook: return self.hooks[i]\n", " def __len__(self) -> int: return len(self.hooks)\n", " def __iter__(self): return iter(self.hooks)\n", " @property\n", " def stored(self): return [o.stored for o in self]\n", " \n", " def remove(self):\n", " for h in self.hooks: h.remove()\n", "\n", "def hook_output (module:Model) -> Hook: return Hook (module, lambda m,i,o: o)\n", "def hook_outputs(modules:Collection[Model]) -> Hooks: return Hooks(modules, lambda m,i,o: o)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class HookCallback(LearnerCallback):\n", " \"Callback that registers given hooks\"\n", " def __init__(self, learn:Learner, modules:Sequence[Model]=None, do_remove:bool=True):\n", " super().__init__(learn)\n", " self.modules,self.do_remove = modules,do_remove\n", "\n", " def on_train_begin(self, **kwargs):\n", " if not self.modules:\n", " self.modules = [m for m in flatten_model(self.learn.model)\n", " if hasattr(m, 'weight')]\n", " self.hooks = Hooks(self.modules, self.hook)\n", "\n", " def on_train_end(self, **kwargs):\n", " if self.do_remove: self.remove()\n", "\n", " def remove(self): self.hooks.remove\n", " def __del__(self): self.remove()\n", "\n", "class ActivationStats(HookCallback):\n", " \"Callback that record the activations\"\n", " def on_train_begin(self, **kwargs):\n", " super().on_train_begin(**kwargs)\n", " self.stats = []\n", " \n", " def hook(self, m:Model, i:Tensors, o:Tensors) -> Tuple[Rank0Tensor,Rank0Tensor]: \n", " return o.mean().item(),o.std().item()\n", " def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored)\n", " def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0)\n", "\n", "def idx_dict(a): return {v:k for k,v in enumerate(a)}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy, path=PATH,\n", " callback_fns=ActivationStats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ms = learn.activation_stats.modules\n", "d = idx_dict(ms)\n", "ln = d[learn.model[1][8]]; ln" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(learn.activation_stats.stats[1][ln].numpy());" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('e1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Best/worst" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64\n", "classes = data.valid_ds.classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preds,y = learn.TTA()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preds,y = learn.get_preds()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def calc_loss(y_pred:Tensor, y_true:Tensor, loss_class:type=nn.CrossEntropyLoss):\n", " \"Calculate loss between `y_pred` and `y_true` using `loss_class`\"\n", " loss_dl = DataLoader(TensorDataset(tensor(y_pred),tensor(y_true)), bs)\n", " with torch.no_grad():\n", " return torch.cat([loss_class(reduction='none')(*b) for b in loss_dl])\n", "\n", "class ClassificationInterpretation():\n", " \"Interpretation methods for classification models\"\n", " def __init__(self, data:DataBunch, y_pred:Tensor, y_true:Tensor,\n", " loss_class:type=nn.CrossEntropyLoss, sigmoid:bool=True):\n", " self.data,self.y_pred,self.y_true,self.loss_class = data,y_pred,y_true,loss_class\n", " self.losses = calc_loss(y_pred, y_true, loss_class=loss_class)\n", " self.probs = preds.sigmoid() if sigmoid else preds\n", " self.pred_class = self.probs.argmax(dim=1)\n", "\n", " def top_losses(self, k, largest=True):\n", " \"`k` largest(/smallest) losses\"\n", " return self.losses.topk(k, largest=largest)\n", "\n", " def plot_top_losses(self, k, largest=True, figsize=(12,12)):\n", " \"Show images in `top_losses` along with their loss, label, and prediction\"\n", " tl = self.top_losses(k,largest)\n", " classes = self.data.classes\n", " rows = math.ceil(math.sqrt(k))\n", " fig,axes = plt.subplots(rows,rows,figsize=figsize)\n", " for i,idx in enumerate(self.top_losses(k, largest=largest)[1]):\n", " t=data.valid_ds[idx]\n", " t[0].show(ax=axes.flat[i], title=\n", " f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][0]:.2f}')\n", "\n", " def confusion_matrix(self):\n", " \"Confusion matrix as an `np.ndarray`\"\n", " x=torch.arange(0,data.c)\n", " cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)\n", " return cm.cpu().numpy()\n", "\n", " def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any=\"Blues\", figsize:tuple=None):\n", " \"Plot the confusion matrix\"\n", " # This function is copied from the scikit docs\n", " cm = self.confusion_matrix()\n", " plt.figure(figsize=figsize)\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " plt.colorbar()\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, self.data.classes, rotation=45)\n", " plt.yticks(tick_marks, self.data.classes)\n", "\n", " if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, cm[i, j], horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.tight_layout()\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp.top_losses(9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp.plot_top_losses(9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp.confusion_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interp.plot_confusion_matrix()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }