{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":2},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython2","version":"2.7.6"},"colab":{"name":"tutorial_getting_started.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"bukochgPFg7s"},"source":["# Getting started\n","\n","This notebook shows how to get started with Quantus, using a very simple example. For this purpose, we use a LeNet model and MNIST dataset.\n","\n","- Make sure to have GPUs enabled to speed up computation.\n","- Skip running the first cell if you do not use Google Colab."]},{"cell_type":"code","metadata":{"id":"A8-GsXIAcigw"},"source":["# Mount Google Drive.\n","from google.colab import drive\n","drive.mount('/content/drive', force_remount=True)\n","\n","# Install packages.\n","from IPython.display import clear_output\n","!pip install captum opencv-python\n","!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n","clear_output()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MuA8v9jLp-pf"},"source":["# Imports general.\n","import sys\n","import gc\n","import warnings\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import torch\n","import torchvision\n","from torchvision import transforms\n","import captum\n","from captum.attr import *\n","import random\n","import os\n","import cv2\n","\n","# Import package.\n","path = \"/content/drive/MyDrive/Projects\"\n","sys.path.append(f'{path}/quantus')\n","import quantus\n","\n","# Collect garbage.\n","gc.collect()\n","torch.cuda.empty_cache()\n","\n","# Notebook settings.\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") \n","warnings.filterwarnings(\"ignore\", category=UserWarning)\n","%load_ext autoreload\n","%autoreload 2\n","clear_output()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mGhP4bTuoWYF"},"source":["## 1. Preliminaries"]},{"cell_type":"markdown","metadata":{"id":"XqKzag4VFjHT"},"source":["### 1.1 Load datasets\n","\n","We will then load a batch of input, output pairs that we generate explanations for, then to evaluate."]},{"cell_type":"code","metadata":{"id":"TmsZxFhuc0mm"},"source":["# Load datasets and make loaders.\n","test_samples = 24\n","transformer = transforms.Compose([transforms.ToTensor()])\n","train_set = torchvision.datasets.MNIST(root='./sample_data', train=True, transform=transformer, download=True)\n","test_set = torchvision.datasets.MNIST(root='./sample_data', train=False, transform=transformer, download=True)\n","train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, pin_memory=True) # num_workers=4,\n","test_loader = torch.utils.data.DataLoader(test_set, batch_size=200, pin_memory=True)\n","\n","# Load a batch of inputs and outputs to use for evaluation.\n","x_batch, y_batch = iter(test_loader).next()\n","x_batch, y_batch = x_batch.to(device), y_batch.to(device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":191},"id":"aAR67-cnOS67","executionInfo":{"status":"ok","timestamp":1640102275943,"user_tz":-60,"elapsed":501,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"694e08b4-89fd-4be9-d19a-161e9d41dcd5"},"source":["# Plot some inputs!\n","nr_images = 5\n","fig, axes = plt.subplots(nrows=1, ncols=nr_images, figsize=(nr_images*3, int(nr_images*2/3)))\n","for i in range(nr_images):\n"," axes[i].imshow((np.reshape(x_batch[i].cpu().numpy(), (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap=\"gray\")\n"," axes[i].title.set_text(f\"MNIST class - {y_batch[i].item()}\")\n"," axes[i].axis(\"off\")\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAA1MAAACuCAYAAADTXFfGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAMqElEQVR4nO3dXaxsZ1kH8P9TEEHbomKCSArE1kQtFBNRjCSgtoRATTQQ/ADtIdpESlQSE0LSaK1WueDCGIMJogalVaBFTUyU1E9IudAIKjWN5aIUQdsTmxZ6WvqF7ePFmhN3m9nn7P3O7Pnav19yLmZmnXXetfez1qz/vPM+p7o7AAAAHM456x4AAADANhKmAAAABghTAAAAA4QpAACAAcIUAADAAGEKAABggDC1RFXVVXXRuscBB6Vm2TZqlm2jZtk2avZwNipMVdXnquqxqvrGpzz/r7Nf7Itmj/9w9vh79mxzUVX1nscfq6or9zy+uqrurKoHq+q/qurDs+dvmz33YFU9XlWP7Hl89VEf81GpqjfvOY4Hq+qh2c/su9Y9tl2iZpenqr63qv6mqu6rqnuq6qaqet66x7Vr1OzyVNUzquojs59pV9X3r3tMu0jNLldVXVpVt8/uC/6hql647jHtGjV7NKrqmtnP67J1j2WvjQpTM3cm+YnTD6rqJUm+Zs529yX59YPssKpOJPmpJJd197lJXpbk75Kkuy/u7nNnz9+S5OdOP+7udy12KOvT3X+85zjOTfK2JJ9N8i9rHtouUrPL8fVJ3pfkRUlemOSBJO9f54B2mJpdnk8k+ckkJ9c9kB2nZpdgdnP/Z0l+Ock3JPlkkg+vdVC7S80uUVVdmOSNSe5e91ieahPD1PVJrtjz+ESSD8zZ7o+SXFJVrzrAPr87yc3dfUeSdPfJ7n7fyOCq6mmzTwXuqKoHqupTVXXBnO0un30CcaqqvlBV1+557ZlVdUNV3VtVX6qqf66q585ee0tVfXa27zur6s0j45zjRJIPdHefdUsOS80uoWa7+6PdfVN3n+ruh5K8J8krRvbFWanZ5dTsY939W939iSSPj+yDA1Ozy7k3eH2S22bX2keSXJvkpVX1bYP7Y39qdrn3s7+T5J1JHltwP0u3iWHqH5OcX1XfXlVPS/LjSW6Ys91DSd6V5DcOuM8rquodVfWy2X5H/WKmTxpel+T8JD89G8tTfTnTSfR1SS5PclVV/cjstRNJnp3kgiTPSfLWJA9X1dcm+e0kr+3u85J8X5J/W2CsSZKapvBfmfknMYtTs0uu2ZlXJrltSfviydTs0dQsR0fNLqdmL07y6dMPuvvLSe6YPc9yqdklXWer6o1JHu3uvxrdx1HaxDCV/H+af3WS/0jy3/ts97tJXlBVrz3Tzrr7hiQ/n+Q1ST6e5H+q6p2DY7syyS9192d68unuvnfOv/mx7v737n6iu29N8sEkpz91+Eqmoruoux/v7k9196nZa08keXFVPau77+7uZdxMXpHklu6+cwn7Yj41u8SarapLklyT5B2L7ot9qdnlXmc5emp28Zo9N8n9T3nu/iTnDe6PM1OzC9ZsVZ2XKWy+feTvr8Imh6k3JXlLzjCb0t2PJrlu9ueMelpDdFmmZP3WJNdV1WsGxnZBpk9xzqiqXl7Tws57qur+2b95eiHi9UluTvKhqrqrqt5dVV81+4Tox2bb3l1Vf7nf1Hs9ubnEC84ynCsyTSNzdNTskmq2pg5CH03y9u6+5VBHymGo2eVeZzl6anbxmn0w0yzEXudnWqPK8qnZxWv22iTXd/fnDn2EK7KRYaq7/zPTwr3XZVooeSbvz1RQrz/gvr/S3TcluTXJiweG94UkFx5guz9J8hdJLujuZyd5b5LaM4Zf7e7vyDT1+UOZfa+2u2/u7lcneV6S25P83j7Hce6eP5/fbxBV9Yok35zkIwc9QA5PzS6nZmdfSf3bJNd19/WHOUgOR80u7zrLaqjZpdTsbUleevrB7OtYF8ZXqo+Eml1KzV6a5Beq6mRVncwUAm9cYEZu6TYyTM38TJIfnKXbfXX3/yb5lUyL0uaqaRHc5VV1XlWdM5tGvTjJPw2M6/czfQrwrTW5pKqeM2e785Lc192P1NTy8k17xvMDVfWS2XddT2WaJn2iqp5bVT88u7g9mukTpCcGxrjXiSR/2t0+dTp6anaBmq2q5yf5+yTv6e73juyDQ1OzC15nq+qrq+qZs4fPqGlBdo3uj7NSs4vV7J9n+urVG2Z1e02SW7v79sH9cXZqdrGavTRTWPzO2Z+7kvxspoYUG2Fjw1R339Hdnzzg5h/MmVslnkpydZLPJ/lSkncnuaqnDkyH9ZtJbkzy17P9/kGSZ83Z7m1Jfq2qHsh0sbpxz2vflGmm6FSm79B+PNNU6TmZFgTelalV5quSXDUwxiRTl5UkPxpf8VsJNbtwzV6Z5FuSXLt32n9wXxyAml38OpvkM0keTvL8TF93eThTa3+OgJpdrGa7+54kb8jU7OCLSV6eqTECR0TNLlyz9/bUtfBkd5/M1Dn1i929MfcH1TplAwAAHNrGzkwBAABsMmEKAABggDAFAAAwQJgCAAAY8PQzvVhVulOwkO5eaYtgNcuiVl2zibplca61bBs1y7bZr2bNTAEAAAwQpgAAAAYIUwAAAAOEKQAAgAHCFAAAwABhCgAAYIAwBQAAMECYAgAAGCBMAQAADBCmAAAABghTAAAAA4QpAACAAcIUAADAAGEKAABggDAFAAAw4OnrHgCwv+5e9xCepKrWPQQAgI1hZgoAAGCAMAUAADBAmAIAABggTAEAAAzQgAI2xKY1m5hn3hg1pWCTLHIeqWXOZlXXabUI28PMFAAAwABhCgAAYIAwBQAAMECYAgAAGKABBazYNjSaAAA210HvJTatmclh7oE2bez7MTMFAAAwQJgCAAAYIEwBAAAMEKYAAAAGaEABR2idzSbmLdzU/AIm27KwmfVxvQQOwswUAADAAGEKAABggDAFAAAwQJgCAAAYsLENKDZt4T5sCvUJEw0CWJZNq6WDjsf7wfGwafW5KvOOexNr3swUAADAAGEKAABggDAFAAAwQJgCAAAYsLENKNZp1xb6beJiveNi3s/+MPXldwfAfvZ7P/HeAatjZgoAAGCAMAUAADBAmAIAABggTAEAAAwQpgAAAAbo5gcrtqouS7vWlRJgGVwb2VRqczuZmQIAABggTAEAAAwQpgAAAAYIUwAAAAM2tgGFRfpwcKuq41Wdl7DXovWtbo+n4/z+Pu/YnQe7b9N+x8flHDQzBQAAMECYAgAAGCBMAQAADBCmAAAABmxsA4pVsVgP5tu0cwNgP9vy3nnQ6+pRHM9B9+navxoa6+wOM1MAAAADhCkAAIABwhQAAMAAYQoAAGDAsW9AsU5HscDUgsTdty0LrQGOg6N43523T9d+Ntlxvqc1MwUAADBAmAIAABggTAEAAAwQpgAAAAZoQAEbbFULjrdlkSe7zyJ7NplrJSMWva4dh7rb5mM0MwUAADBAmAIAABggTAEAAAwQpgAAAAZoQLEix/l/hmZzqBl2nRpnhLoBRpmZAgAAGCBMAQAADBCmAAAABghTAAAAA4QpAACAAbr5wYY4io6PAMB2m3d/sM4OlO5XnszMFAAAwABhCgAAYIAwBQAAMECYAgAAGKABxRFY9sK8dS4yZPlWtXBT3bDpFj0X1Di7wGL+3TfvWrXo713dbA4zUwAAAAOEKQAAgAHCFAAAwABhCgAAYIAGFBvGgmoAYBXcc8DizEwBAAAMEKYAAAAGCFMAAAADhCkAAIABGlAswP8+zaawiBhg863zvsH7xGY5zO9jFXWzqvHsYh2amQIAABggTAEAAAwQpgAAAAYIUwAAAAM0oDigo1j8t4uL8FiPVS1qnlez29yIxTkIJNt9HZvHtW237NLvc79zbZuP0cwUAADAAGEKAABggDAFAAAwQJgCAAAYIEwBAAAM0M1vjl3r6gPL4txglNoBYJu79u3HzBQAAMAAYQoAAGCAMAUAADBAmAIAABigAcWK7OKCOwAAOM7MTAEAAAwQpgAAAAYIUwAAAAOEKQAAgAHHvgFFd697CBwz+zUjOa61qDnL7ll2LasRzmZejWzzNVXNw/YwMwUAADBAmAIAABggTAEAAAwQpgAAAAYc+wYUR8HCUUaoG4DlOcw1dZubVQDrZWYKAABggDAFAAAwQJgCAAAYIEwBAAAMOFYNKCwwBTh6mqmwbdQsLN9xOa/MTAEAAAwQpgAAAAYIUwAAAAOEKQAAgAHHqgEFAAAw7rg0ljgoM1MAAAADhCkAAIABwhQAAMAAYQoAAGCAMAUAADBAN78F6GYCAADHl5kpAACAAcIUAADAAGEKAABggDAFAAAw4Fg1oNAwAgAAWBYzUwAAAAOEKQAAgAHCFAAAwABhCgAAYEB197rHAAAAsHXMTAEAAAwQpgAAAAYIUwAAAAOEKQAAgAHCFAAAwABhCgAAYMD/Ac/SlKuddICZAAAAAElFTkSuQmCC\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"vmccxpA0n6MY"},"source":["### 1.2 Train a LeNet model\n","\n","(or any other model of choice). \n","Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch."]},{"cell_type":"code","metadata":{"id":"CUghaOhXddLU","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1640102275944,"user_tz":-60,"elapsed":8,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"cb625cea-1bc4-477e-a549-4bd1e1a8e158"},"source":["class LeNet(torch.nn.Module):\n"," \"\"\"Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch.\"\"\"\n"," def __init__(self):\n"," super().__init__()\n"," self.conv_1 = torch.nn.Conv2d(1, 6, 5)\n"," self.pool_1 = torch.nn.MaxPool2d(2, 2)\n"," self.relu_1 = torch.nn.ReLU()\n"," self.conv_2 = torch.nn.Conv2d(6, 16, 5)\n"," self.pool_2 = torch.nn.MaxPool2d(2, 2)\n"," self.relu_2 = torch.nn.ReLU()\n"," self.fc_1 = torch.nn.Linear(256, 120)\n"," self.relu_3 = torch.nn.ReLU()\n"," self.fc_2 = torch.nn.Linear(120, 84)\n"," self.relu_4 = torch.nn.ReLU()\n"," self.fc_3 = torch.nn.Linear(84, 10)\n","\n"," def forward(self, x):\n"," x = self.pool_1(self.relu_1(self.conv_1(x)))\n"," x = self.pool_2(self.relu_2(self.conv_2(x)))\n"," x = x.view(x.shape[0], -1)\n"," x = self.relu_3(self.fc_1(x))\n"," x = self.relu_4(self.fc_2(x))\n"," x = self.fc_3(x)\n"," return x\n","\n","# Load model architecture.\n","model = LeNet()\n","print(f\"\\n Model architecture: {model.eval()}\\n\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n"," Model architecture: LeNet(\n"," (conv_1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n"," (pool_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (relu_1): ReLU()\n"," (conv_2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n"," (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (relu_2): ReLU()\n"," (fc_1): Linear(in_features=256, out_features=120, bias=True)\n"," (relu_3): ReLU()\n"," (fc_2): Linear(in_features=120, out_features=84, bias=True)\n"," (relu_4): ReLU()\n"," (fc_3): Linear(in_features=84, out_features=10, bias=True)\n",")\n","\n"]}]},{"cell_type":"code","metadata":{"id":"olAfyOHzevne"},"source":["def train_model(model, \n"," train_data: torchvision.datasets,\n"," test_data: torchvision.datasets, \n"," device: torch.device, \n"," epochs: int = 20,\n"," criterion: torch.nn = torch.nn.CrossEntropyLoss(), \n"," optimizer: torch.optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9), \n"," evaluate: bool = False):\n"," \"\"\"Train torch model.\"\"\"\n"," \n"," model.train()\n"," \n"," for epoch in range(epochs):\n","\n"," for images, labels in train_data:\n"," images, labels = images.to(device), labels.to(device)\n"," \n"," optimizer.zero_grad()\n"," \n"," logits = model(images)\n"," loss = criterion(logits, labels)\n"," loss.backward()\n"," optimizer.step()\n","\n"," # Evaluate model!\n"," if evaluate:\n"," predictions, labels = evaluate_model(model, test_data, device)\n"," test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())\n"," \n"," print(f\"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}\")\n","\n"," return model\n","\n","def evaluate_model(model, data, device):\n"," \"\"\"Evaluate torch model.\"\"\"\n"," model.eval()\n"," logits = torch.Tensor().to(device)\n"," targets = torch.LongTensor().to(device)\n","\n"," with torch.no_grad():\n"," for images, labels in data:\n"," images, labels = images.to(device), labels.to(device)\n"," logits = torch.cat([logits, model(images)])\n"," targets = torch.cat([targets, labels])\n"," \n"," return torch.nn.functional.softmax(logits, dim=1), targets"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"t6_qEwhee1WH","executionInfo":{"status":"ok","timestamp":1640102277771,"user_tz":-60,"elapsed":1832,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"fabc06df-34b7-49fc-ba4f-583707c862bb"},"source":["path_model_weights = \"drive/MyDrive/Projects/quantus/tutorials/assets/mnist\"\n","\n","if pathlib.Path(path_model_weights).is_file():\n"," model.load_state_dict(torch.load(path_model_weights))\n"," \n","else:\n","\n"," # Train and evaluate model.\n"," model = train_model(model=model.to(device),\n"," train_data=train_loader,\n"," test_data=test_loader,\n"," device=device,\n"," epochs=20,\n"," criterion=torch.nn.CrossEntropyLoss().to(device),\n"," optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),\n"," evaluate=True)\n","\n"," # Save model.\n"," torch.save(model.state_dict(), path_model_weights)\n","\n","# Model to GPU and eval mode.\n","model.to(device)\n","model.eval()\n","\n","# Check test set performance.\n","predictions, labels = evaluate_model(model, test_loader, device)\n","test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy()) \n","print(f\"Model test accuracy: {(100 * test_acc):.2f}%\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Model test accuracy: 99.01%\n"]}]},{"cell_type":"markdown","metadata":{"id":"4vY9mZQanaxr"},"source":["### 1.3 Generate explanations\n","\n","There exist multiple ways to generate explanations for neural network models e.g., using `captum` or `innvestigate` libraries. In this example, we rely on the `quantus.explain` functionality (a simple wrapper around `captum`) however use whatever approach or library you'd like to create your explanations.\n","\n","**Requirements.**\n","\n","* **Data type.** Similar to the x-y pairs, the attributions should also be of type `np.ndarray`\n","* **Shape.** Sharing all the same dimensions as the input (expect for nr_channels which for explanations is equal to 1). For example, if x_batch is of size (128, 3, 224, 224) then the attributions should be of size (128, 1, 224, 224)."]},{"cell_type":"code","metadata":{"id":"gNxAtc2Co1pL"},"source":["# Generate normalised Saliency and Integrated Gradients attributions of the first batch of the test set.\n","a_batch_saliency = quantus.normalise_by_negative(Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1).cpu().numpy())\n","a_batch_intgrad = quantus.normalise_by_negative(IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(x_batch)).sum(axis=1).cpu().numpy())\n","\n","# Save x_batch and y_batch as numpy arrays that will be used to call metric instances.\n","x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()\n","\n","\n","# Quick assert.\n","assert [isinstance(obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_saliency, a_batch_intgrad]]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iRDwzUUp8bR2"},"source":["Visualise attributions given model and pairs of input-output."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":628},"id":"82WWNmyoilXo","executionInfo":{"status":"ok","timestamp":1640102281314,"user_tz":-60,"elapsed":3146,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"46910bcf-1acd-46a0-f479-8512f967143c"},"source":["# Plot explanations!\n","nr_images = 3\n","fig, axes = plt.subplots(nrows=nr_images, ncols=3, figsize=(nr_images*2.5, int(nr_images*3)))\n","for i in range(nr_images):\n"," axes[i, 0].imshow((np.reshape(x_batch[i], (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap=\"gray\")\n"," axes[i, 0].title.set_text(f\"MNIST digit {y_batch[i].item()}\")\n"," axes[i, 0].axis(\"off\")\n"," axes[i, 1].imshow(a_batch_saliency[i], cmap=\"seismic\")\n"," axes[i, 1].title.set_text(f\"Saliency\")\n"," axes[i, 1].axis(\"off\")\n"," a = axes[i, 2].imshow(a_batch_intgrad[i], cmap=\"seismic\")\n"," axes[i, 2].title.set_text(f\"Integrated Gradients\")\n"," axes[i, 2].axis(\"off\")\n","plt.savefig(f'{path}/quantus/tutorials/assets/mnist_example.png', dpi=400)\n","plt.tight_layout()\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"tuBkEBv3mihR"},"source":["## 2. Quantative evaluation using Quantus\n","\n","We can evaluate our explanations on a variety of quantuative criteria but as a motivating example we test the Max-Sensitivity (Yeh at el., 2019) of the explanations. This metric tests how the explanations maximally change while subject to slight perturbations."]},{"cell_type":"code","source":["# Define params for evaluation.\n","params_eval = {\n"," \"nr_samples\": 10,\n"," \"perturb_radius\": 0.1,\n"," \"norm_numerator\": quantus.fro_norm,\n"," \"norm_denominator\": quantus.fro_norm,\n"," \"perturb_func\": quantus.uniform_sampling,\n"," \"similarity_func\": quantus.difference,\n"," \"disable_warnings\": True,\n","}"],"metadata":{"id":"aLjrKsT6mS9X"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Return max sensitivity scores in an one-liner - by calling the metric instance.\n","scores_saliency = quantus.MaxSensitivity(**params_eval)(model=model, \n"," x_batch=x_batch,\n"," y_batch=y_batch,\n"," a_batch=a_batch_saliency,\n"," **{\"explain_func\": quantus.explain, \"method\": \"Saliency\", \"device\": device, \"img_size\": 28, \"normalise\": False, \"abs\": False})"],"metadata":{"id":"NlV_43TAmJll"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Return max sensitivity scores in an one-liner - by calling the metric instance.\n","scores_intgrad = quantus.MaxSensitivity(**params_eval)(model=model, \n"," x_batch=x_batch,\n"," y_batch=y_batch,\n"," a_batch=a_batch_intgrad,\n"," **{\"explain_func\": quantus.explain, \"method\": \"IntegratedGradients\", \"device\": device, \"img_size\": 28, \"normalise\": False, \"abs\": False})"],"metadata":{"id":"iq7qqDfSmIdj"},"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"3kBrG51Lpuq9"},"source":["print(f\"max-Sensitivity scores by Yeh et al., 2019\\n\" \\\n"," f\"\\n • Saliency = {np.mean(scores_saliency):.2f} ({np.std(scores_saliency):.2f}).\" \\\n"," f\"\\n • Integrated Gradients = {np.mean(scores_intgrad):.2f} ({np.std(scores_intgrad):.2f}).\"\n"," )"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"uGYHGu1Esya7"},"source":["metrics = {\"max-Sensitivity\": quantus.MaxSensitivity(**params_eval)}\n","\n","xai_methods = {\"Saliency\": a_batch_saliency,\n"," \"IntegratedGradients\": a_batch_intgrad}\n","\n","results = quantus.evaluate(evaluation_metrics=metrics,\n"," explanation_methods=xai_methods,\n"," model=model,\n"," x_batch=x_batch,\n"," y_batch=y_batch,\n"," agg_func=np.mean,\n"," **{\"explain_func\": quantus.explain, \"device\": device, \"img_size\": 28, \"normalise\": False, \"abs\": False})\n","\n","df = pd.DataFrame(results)\n","df"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"X9DkYz4BuySK"},"source":[""],"execution_count":null,"outputs":[]}]}