{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":2},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython2","version":"2.7.6"},"colab":{"name":"tutorial_sensivitivty_parameterisation.ipynb","provenance":[{"file_id":"1nMiFlKaXP_F5rbeFwc35BL4SwO2ktnEs","timestamp":1627483591348},{"file_id":"14H0YjeULfWNdvzhVJbiScA0sUcLhGmY0","timestamp":1627472772689},{"file_id":"1tlXprNNO0PAcdH4Wh5ThFWgknpSQTkmd","timestamp":1626956510892}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"80e9f6d45d5349e68f8c1dea5d30044e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_bf620515a66c45649c422cd7126b730f","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_1db49d8b301f41eda1d14a00a6fb6f9d","IPY_MODEL_05ee309130bb4801bf8dd1168d4b58ec","IPY_MODEL_bdb6932e534c4a4d93475c2c565bd73a"]}},"bf620515a66c45649c422cd7126b730f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"1db49d8b301f41eda1d14a00a6fb6f9d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_d814e2faef0641e09ab90548a81eaeef","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_b22474e0097840b7803ae87ae05c4a00"}},"05ee309130bb4801bf8dd1168d4b58ec":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_c0f8e7898bac4d42b621fc090999d6c7","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":170498071,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":170498071,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_fe15ea416eb04db2938ccbc05b595acd"}},"bdb6932e534c4a4d93475c2c565bd73a":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_e9036661f86e4d4bac3635d6f5174cb4","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 170499072/? [00:12<00:00, 18669248.59it/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_939c5fa81c134e309178d90f15c37564"}},"d814e2faef0641e09ab90548a81eaeef":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"b22474e0097840b7803ae87ae05c4a00":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"c0f8e7898bac4d42b621fc090999d6c7":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"fe15ea416eb04db2938ccbc05b595acd":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"e9036661f86e4d4bac3635d6f5174cb4":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"939c5fa81c134e309178d90f15c37564":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}}}}},"cells":[{"cell_type":"markdown","metadata":{"id":"bukochgPFg7s"},"source":["## Tutorial - Measuring sensitivity of hyperparameter choice\n","\n","This tutorial demonstrates how one can use the library to measure to what extent the outcome of evaluation is sensitive to the choice of hyperparameters e.g., choice of baseline value to mask an image with, patch sizes or number of runs. We use a LeNet model and CIFAR-10 dataset to showcase the library's functionality and test the Faithfulness Correlation by Bhatt et al., 2020..\n","\n","- Make sure to have GPUs enabled to speed up computation.\n","- Skip running the first cell if you do not use Google Colab.\n","\n"]},{"cell_type":"code","metadata":{"id":"52Q6cyeSS1ZG"},"source":["# Mount Google Drive.\n","from google.colab import drive\n","drive.mount('/content/drive', force_remount=True)\n","\n","# Install packages.\n","from IPython.display import clear_output\n","!pip install captum opencv-python xmltodict\n","!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n","clear_output()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MuA8v9jLp-pf"},"source":["# Imports general.\n","import sys\n","import gc\n","import warnings\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import torch\n","import torchvision\n","from torchvision import transforms\n","import captum\n","from captum.attr import *\n","import random\n","import os\n","import cv2\n","\n","# Import package.\n","path = \"/content/drive/MyDrive/Projects\"\n","sys.path.append(f'{path}/quantus')\n","import quantus\n","\n","# Collect garbage.\n","gc.collect()\n","torch.cuda.empty_cache()\n","\n","# Notebook settings.\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") \n","warnings.filterwarnings(\"ignore\", category=UserWarning)\n","%load_ext autoreload\n","%autoreload 2\n","clear_output()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rvHznz40r2lw"},"source":["## 1. Preliminaries"]},{"cell_type":"markdown","metadata":{"id":"mB2QuiaDlu7w"},"source":["### 1.1 Load CIFAR10 dataset"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":115,"referenced_widgets":["80e9f6d45d5349e68f8c1dea5d30044e","bf620515a66c45649c422cd7126b730f","1db49d8b301f41eda1d14a00a6fb6f9d","05ee309130bb4801bf8dd1168d4b58ec","bdb6932e534c4a4d93475c2c565bd73a","d814e2faef0641e09ab90548a81eaeef","b22474e0097840b7803ae87ae05c4a00","c0f8e7898bac4d42b621fc090999d6c7","fe15ea416eb04db2938ccbc05b595acd","e9036661f86e4d4bac3635d6f5174cb4","939c5fa81c134e309178d90f15c37564"]},"id":"PZ6VyL7x26Ue","executionInfo":{"status":"ok","timestamp":1637240855177,"user_tz":-60,"elapsed":21153,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"2277b955-ab95-49d0-f58a-4a31ffc7bdfb"},"source":["# Load datasets and make loaders.\n","test_samples = 200\n","transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n","train_set = torchvision.datasets.CIFAR10(root='./sample_data', train=True, transform=transformer, download=True)\n","test_set = torchvision.datasets.CIFAR10(root='./sample_data', train=False, transform=transformer, download=True)\n","train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, pin_memory=True) # num_workers=4,\n","test_loader = torch.utils.data.DataLoader(test_set, batch_size=200, pin_memory=True)\n","\n","# Specify class labels.\n","classes = {0: 'plane', 1: 'car', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}\n","\n","# Load a batch of inputs and outputs to use for evaluation.\n","x_batch, y_batch = iter(test_loader).next()\n","x_batch, y_batch = x_batch.to(device), y_batch.to(device)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./sample_data/cifar-10-python.tar.gz\n","Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./sample_data/cifar-10-python.tar.gz\n"]},{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"80e9f6d45d5349e68f8c1dea5d30044e","version_minor":0,"version_major":2},"text/plain":[" 0%| | 0/170498071 [00:00 None:\n"," \"\"\"Plot some images.\"\"\"\n"," fig = plt.figure(figsize=(10, 10))\n"," img = images / 2 + 0.5 \n"," plt.imshow(np.transpose(img.cpu().numpy(), (1, 2, 0)))\n"," plt.axis(\"off\")\n"," plt.show()\n","\n","# Plot image examples!\n","plot_images(torchvision.utils.make_grid(x_batch[:6, :, :, :]))"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"YKwN9uWs29sn"},"source":["### 1.2 Train a LeNet model\n","\n","(or any other model of choice). Network architecture and training procedure is partly copied from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ypxtHPfolQPD","executionInfo":{"status":"ok","timestamp":1637240855180,"user_tz":-60,"elapsed":12,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"d350e6a7-186c-4ad3-bcb9-37adf3c790b0"},"source":["class Net(torch.nn.Module):\n"," def __init__(self):\n"," super(Net, self).__init__()\n"," self.conv_1 = torch.nn.Conv2d(3, 6, 5)\n"," self.pool_1 = torch.nn.MaxPool2d(2, 2)\n"," self.pool_2 = torch.nn.MaxPool2d(2, 2)\n"," self.conv_2 = torch.nn.Conv2d(6, 16, 5)\n"," self.fc_1 = torch.nn.Linear(16 * 5 * 5, 120)\n"," self.fc_2 = torch.nn.Linear(120, 84)\n"," self.fc_3 = torch.nn.Linear(84, 10)\n"," self.relu_1 = torch.nn.ReLU()\n"," self.relu_2 = torch.nn.ReLU()\n"," self.relu_3 = torch.nn.ReLU()\n"," self.relu_4 = torch.nn.ReLU()\n","\n"," def forward(self, x):\n"," x = self.pool_1(self.relu_1(self.conv_1(x)))\n"," x = self.pool_2(self.relu_2(self.conv_2(x)))\n"," x = x.view(-1, 16 * 5 * 5)\n"," x = self.relu_3(self.fc_1(x))\n"," x = self.relu_4(self.fc_2(x))\n"," x = self.fc_3(x)\n"," return x\n","\n","\n","# Load model architecture.\n","model = Net()\n","print(f\"\\n Model architecture: {model.eval()}\\n\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n"," Model architecture: Net(\n"," (conv_1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n"," (pool_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (conv_2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n"," (fc_1): Linear(in_features=400, out_features=120, bias=True)\n"," (fc_2): Linear(in_features=120, out_features=84, bias=True)\n"," (fc_3): Linear(in_features=84, out_features=10, bias=True)\n"," (relu_1): ReLU()\n"," (relu_2): ReLU()\n"," (relu_3): ReLU()\n"," (relu_4): ReLU()\n",")\n","\n"]}]},{"cell_type":"code","metadata":{"id":"8mP55MfxuSZh"},"source":["def train_model(model, \n"," train_data: torchvision.datasets,\n"," test_data: torchvision.datasets, \n"," device: torch.device, \n"," epochs: int = 20,\n"," criterion: torch.nn = torch.nn.CrossEntropyLoss(), \n"," optimizer: torch.optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9), \n"," evaluate: bool = False):\n"," \"\"\"Train torch model.\"\"\"\n"," \n"," model.train()\n"," \n"," for epoch in range(epochs):\n","\n"," for images, labels in train_data:\n"," images, labels = images.to(device), labels.to(device)\n"," \n"," optimizer.zero_grad()\n"," \n"," logits = model(images)\n"," loss = criterion(logits, labels)\n"," loss.backward()\n"," optimizer.step()\n","\n"," # Evaluate model!\n"," if evaluate:\n"," predictions, labels = evaluate_model(model, test_data, device)\n"," test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())\n"," \n"," print(f\"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}\")\n","\n"," return model\n","\n","def evaluate_model(model, data, device):\n"," \"\"\"Evaluate torch model.\"\"\"\n"," model.eval()\n"," logits = torch.Tensor().to(device)\n"," targets = torch.LongTensor().to(device)\n","\n"," with torch.no_grad():\n"," for images, labels in data:\n"," images, labels = images.to(device), labels.to(device)\n"," logits = torch.cat([logits, model(images)])\n"," targets = torch.cat([targets, labels])\n"," \n"," return torch.nn.functional.softmax(logits, dim=1), targets"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lbfAkSEtmGym","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1637240857985,"user_tz":-60,"elapsed":2814,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"5e783e7d-a132-46d0-a1fa-85149b863bc2"},"source":["path_model_weights = path + \"/quantus/tutorials/assets/cifar10\"\n","\n","if pathlib.Path(path_model_weights).is_file():\n"," model.load_state_dict(torch.load(path_model_weights))\n"," \n","else:\n","\n"," # Train and evaluate model.\n"," model = train_model(model=model.to(device),\n"," train_data=train_loader,\n"," test_data=test_loader,\n"," device=device,\n"," epochs=20,\n"," criterion=torch.nn.CrossEntropyLoss().to(device),\n"," optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),\n"," evaluate=True)\n","\n"," # Save model.\n"," torch.save(model.state_dict(), path_model_weights)\n","\n","# Model to GPU and eval mode.\n","model.to(device)\n","model.eval()\n","\n","# Check test set performance.\n","predictions, labels = evaluate_model(model, test_loader, device)\n","test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy()) \n","print(f\"Model test accuracy: {(100 * test_acc):.2f}%\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Model test accuracy: 59.51%\n"]}]},{"cell_type":"markdown","metadata":{"id":"XqKzag4VFjHT"},"source":["### 1.3 Load gradient-based attributions"]},{"cell_type":"code","metadata":{"id":"uSUkm-d6-p20","colab":{"base_uri":"https://localhost:8080/","height":505},"executionInfo":{"status":"ok","timestamp":1637240858622,"user_tz":-60,"elapsed":645,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"0e31429d-8410-4f89-c57c-84fdd140d528"},"source":["# Load some attributions and plot them. \n","a_batch = quantus.explain(model, \n"," x_batch, \n"," y_batch, \n"," method=\"IntegratedGradients\",\n"," **{\"normalize\": True})\n","\n","# Plot examplary inputs!\n","nr_images = 3\n","fig, axes = plt.subplots(nrows=nr_images, ncols=2, figsize=(nr_images*1.5, int(nr_images*3)))\n","for i in range(nr_images):\n"," axes[i, 0].imshow(np.moveaxis(np.clip(x_batch[i].cpu().numpy(), 0, 1), 0, -1), \n"," vmin=0.0, vmax=1.0)\n"," axes[i, 0].title.set_text(f\"CIFAR-10 - {classes[y_batch[i].item()]}\")\n"," axes[i, 0].axis(\"off\")\n"," axes[i, 1].imshow(a_batch[i], cmap=\"seismic\")\n"," axes[i, 1].title.set_text(f\"IG_norm=[0, 1]\")\n"," axes[i, 1].axis(\"off\")\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"3zPsqB-OsAjf"},"source":["## 2. Quantiatative evaluation using Quantus"]},{"cell_type":"markdown","metadata":{"id":"Vrl4GiojK6r-"},"source":["### 2.1 Measure sensitivity of hyperparameter choice\n","\n","We want to understand how sensitive the evaluation outome of Faithfulness Correlation (Bhatt et al., 2020) is from its hyperparameters."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sD5STAScsp_3","executionInfo":{"status":"ok","timestamp":1637240860152,"user_tz":-60,"elapsed":1536,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"7a44e681-4289-4063-ecfd-00c3209ae4a6"},"source":["# Let's list the default parameters of the metric.\n","metric = quantus.FaithfulnessCorrelation()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["WARNINGS.\n","\n","The Faithfulness correlation metric is likely to be sensitive to the choice of baseline value 'perturb_baseline', size of subset |S| 'subset_size' and the number of runs (for each input and explanation pair) 'nr_runs'. \n","Go over and select each hyperparameter of the metric carefully to avoid misinterpretation of scores. \n","To view all relevant hyperparameters call .get_params of the metric instance. \n","For further reading: Bhatt, Umang, Adrian Weller, and José MF Moura. 'Evaluating and aggregating feature-based model explanations.' arXiv preprint arXiv:2005.00631 (2020).\u001b[0m\n","Normalising attributions may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.\n","\n"]}]},{"cell_type":"code","metadata":{"id":"YCiGXKSVuP3L"},"source":["# Recompute some Saliency explanations.\n","a_batch = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1).cpu().numpy()\n","a_batch_occ = Occlusion(model).attribute(inputs=x_batch, target=y_batch, sliding_window_shapes=(1, 4, 4)).sum(axis=1).cpu().numpy()\n","a_batch_ig = IntegratedGradients(model.to(device)).attribute(inputs=x_batch,\n"," target=y_batch,\n"," baselines=torch.zeros_like(x_batch), \n"," n_steps=10, \n"," method=\"riemann_trapezoid\").sum(axis=1).cpu().numpy()\n","a_batch_gh = GradientShap(model).attribute(inputs=x_batch,\n"," target=y_batch,\n"," baselines=torch.zeros_like(x_batch),).sum(axis=1).cpu().data.numpy()\n","\n","# Metric class expects numpy arrays.\n","x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"JSJ-n7xcVDcd"},"source":["# Define some parameter settings to evaluate.\n","baseline_strategies = [\"mean\", \"random\", \"uniform\", \"black\", \"white\"]\n","subset_sizes = np.array([ 2, 52, 102])\n","iterations = [100, 200]\n","absolutes = [True, False]\n","normalisations = [True, False]\n","sim_funcs = {\"pearson\": quantus.correlation_pearson, \"spearman\": quantus.correlation_spearman}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pFEineha6yIU"},"source":["result = {\n"," \"Normalise\": [],\n"," \"Absolute\": [],\n"," \"Similarity function\": [],\n"," \"Faithfulness score\": [],\n"," \"Baseline strategy\": [],\n"," \"Subset size\": [],\n"," \"Method\": [],\n"," \"Iterations\": [],\n","}\n","methods = {\"Saliency\": a_batch, \"Occlusion\": a_batch_occ, \"Integrated Gradients\": a_batch_ig, \"GradShap\": a_batch_gh}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gI3ahcaJNrUQ"},"source":["#!ls drive/MyDrive/Projects/quantus/tutorials/assets/data"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VcnLEflMNH9T","executionInfo":{"status":"ok","timestamp":1637241914247,"user_tz":-60,"elapsed":430,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"92312ceb-2bb1-40a9-bd54-ccd464e05715"},"source":["path_sensitivity_results = \"sensitivity_results_200_extra_extra.csv\" #\"drive/MyDrive/Projects/quantus/tutorials/assets/data/sensitivity_results.csv\"\n","\n","if pathlib.Path(path_sensitivity_results).is_file():\n"," df = pd.read_csv(path_sensitivity_results)\n","len(df)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["2160"]},"metadata":{},"execution_count":39}]},{"cell_type":"code","metadata":{"id":"UhQWT0TuJKSo"},"source":["path_sensitivity_results = path + \"/quantus/tutorials/assets/data/sensitivity_results.csv\"\n","\n","if pathlib.Path(path_sensitivity_results).is_file():\n"," df = pd.read_csv(path_sensitivity_results)\n","\n","else:\n","\n"," # Score explanations!\n"," for b in baseline_strategies:\n"," for s in subset_sizes:\n"," for nr in iterations: \n"," for method, attr in methods.items():\n"," for sim, sim_func in sim_funcs.items():\n"," score = np.mean(FaithfulnessCorrelation({'abs': True,\n"," 'normalize': True,\n"," 'normalize_func': quantus.normalize_by_max,\n"," 'nr_runs': nr,\n"," 'perturb_baseline': b,\n"," 'perturb_func': quantus.baseline_replacement_by_indices,\n"," 'similarity_func': sim_func,\n"," 'subset_size': s})(model=model.cuda(), \n"," x_batch=x_batch, \n"," y_batch=y_batch,\n"," a_batch=attr,\n"," **{\"device\": device}))\n"," \n"," result[\"Method\"].append(method)\n"," result[\"Baseline strategy\"].append(b.capitalize())\n"," result[\"Subset size\"].append(s)\n"," result[\"Iterations\"].append(nr)\n"," result[\"Faithfulness score\"].append(score)\n"," result[\"Similarity function\"].append(\"spearman\")\n","\n"," \n"," df = pd.DataFrame(result)\n"," df[\"Rank\"] = df.groupby(['Baseline strategy', 'Subset size', 'Iterations', 'Similarity function'])[\"Faithfulness score\"].rank()\n"," df.to_csv(path + \"/quantus/tutorials/assets/data/sensitivity_results.csv\")\n","\n","# Smaller adjustments.\n","df = df.loc[:, ~df.columns.str.contains('^Unnamed')]\n","df.replace(to_replace=\"Integrated Gradients\", value=\"Integrated\\nGradients\", inplace=True)\n","df.replace(value=\"GradShap\", to_replace=\"GS\", inplace=True)\n","df.columns = map(lambda x: str(x).capitalize(), df.columns)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":412},"id":"CKUC7vWENMgW","executionInfo":{"status":"ok","timestamp":1637241979258,"user_tz":-60,"elapsed":323,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"f4b40284-ed03-490a-b877-d9979a1d5cde"},"source":["path_sensitivity_results = \"sensitivity_results_200_extra_extra.csv\" #\"drive/MyDrive/Projects/quantus/tutorials/assets/data/sensitivity_results.csv\"\n","\n","if pathlib.Path(path_sensitivity_results).is_file():\n"," df = pd.read_csv(path_sensitivity_results)\n","\n","\n","# Smaller adjustments.\n","df = df.loc[:, ~df.columns.str.contains('^Unnamed')]\n","df.replace(to_replace=\"Integrated Gradients\", value=\"Integrated\\nGradients\", inplace=True)\n","df.replace(value=\"GradShap\", to_replace=\"GS\", inplace=True)\n","df.columns = map(lambda x: str(x).capitalize(), df.columns)\n","\n","df[\"Rank_2\"] = df.groupby(['Baseline strategy', 'Subset size', 'Iterations', 'Similarity function'])[\"Faithfulness score\"].rank()\n","df"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
NormaliseAbsoluteFaithfulness scoreBaseline strategySubset sizeMethodIterationsSimilarity functionRankRank_2
0TrueTrue0.029557Mean2Saliency100pearson3.07.0
1TrueTrue0.018670Mean2Integrated\\nGradients100pearson2.03.0
2TrueTrue0.024692Mean2Saliency100spearman3.08.0
3TrueFalse0.021572Mean2Integrated\\nGradients100pearson3.05.0
4TrueFalse-0.063219Mean2Saliency100spearman1.01.0
.................................
2155FalseFalse0.048248White142Occlusion500pearson3.09.0
2156FalseFalse0.023836White142Integrated\\nGradients500pearson2.04.0
2157FalseFalse-0.019863White142Saliency500spearman1.01.0
2158FalseFalse-0.016180White142Occlusion500spearman2.02.0
2159FalseFalse0.041108White142Integrated\\nGradients500spearman3.09.0
\n","

2160 rows × 10 columns

\n","
"],"text/plain":[" Normalise Absolute Faithfulness score ... Similarity function Rank Rank_2\n","0 True True 0.029557 ... pearson 3.0 7.0\n","1 True True 0.018670 ... pearson 2.0 3.0\n","2 True True 0.024692 ... spearman 3.0 8.0\n","3 True False 0.021572 ... pearson 3.0 5.0\n","4 True False -0.063219 ... spearman 1.0 1.0\n","... ... ... ... ... ... ... ...\n","2155 False False 0.048248 ... pearson 3.0 9.0\n","2156 False False 0.023836 ... pearson 2.0 4.0\n","2157 False False -0.019863 ... spearman 1.0 1.0\n","2158 False False -0.016180 ... spearman 2.0 2.0\n","2159 False False 0.041108 ... spearman 3.0 9.0\n","\n","[2160 rows x 10 columns]"]},"metadata":{},"execution_count":44}]},{"cell_type":"code","metadata":{"id":"QXdhXGTzUfZf","colab":{"base_uri":"https://localhost:8080/","height":442},"executionInfo":{"status":"ok","timestamp":1637241996406,"user_tz":-60,"elapsed":597,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"a88843ba-f9db-4156-b771-646661346b3c"},"source":["df.groupby(['Baseline strategy', 'Subset size', 'Iterations', 'Similarity function']).mean()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
NormaliseAbsoluteFaithfulness scoreRankRank_2
Baseline strategySubset sizeIterationsSimilarity function
Black2100pearson0.50.50.0248052.06.5
spearman0.50.50.0349202.06.5
250pearson0.50.50.0318582.06.5
spearman0.50.50.0239862.06.5
500pearson0.50.50.0440352.06.5
...........................
White142100spearman0.50.50.0189242.06.5
250pearson0.50.50.0353142.06.5
spearman0.50.50.0224082.06.5
500pearson0.50.50.0365992.06.5
spearman0.50.50.0305582.06.5
\n","

180 rows × 5 columns

\n","
"],"text/plain":[" Normalise ... Rank_2\n","Baseline strategy Subset size Iterations Similarity function ... \n","Black 2 100 pearson 0.5 ... 6.5\n"," spearman 0.5 ... 6.5\n"," 250 pearson 0.5 ... 6.5\n"," spearman 0.5 ... 6.5\n"," 500 pearson 0.5 ... 6.5\n","... ... ... ...\n","White 142 100 spearman 0.5 ... 6.5\n"," 250 pearson 0.5 ... 6.5\n"," spearman 0.5 ... 6.5\n"," 500 pearson 0.5 ... 6.5\n"," spearman 0.5 ... 6.5\n","\n","[180 rows x 5 columns]"]},"metadata":{},"execution_count":45}]},{"cell_type":"code","metadata":{"id":"wbaRUinzfLcj"},"source":["#df[\"Rank\"] = df.groupby(['Baseline strategy', 'Subset size', 'Iterations', 'Similarity function'])[\"Faithfulness score\"].rank()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Etuxl5ytg1b3"},"source":["# Convert to datafame and rank. \n","df = pd.DataFrame(result)\n","df.to_csv(path + \"/quantus/tutorials/assets/data/sensitivity_results.csv\")\n","df[\"Rank\"] = df.groupby(['Baseline strategy', 'Subset size', 'Iterations', 'Similarity function']).rank()\n","\n","# Write to disk and re-open.\n","df.to_csv(path + \"/quantus/tutorials/assets/data/sensitivity_results.csv\")\n","df = pd.read_csv(path + \"/quantus/tutorials/assets/data/sensitivity_results.csv\")\n","\n","# Smaller adjustments.\n","df = df.loc[:, ~df.columns.str.contains('^Unnamed')]\n","df.replace(to_replace=\"Integrated Gradients\", value=\"Integrated\\nGradients\", inplace=True)\n","df.replace(value=\"GradShap\", to_replace=\"GS\", inplace=True)\n","df.columns = map(lambda x: str(x).capitalize(), df.columns)\n","df"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":495},"id":"CtDwzwzpfPCh","executionInfo":{"status":"ok","timestamp":1637229590011,"user_tz":-60,"elapsed":366,"user":{"displayName":"Anna Hedström","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhbfsluHeZ1mzN6Bsf-1zU62lYHcz183jYjeS63=s64","userId":"05540180366077551505"}},"outputId":"c568be87-01c9-43a2-a559-378b3420aec1"},"source":[""],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/pandas/core/frame.py:4389: SettingWithCopyWarning: \n","A value is trying to be set on a copy of a slice from a DataFrame\n","\n","See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n"," method=method,\n"]},{"output_type":"execute_result","data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
NormaliseAbsoluteFaithfulness scoreBaseline strategySubset sizeMethodIterationsSimilarity functionRank
0TrueTrue0.029557Mean2Saliency100pearson7.0
1TrueTrue0.018670Mean2Integrated\\nGradients100pearson3.0
2TrueTrue0.024692Mean2Saliency100spearman8.0
3TrueFalse0.021572Mean2Integrated\\nGradients100pearson5.0
4TrueFalse-0.063219Mean2Saliency100spearman1.0
..............................
2155FalseFalse0.048248White142Occlusion500pearson9.0
2156FalseFalse0.023836White142Integrated\\nGradients500pearson4.0
2157FalseFalse-0.019863White142Saliency500spearman1.0
2158FalseFalse-0.016180White142Occlusion500spearman2.0
2159FalseFalse0.041108White142Integrated\\nGradients500spearman9.0
\n","

2160 rows × 9 columns

\n","
"],"text/plain":[" Normalise Absolute ... Similarity function Rank\n","0 True True ... pearson 7.0\n","1 True True ... pearson 3.0\n","2 True True ... spearman 8.0\n","3 True False ... pearson 5.0\n","4 True False ... spearman 1.0\n","... ... ... ... ... ...\n","2155 False False ... pearson 9.0\n","2156 False False ... pearson 4.0\n","2157 False False ... spearman 1.0\n","2158 False False ... spearman 2.0\n","2159 False False ... spearman 9.0\n","\n","[2160 rows x 9 columns]"]},"metadata":{},"execution_count":16}]},{"cell_type":"code","metadata":{"id":"Z_nKiL1Ra8W-"},"source":["# Group by rank\n","df_view = df.groupby([\"Method\"])[\"Rank\"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)\n","df_view = df_view.append({'Method': 'Method A', 'Rank': 1.0, 'Percentage': 100}, ignore_index=True)\n","df_view = df_view.append({'Method': 'Method B', 'Rank': 2.0, 'Percentage': 100}, ignore_index=True)\n","df_view = df_view.append({'Method': 'Method C', 'Rank': 3.0, 'Percentage': 100}, ignore_index=True)\n","#df_view = df_view.append({'Method': 'Method D', 'Rank': 4.0, 'Percentage': 100}, ignore_index=True)\n","\n","# Reorder the methods for plotting purporses.\n","df_view_ordered = pd.DataFrame(columns=[\"Method\", \"Rank\", \"Percentage\"])\n","df_view_ordered = df_view_ordered.append({'Method': 'Method A', 'Rank': 1.0, 'Percentage': 100}, ignore_index=True)\n","df_view_ordered = df_view_ordered.append({'Method': 'Method B', 'Rank': 2.0, 'Percentage': 100}, ignore_index=True)\n","df_view_ordered = df_view_ordered.append({'Method': 'Method C', 'Rank': 3.0, 'Percentage': 100}, ignore_index=True)\n","df_view_ordered = df_view_ordered.append({'Method': 'Method D', 'Rank': 4.0, 'Percentage': 100}, ignore_index=True)\n","df_view_ordered = df_view_ordered.append([df_view.loc[df_view[\"Method\"] == 'Saliency']], ignore_index=True)\n","df_view_ordered = df_view_ordered.append([df_view.loc[df_view[\"Method\"] == 'Occlusion']], ignore_index=True)\n","df_view_ordered = df_view_ordered.append([df_view.loc[df_view[\"Method\"] == 'Integrated\\nGradients']], ignore_index=True)\n","df_view_ordered = df_view_ordered.append([df_view.loc[df_view[\"Method\"] == 'GradShap']], ignore_index=True)\n","df_view_ordered"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sl-xLLHInca-"},"source":["### 2.2 Plot results!"]},{"cell_type":"code","metadata":{"id":"f0-S_5iRnevU"},"source":["plt.style.use('seaborn-white')\n","sns.set(font_scale=1.5)\n","\n","path = \"drive/MyDrive/Projects/\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"F3kR9dNzhNi7"},"source":["# Plot 1!\n","\n","fig, ax = plt.subplots(figsize=(6.5,5))\n","ax = sns.histplot(x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette=\"colorblind\", legend=False)\n","ax.spines[\"right\"].set_visible(False)\n","ax.spines['top'].set_visible(False)\n","ax.tick_params(axis='both', which='major', labelsize=16)\n","ax.set_ylabel('Frequency of rank')\n","ax.set_xlabel('')\n","ax.set_xticklabels([\"A\", \"B\", \"C\", \"D\", \"SAL\", \"OCC\", \"IG\", \"GD\"])\n","ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n","plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4, fancybox=True, shadow=False, labels=['1st', \"2nd\", \"3rd\", \"4th\"])\n","plt.axvline(x=2.5, ymax=0.95, color='black', linestyle='-')\n","plt.tight_layout()\n","plt.savefig(f'{path}sensitivity_analysis_1.png', dpi = 400)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"qDYcEaQstVax"},"source":["# Plot 2!\n","\n","ax = sns.catplot(x=\"Baseline strategy\", y=\"Rank\", hue=\"Method\", kind=\"bar\", estimator=np.mean, hue_order=['SAL', 'OCC', 'IG', \"GD\"],\n"," data=df, palette=sns.color_palette(\"husl\", 3), legend=False, height=5, aspect=7.5/5.8)\n","plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4, fancybox=True, shadow=False)\n","ax.set_ylabels('Mean Rank')\n","ax.set_xlabels('')\n","plt.tight_layout()\n","plt.savefig(f'{path}sensitivity_analysis_2.png', dpi = 400)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mc8Rc6fyelnL"},"source":["# Plot 3!\n","df_subset = df.loc[(df[\"Iterations\"] == 100) & (df[\"Subset size\"] == 102) & (df[\"Similarity function\"] == \"spearman\")]\n","ax = sns.catplot(x=\"Baseline strategy\", y=\"Rank\", hue=\"Method\", kind=\"bar\", estimator=np.mean, hue_order=['Saliency', 'Occlusion', 'Integrated\\nGradients'],\n"," data=df_subset, palette=sns.color_palette(\"husl\", 3), legend=False) \n","ax.set_xlabels('')\n","ax.set_ylabels('')\n","plt.legend(handles=[\"\", \"\", \"\"], labels=[\"\", \"\", \"\"], loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=False, shadow=False)\n","plt.tight_layout()\n","plt.show()"],"execution_count":null,"outputs":[]}]}