{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Tks9e_G7fzRX"
   },
   "source": [
    "# Model Interpretation Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0Ioclbd4f4jg"
   },
   "source": [
    "Welcome to the final assignment of course 3! In this assignment we will focus on the interpretation of machine learning and deep learning models. Using the techniques we've learned this week we'll revisit some of the models we've built throughout the course and try to understand a little more about what they're doing.\n",
    "\n",
    "In this assignment you'll use various methods to interpret different types of machine learning models. In particular, you'll learn about the following topics:\n",
    "\n",
    "- Interpreting Deep Learning Models\n",
    "    - Understanding output using GradCAMs\n",
    "- Feature Importance in Machine Learning\n",
    "    - Permutation Method\n",
    "    - SHAP Values\n",
    "\n",
    "Let's get started."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### This assignment covers the folowing topics:\n",
    "\n",
    "- [1. Interpreting Deep Learning Models](#1)\n",
    "  - [1.1 GradCAM](#1-1)\n",
    "    - [1.1.1 Getting Intermediate Layers](#1-1-1)\n",
    "    - [1.1.2 Getting Gradients](#1-1-2)\n",
    "    - [1.1.3 Implementing GradCAM](#1-1-3)\n",
    "      - [Exercise 1](#ex-01)\n",
    "    - [1.1.4 Using GradCAM to Visualize Multiple Labels](#1-1-4)\n",
    "      - [Exercise 2](#ex-02)\n",
    "- [2. Feature Importance in Machine Learning](#2)\n",
    "  - [2.1 Permuation Method for Feature Importance](#2-1)\n",
    "    - [2.1.1 Implementing Permutation](#2-1-1)\n",
    "      - [Exercise 3](#ex-03)\n",
    "    - [2.1.2 Implementing Importance](#2-1-2)\n",
    "      - [Exercise 4](#ex-04)\n",
    "    - [2.1.3 Computing our Feature Importance](#2-1-3)\n",
    "  - [2.2 Shapley Values for Random Forests](#2-2)\n",
    "    - [2.2.1 Visualizing Feature Importance on Specific Individuals](#2-2-1)\n",
    "    - [2.2.2 Visualizing Feature Importance on Aggregate](#2-2-2)\n",
    "    - [2.2.3 Visualizing Interactions between Features](#2-2-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "p4RvDHBOha2Y"
   },
   "source": [
    "## Packages\n",
    "\n",
    "We'll first import the necessary packages for this assignment.\n",
    "\n",
    "- `keras`: we'll use this framework to interact with our deep learning model\n",
    "- `matplotlib`: standard plotting library\n",
    "- `pandas`: we'll use this to manipulate data\n",
    "- `numpy`: standard python library for numerical operations\n",
    "- `cv2`: library that contains convenience functions for image processing\n",
    "- `sklearn`: standard machine learning library\n",
    "- `lifelines`: we'll use their implementation of the c-index\n",
    "- `shap`: library for interpreting and visualizing machine learning models using shapley values\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 80
    },
    "colab_type": "code",
    "id": "i9OcyAhSesQc",
    "outputId": "b32cbd74-6f95-476e-e0db-c5739e886d21"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras import backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "import sklearn\n",
    "import lifelines\n",
    "import shap\n",
    "\n",
    "\n",
    "from util import *\n",
    "\n",
    "# This sets a common size for all the figures we will draw.\n",
    "plt.rcParams['figure.figsize'] = [10, 7]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RaIViDj8khSg"
   },
   "source": [
    "<a name=\"1\"></a>\n",
    "## 1 Interpreting Deep Learning Models\n",
    "\n",
    "To start, let's try understanding our X-ray diagnostic model from Course 1 Week 1. Run the next cell to load in the model (it should take a few seconds to complete)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 632
    },
    "colab_type": "code",
    "id": "vrzRJFrXhi6x",
    "outputId": "1ec745f3-a764-436a-c1cb-74ed4e5c73ae"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got loss weights\n",
      "Loaded DenseNet\n",
      "Added layers\n",
      "Compiled Model\n",
      "Loaded Weights\n"
     ]
    }
   ],
   "source": [
    "model = load_C3M3_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "B07VP7edyb98"
   },
   "source": [
    "Let's load in an X-ray image to develop on. Run the next cell to load and show the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 432
    },
    "colab_type": "code",
    "id": "cVVXgMweyGtz",
    "outputId": "37992056-b6ea-4316-a44d-6c98b7074423",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "IMAGE_DIR = 'nih_new/images-small/'\n",
    "df = pd.read_csv(\"nih_new/train-small.csv\")\n",
    "im_path = IMAGE_DIR + '00025288_001.png' \n",
    "x = load_image(im_path, df, preprocess=False)\n",
    "plt.imshow(x, cmap = 'gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1N1_YANDztAo"
   },
   "source": [
    "Next, let's get our predictions. Before we plug the image into our model, we have to normalize it. Run the next cell to compute the mean and standard deviation of the images in our training set. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "s5I91kMHxUiz"
   },
   "outputs": [],
   "source": [
    "mean, std = get_mean_std_per_batch(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IGX2R05t6ZLA"
   },
   "source": [
    "Now we are ready to normalize and run the image through our model to get predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 529
    },
    "colab_type": "code",
    "id": "-UG6DAnUzxk0",
    "outputId": "13fedfec-3aed-4830-e435-827bb1de1a9d"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "labels = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 'Atelectasis',\n",
    "              'Pneumothorax', 'Pleural_Thickening', 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation']\n",
    "\n",
    "processed_image = load_image_normalize(im_path, mean, std)\n",
    "preds = model.predict(processed_image)\n",
    "pred_df = pd.DataFrame(preds, columns = labels)\n",
    "pred_df.loc[0, :].plot.bar()\n",
    "plt.title(\"Predictions\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UpNVTzl6002K"
   },
   "source": [
    "We see, for example, that the model predicts Mass (abnormal spot or area in the lungs that are more than 3 centimeters) with high probability. Indeed, this patient was diagnosed with mass. However, we don't know where the model is looking when it's making its own diagnosis. To gain more insight into what the model is looking at, we can use GradCAMs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iKeH6_eBDCho"
   },
   "source": [
    "<a name=\"1-1\"></a>\n",
    "### 1.1 GradCAM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GradCAM is a technique to visualize the impact of each region of an image on a specific output for a Convolutional Neural Network model. Through GradCAM, we can generate a heatmap by computing gradients of the specific class scores we are interested in visualizing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pbKwnYWRHmxa"
   },
   "source": [
    "<a name=\"1-1-1\"></a>\n",
    "#### 1.1.1 Getting Intermediate Layers\n",
    "\n",
    "Perhaps the most complicated part of computing GradCAM is accessing intermediate activations in our deep learning model and computing gradients with respect to the class output. Now we'll go over one pattern to accomplish this, which you can use when implementing GradCAM.\n",
    "\n",
    "In order to understand how to access intermediate layers in a computation, first let's see the layers that our model is composed of. This can be done by calling Keras convenience function `model.summary()`. Do this in the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, None, None, 3 0                                            \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding2d_1 (ZeroPadding2D (None, None, None, 3 0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "conv1/conv (Conv2D)             (None, None, None, 6 9408        zero_padding2d_1[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv1/bn (BatchNormalization)   (None, None, None, 6 256         conv1/conv[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1/relu (Activation)         (None, None, None, 6 0           conv1/bn[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding2d_2 (ZeroPadding2D (None, None, None, 6 0           conv1/relu[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "pool1 (MaxPooling2D)            (None, None, None, 6 0           zero_padding2d_2[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_0_bn (BatchNormali (None, None, None, 6 256         pool1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_0_relu (Activation (None, None, None, 6 0           conv2_block1_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_conv (Conv2D)    (None, None, None, 1 8192        conv2_block1_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_bn (BatchNormali (None, None, None, 1 512         conv2_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_relu (Activation (None, None, None, 1 0           conv2_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_concat (Concatenat (None, None, None, 9 0           pool1[0][0]                      \n",
      "                                                                 conv2_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_0_bn (BatchNormali (None, None, None, 9 384         conv2_block1_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_0_relu (Activation (None, None, None, 9 0           conv2_block2_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_conv (Conv2D)    (None, None, None, 1 12288       conv2_block2_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_bn (BatchNormali (None, None, None, 1 512         conv2_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_relu (Activation (None, None, None, 1 0           conv2_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_concat (Concatenat (None, None, None, 1 0           conv2_block1_concat[0][0]        \n",
      "                                                                 conv2_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_0_bn (BatchNormali (None, None, None, 1 512         conv2_block2_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_0_relu (Activation (None, None, None, 1 0           conv2_block3_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_conv (Conv2D)    (None, None, None, 1 16384       conv2_block3_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_bn (BatchNormali (None, None, None, 1 512         conv2_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_relu (Activation (None, None, None, 1 0           conv2_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_concat (Concatenat (None, None, None, 1 0           conv2_block2_concat[0][0]        \n",
      "                                                                 conv2_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_0_bn (BatchNormali (None, None, None, 1 640         conv2_block3_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_0_relu (Activation (None, None, None, 1 0           conv2_block4_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_1_conv (Conv2D)    (None, None, None, 1 20480       conv2_block4_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_1_bn (BatchNormali (None, None, None, 1 512         conv2_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_1_relu (Activation (None, None, None, 1 0           conv2_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block4_concat (Concatenat (None, None, None, 1 0           conv2_block3_concat[0][0]        \n",
      "                                                                 conv2_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_0_bn (BatchNormali (None, None, None, 1 768         conv2_block4_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_0_relu (Activation (None, None, None, 1 0           conv2_block5_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_1_conv (Conv2D)    (None, None, None, 1 24576       conv2_block5_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_1_bn (BatchNormali (None, None, None, 1 512         conv2_block5_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_1_relu (Activation (None, None, None, 1 0           conv2_block5_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block5_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block5_concat (Concatenat (None, None, None, 2 0           conv2_block4_concat[0][0]        \n",
      "                                                                 conv2_block5_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_0_bn (BatchNormali (None, None, None, 2 896         conv2_block5_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_0_relu (Activation (None, None, None, 2 0           conv2_block6_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_1_conv (Conv2D)    (None, None, None, 1 28672       conv2_block6_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_1_bn (BatchNormali (None, None, None, 1 512         conv2_block6_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_1_relu (Activation (None, None, None, 1 0           conv2_block6_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_2_conv (Conv2D)    (None, None, None, 3 36864       conv2_block6_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block6_concat (Concatenat (None, None, None, 2 0           conv2_block5_concat[0][0]        \n",
      "                                                                 conv2_block6_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "pool2_bn (BatchNormalization)   (None, None, None, 2 1024        conv2_block6_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "pool2_relu (Activation)         (None, None, None, 2 0           pool2_bn[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "pool2_conv (Conv2D)             (None, None, None, 1 32768       pool2_relu[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "pool2_pool (AveragePooling2D)   (None, None, None, 1 0           pool2_conv[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_0_bn (BatchNormali (None, None, None, 1 512         pool2_pool[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_0_relu (Activation (None, None, None, 1 0           conv3_block1_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_conv (Conv2D)    (None, None, None, 1 16384       conv3_block1_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_bn (BatchNormali (None, None, None, 1 512         conv3_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_relu (Activation (None, None, None, 1 0           conv3_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_concat (Concatenat (None, None, None, 1 0           pool2_pool[0][0]                 \n",
      "                                                                 conv3_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_0_bn (BatchNormali (None, None, None, 1 640         conv3_block1_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_0_relu (Activation (None, None, None, 1 0           conv3_block2_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_conv (Conv2D)    (None, None, None, 1 20480       conv3_block2_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_bn (BatchNormali (None, None, None, 1 512         conv3_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_relu (Activation (None, None, None, 1 0           conv3_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_concat (Concatenat (None, None, None, 1 0           conv3_block1_concat[0][0]        \n",
      "                                                                 conv3_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_0_bn (BatchNormali (None, None, None, 1 768         conv3_block2_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_0_relu (Activation (None, None, None, 1 0           conv3_block3_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_conv (Conv2D)    (None, None, None, 1 24576       conv3_block3_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_bn (BatchNormali (None, None, None, 1 512         conv3_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_relu (Activation (None, None, None, 1 0           conv3_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_concat (Concatenat (None, None, None, 2 0           conv3_block2_concat[0][0]        \n",
      "                                                                 conv3_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_0_bn (BatchNormali (None, None, None, 2 896         conv3_block3_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_0_relu (Activation (None, None, None, 2 0           conv3_block4_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_conv (Conv2D)    (None, None, None, 1 28672       conv3_block4_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_bn (BatchNormali (None, None, None, 1 512         conv3_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_relu (Activation (None, None, None, 1 0           conv3_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_concat (Concatenat (None, None, None, 2 0           conv3_block3_concat[0][0]        \n",
      "                                                                 conv3_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_0_bn (BatchNormali (None, None, None, 2 1024        conv3_block4_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_0_relu (Activation (None, None, None, 2 0           conv3_block5_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_1_conv (Conv2D)    (None, None, None, 1 32768       conv3_block5_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_1_bn (BatchNormali (None, None, None, 1 512         conv3_block5_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_1_relu (Activation (None, None, None, 1 0           conv3_block5_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block5_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block5_concat (Concatenat (None, None, None, 2 0           conv3_block4_concat[0][0]        \n",
      "                                                                 conv3_block5_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_0_bn (BatchNormali (None, None, None, 2 1152        conv3_block5_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_0_relu (Activation (None, None, None, 2 0           conv3_block6_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_1_conv (Conv2D)    (None, None, None, 1 36864       conv3_block6_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_1_bn (BatchNormali (None, None, None, 1 512         conv3_block6_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_1_relu (Activation (None, None, None, 1 0           conv3_block6_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block6_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block6_concat (Concatenat (None, None, None, 3 0           conv3_block5_concat[0][0]        \n",
      "                                                                 conv3_block6_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_0_bn (BatchNormali (None, None, None, 3 1280        conv3_block6_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_0_relu (Activation (None, None, None, 3 0           conv3_block7_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_1_conv (Conv2D)    (None, None, None, 1 40960       conv3_block7_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_1_bn (BatchNormali (None, None, None, 1 512         conv3_block7_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_1_relu (Activation (None, None, None, 1 0           conv3_block7_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block7_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block7_concat (Concatenat (None, None, None, 3 0           conv3_block6_concat[0][0]        \n",
      "                                                                 conv3_block7_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_0_bn (BatchNormali (None, None, None, 3 1408        conv3_block7_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_0_relu (Activation (None, None, None, 3 0           conv3_block8_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_1_conv (Conv2D)    (None, None, None, 1 45056       conv3_block8_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_1_bn (BatchNormali (None, None, None, 1 512         conv3_block8_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_1_relu (Activation (None, None, None, 1 0           conv3_block8_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block8_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block8_concat (Concatenat (None, None, None, 3 0           conv3_block7_concat[0][0]        \n",
      "                                                                 conv3_block8_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_0_bn (BatchNormali (None, None, None, 3 1536        conv3_block8_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_0_relu (Activation (None, None, None, 3 0           conv3_block9_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_1_conv (Conv2D)    (None, None, None, 1 49152       conv3_block9_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_1_bn (BatchNormali (None, None, None, 1 512         conv3_block9_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_1_relu (Activation (None, None, None, 1 0           conv3_block9_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_2_conv (Conv2D)    (None, None, None, 3 36864       conv3_block9_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block9_concat (Concatenat (None, None, None, 4 0           conv3_block8_concat[0][0]        \n",
      "                                                                 conv3_block9_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_0_bn (BatchNormal (None, None, None, 4 1664        conv3_block9_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_0_relu (Activatio (None, None, None, 4 0           conv3_block10_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_1_conv (Conv2D)   (None, None, None, 1 53248       conv3_block10_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_1_bn (BatchNormal (None, None, None, 1 512         conv3_block10_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_1_relu (Activatio (None, None, None, 1 0           conv3_block10_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_2_conv (Conv2D)   (None, None, None, 3 36864       conv3_block10_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block10_concat (Concatena (None, None, None, 4 0           conv3_block9_concat[0][0]        \n",
      "                                                                 conv3_block10_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_0_bn (BatchNormal (None, None, None, 4 1792        conv3_block10_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_0_relu (Activatio (None, None, None, 4 0           conv3_block11_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_1_conv (Conv2D)   (None, None, None, 1 57344       conv3_block11_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_1_bn (BatchNormal (None, None, None, 1 512         conv3_block11_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_1_relu (Activatio (None, None, None, 1 0           conv3_block11_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_2_conv (Conv2D)   (None, None, None, 3 36864       conv3_block11_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block11_concat (Concatena (None, None, None, 4 0           conv3_block10_concat[0][0]       \n",
      "                                                                 conv3_block11_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_0_bn (BatchNormal (None, None, None, 4 1920        conv3_block11_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_0_relu (Activatio (None, None, None, 4 0           conv3_block12_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_1_conv (Conv2D)   (None, None, None, 1 61440       conv3_block12_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_1_bn (BatchNormal (None, None, None, 1 512         conv3_block12_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_1_relu (Activatio (None, None, None, 1 0           conv3_block12_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_2_conv (Conv2D)   (None, None, None, 3 36864       conv3_block12_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block12_concat (Concatena (None, None, None, 5 0           conv3_block11_concat[0][0]       \n",
      "                                                                 conv3_block12_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "pool3_bn (BatchNormalization)   (None, None, None, 5 2048        conv3_block12_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "pool3_relu (Activation)         (None, None, None, 5 0           pool3_bn[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "pool3_conv (Conv2D)             (None, None, None, 2 131072      pool3_relu[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "pool3_pool (AveragePooling2D)   (None, None, None, 2 0           pool3_conv[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_0_bn (BatchNormali (None, None, None, 2 1024        pool3_pool[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_0_relu (Activation (None, None, None, 2 0           conv4_block1_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_conv (Conv2D)    (None, None, None, 1 32768       conv4_block1_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_bn (BatchNormali (None, None, None, 1 512         conv4_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_relu (Activation (None, None, None, 1 0           conv4_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_concat (Concatenat (None, None, None, 2 0           pool3_pool[0][0]                 \n",
      "                                                                 conv4_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_0_bn (BatchNormali (None, None, None, 2 1152        conv4_block1_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_0_relu (Activation (None, None, None, 2 0           conv4_block2_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_conv (Conv2D)    (None, None, None, 1 36864       conv4_block2_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_bn (BatchNormali (None, None, None, 1 512         conv4_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_relu (Activation (None, None, None, 1 0           conv4_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_concat (Concatenat (None, None, None, 3 0           conv4_block1_concat[0][0]        \n",
      "                                                                 conv4_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_0_bn (BatchNormali (None, None, None, 3 1280        conv4_block2_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_0_relu (Activation (None, None, None, 3 0           conv4_block3_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_conv (Conv2D)    (None, None, None, 1 40960       conv4_block3_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_bn (BatchNormali (None, None, None, 1 512         conv4_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_relu (Activation (None, None, None, 1 0           conv4_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_concat (Concatenat (None, None, None, 3 0           conv4_block2_concat[0][0]        \n",
      "                                                                 conv4_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_0_bn (BatchNormali (None, None, None, 3 1408        conv4_block3_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_0_relu (Activation (None, None, None, 3 0           conv4_block4_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_conv (Conv2D)    (None, None, None, 1 45056       conv4_block4_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_bn (BatchNormali (None, None, None, 1 512         conv4_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_relu (Activation (None, None, None, 1 0           conv4_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_concat (Concatenat (None, None, None, 3 0           conv4_block3_concat[0][0]        \n",
      "                                                                 conv4_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_0_bn (BatchNormali (None, None, None, 3 1536        conv4_block4_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_0_relu (Activation (None, None, None, 3 0           conv4_block5_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_conv (Conv2D)    (None, None, None, 1 49152       conv4_block5_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_bn (BatchNormali (None, None, None, 1 512         conv4_block5_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_relu (Activation (None, None, None, 1 0           conv4_block5_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block5_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_concat (Concatenat (None, None, None, 4 0           conv4_block4_concat[0][0]        \n",
      "                                                                 conv4_block5_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_0_bn (BatchNormali (None, None, None, 4 1664        conv4_block5_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_0_relu (Activation (None, None, None, 4 0           conv4_block6_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_conv (Conv2D)    (None, None, None, 1 53248       conv4_block6_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_bn (BatchNormali (None, None, None, 1 512         conv4_block6_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_relu (Activation (None, None, None, 1 0           conv4_block6_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block6_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_concat (Concatenat (None, None, None, 4 0           conv4_block5_concat[0][0]        \n",
      "                                                                 conv4_block6_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_0_bn (BatchNormali (None, None, None, 4 1792        conv4_block6_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_0_relu (Activation (None, None, None, 4 0           conv4_block7_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_1_conv (Conv2D)    (None, None, None, 1 57344       conv4_block7_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_1_bn (BatchNormali (None, None, None, 1 512         conv4_block7_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_1_relu (Activation (None, None, None, 1 0           conv4_block7_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block7_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block7_concat (Concatenat (None, None, None, 4 0           conv4_block6_concat[0][0]        \n",
      "                                                                 conv4_block7_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_0_bn (BatchNormali (None, None, None, 4 1920        conv4_block7_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_0_relu (Activation (None, None, None, 4 0           conv4_block8_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_1_conv (Conv2D)    (None, None, None, 1 61440       conv4_block8_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_1_bn (BatchNormali (None, None, None, 1 512         conv4_block8_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_1_relu (Activation (None, None, None, 1 0           conv4_block8_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block8_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block8_concat (Concatenat (None, None, None, 5 0           conv4_block7_concat[0][0]        \n",
      "                                                                 conv4_block8_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_0_bn (BatchNormali (None, None, None, 5 2048        conv4_block8_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_0_relu (Activation (None, None, None, 5 0           conv4_block9_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_1_conv (Conv2D)    (None, None, None, 1 65536       conv4_block9_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_1_bn (BatchNormali (None, None, None, 1 512         conv4_block9_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_1_relu (Activation (None, None, None, 1 0           conv4_block9_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_2_conv (Conv2D)    (None, None, None, 3 36864       conv4_block9_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block9_concat (Concatenat (None, None, None, 5 0           conv4_block8_concat[0][0]        \n",
      "                                                                 conv4_block9_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_0_bn (BatchNormal (None, None, None, 5 2176        conv4_block9_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_0_relu (Activatio (None, None, None, 5 0           conv4_block10_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_1_conv (Conv2D)   (None, None, None, 1 69632       conv4_block10_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_1_bn (BatchNormal (None, None, None, 1 512         conv4_block10_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_1_relu (Activatio (None, None, None, 1 0           conv4_block10_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block10_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block10_concat (Concatena (None, None, None, 5 0           conv4_block9_concat[0][0]        \n",
      "                                                                 conv4_block10_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_0_bn (BatchNormal (None, None, None, 5 2304        conv4_block10_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_0_relu (Activatio (None, None, None, 5 0           conv4_block11_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_1_conv (Conv2D)   (None, None, None, 1 73728       conv4_block11_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_1_bn (BatchNormal (None, None, None, 1 512         conv4_block11_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_1_relu (Activatio (None, None, None, 1 0           conv4_block11_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block11_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block11_concat (Concatena (None, None, None, 6 0           conv4_block10_concat[0][0]       \n",
      "                                                                 conv4_block11_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_0_bn (BatchNormal (None, None, None, 6 2432        conv4_block11_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_0_relu (Activatio (None, None, None, 6 0           conv4_block12_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_1_conv (Conv2D)   (None, None, None, 1 77824       conv4_block12_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_1_bn (BatchNormal (None, None, None, 1 512         conv4_block12_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_1_relu (Activatio (None, None, None, 1 0           conv4_block12_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block12_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block12_concat (Concatena (None, None, None, 6 0           conv4_block11_concat[0][0]       \n",
      "                                                                 conv4_block12_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_0_bn (BatchNormal (None, None, None, 6 2560        conv4_block12_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_0_relu (Activatio (None, None, None, 6 0           conv4_block13_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_1_conv (Conv2D)   (None, None, None, 1 81920       conv4_block13_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_1_bn (BatchNormal (None, None, None, 1 512         conv4_block13_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_1_relu (Activatio (None, None, None, 1 0           conv4_block13_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block13_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block13_concat (Concatena (None, None, None, 6 0           conv4_block12_concat[0][0]       \n",
      "                                                                 conv4_block13_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_0_bn (BatchNormal (None, None, None, 6 2688        conv4_block13_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_0_relu (Activatio (None, None, None, 6 0           conv4_block14_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_1_conv (Conv2D)   (None, None, None, 1 86016       conv4_block14_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_1_bn (BatchNormal (None, None, None, 1 512         conv4_block14_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_1_relu (Activatio (None, None, None, 1 0           conv4_block14_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block14_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block14_concat (Concatena (None, None, None, 7 0           conv4_block13_concat[0][0]       \n",
      "                                                                 conv4_block14_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_0_bn (BatchNormal (None, None, None, 7 2816        conv4_block14_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_0_relu (Activatio (None, None, None, 7 0           conv4_block15_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_1_conv (Conv2D)   (None, None, None, 1 90112       conv4_block15_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_1_bn (BatchNormal (None, None, None, 1 512         conv4_block15_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_1_relu (Activatio (None, None, None, 1 0           conv4_block15_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block15_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block15_concat (Concatena (None, None, None, 7 0           conv4_block14_concat[0][0]       \n",
      "                                                                 conv4_block15_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_0_bn (BatchNormal (None, None, None, 7 2944        conv4_block15_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_0_relu (Activatio (None, None, None, 7 0           conv4_block16_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_1_conv (Conv2D)   (None, None, None, 1 94208       conv4_block16_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_1_bn (BatchNormal (None, None, None, 1 512         conv4_block16_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_1_relu (Activatio (None, None, None, 1 0           conv4_block16_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block16_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block16_concat (Concatena (None, None, None, 7 0           conv4_block15_concat[0][0]       \n",
      "                                                                 conv4_block16_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_0_bn (BatchNormal (None, None, None, 7 3072        conv4_block16_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_0_relu (Activatio (None, None, None, 7 0           conv4_block17_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_1_conv (Conv2D)   (None, None, None, 1 98304       conv4_block17_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_1_bn (BatchNormal (None, None, None, 1 512         conv4_block17_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_1_relu (Activatio (None, None, None, 1 0           conv4_block17_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block17_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block17_concat (Concatena (None, None, None, 8 0           conv4_block16_concat[0][0]       \n",
      "                                                                 conv4_block17_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_0_bn (BatchNormal (None, None, None, 8 3200        conv4_block17_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_0_relu (Activatio (None, None, None, 8 0           conv4_block18_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_1_conv (Conv2D)   (None, None, None, 1 102400      conv4_block18_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_1_bn (BatchNormal (None, None, None, 1 512         conv4_block18_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_1_relu (Activatio (None, None, None, 1 0           conv4_block18_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block18_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block18_concat (Concatena (None, None, None, 8 0           conv4_block17_concat[0][0]       \n",
      "                                                                 conv4_block18_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_0_bn (BatchNormal (None, None, None, 8 3328        conv4_block18_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_0_relu (Activatio (None, None, None, 8 0           conv4_block19_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_1_conv (Conv2D)   (None, None, None, 1 106496      conv4_block19_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_1_bn (BatchNormal (None, None, None, 1 512         conv4_block19_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_1_relu (Activatio (None, None, None, 1 0           conv4_block19_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block19_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block19_concat (Concatena (None, None, None, 8 0           conv4_block18_concat[0][0]       \n",
      "                                                                 conv4_block19_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_0_bn (BatchNormal (None, None, None, 8 3456        conv4_block19_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_0_relu (Activatio (None, None, None, 8 0           conv4_block20_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_1_conv (Conv2D)   (None, None, None, 1 110592      conv4_block20_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_1_bn (BatchNormal (None, None, None, 1 512         conv4_block20_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_1_relu (Activatio (None, None, None, 1 0           conv4_block20_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block20_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block20_concat (Concatena (None, None, None, 8 0           conv4_block19_concat[0][0]       \n",
      "                                                                 conv4_block20_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_0_bn (BatchNormal (None, None, None, 8 3584        conv4_block20_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_0_relu (Activatio (None, None, None, 8 0           conv4_block21_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_1_conv (Conv2D)   (None, None, None, 1 114688      conv4_block21_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_1_bn (BatchNormal (None, None, None, 1 512         conv4_block21_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_1_relu (Activatio (None, None, None, 1 0           conv4_block21_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block21_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block21_concat (Concatena (None, None, None, 9 0           conv4_block20_concat[0][0]       \n",
      "                                                                 conv4_block21_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_0_bn (BatchNormal (None, None, None, 9 3712        conv4_block21_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_0_relu (Activatio (None, None, None, 9 0           conv4_block22_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_1_conv (Conv2D)   (None, None, None, 1 118784      conv4_block22_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_1_bn (BatchNormal (None, None, None, 1 512         conv4_block22_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_1_relu (Activatio (None, None, None, 1 0           conv4_block22_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block22_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block22_concat (Concatena (None, None, None, 9 0           conv4_block21_concat[0][0]       \n",
      "                                                                 conv4_block22_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_0_bn (BatchNormal (None, None, None, 9 3840        conv4_block22_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_0_relu (Activatio (None, None, None, 9 0           conv4_block23_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_1_conv (Conv2D)   (None, None, None, 1 122880      conv4_block23_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_1_bn (BatchNormal (None, None, None, 1 512         conv4_block23_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_1_relu (Activatio (None, None, None, 1 0           conv4_block23_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block23_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block23_concat (Concatena (None, None, None, 9 0           conv4_block22_concat[0][0]       \n",
      "                                                                 conv4_block23_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_0_bn (BatchNormal (None, None, None, 9 3968        conv4_block23_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_0_relu (Activatio (None, None, None, 9 0           conv4_block24_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_1_conv (Conv2D)   (None, None, None, 1 126976      conv4_block24_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_1_bn (BatchNormal (None, None, None, 1 512         conv4_block24_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_1_relu (Activatio (None, None, None, 1 0           conv4_block24_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_2_conv (Conv2D)   (None, None, None, 3 36864       conv4_block24_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block24_concat (Concatena (None, None, None, 1 0           conv4_block23_concat[0][0]       \n",
      "                                                                 conv4_block24_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "pool4_bn (BatchNormalization)   (None, None, None, 1 4096        conv4_block24_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "pool4_relu (Activation)         (None, None, None, 1 0           pool4_bn[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "pool4_conv (Conv2D)             (None, None, None, 5 524288      pool4_relu[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "pool4_pool (AveragePooling2D)   (None, None, None, 5 0           pool4_conv[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_0_bn (BatchNormali (None, None, None, 5 2048        pool4_pool[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_0_relu (Activation (None, None, None, 5 0           conv5_block1_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_conv (Conv2D)    (None, None, None, 1 65536       conv5_block1_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_bn (BatchNormali (None, None, None, 1 512         conv5_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_relu (Activation (None, None, None, 1 0           conv5_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_concat (Concatenat (None, None, None, 5 0           pool4_pool[0][0]                 \n",
      "                                                                 conv5_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_0_bn (BatchNormali (None, None, None, 5 2176        conv5_block1_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_0_relu (Activation (None, None, None, 5 0           conv5_block2_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_conv (Conv2D)    (None, None, None, 1 69632       conv5_block2_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_bn (BatchNormali (None, None, None, 1 512         conv5_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_relu (Activation (None, None, None, 1 0           conv5_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_concat (Concatenat (None, None, None, 5 0           conv5_block1_concat[0][0]        \n",
      "                                                                 conv5_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_0_bn (BatchNormali (None, None, None, 5 2304        conv5_block2_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_0_relu (Activation (None, None, None, 5 0           conv5_block3_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_conv (Conv2D)    (None, None, None, 1 73728       conv5_block3_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_bn (BatchNormali (None, None, None, 1 512         conv5_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_relu (Activation (None, None, None, 1 0           conv5_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_concat (Concatenat (None, None, None, 6 0           conv5_block2_concat[0][0]        \n",
      "                                                                 conv5_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_0_bn (BatchNormali (None, None, None, 6 2432        conv5_block3_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_0_relu (Activation (None, None, None, 6 0           conv5_block4_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_1_conv (Conv2D)    (None, None, None, 1 77824       conv5_block4_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_1_bn (BatchNormali (None, None, None, 1 512         conv5_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_1_relu (Activation (None, None, None, 1 0           conv5_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block4_concat (Concatenat (None, None, None, 6 0           conv5_block3_concat[0][0]        \n",
      "                                                                 conv5_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_0_bn (BatchNormali (None, None, None, 6 2560        conv5_block4_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_0_relu (Activation (None, None, None, 6 0           conv5_block5_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_1_conv (Conv2D)    (None, None, None, 1 81920       conv5_block5_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_1_bn (BatchNormali (None, None, None, 1 512         conv5_block5_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_1_relu (Activation (None, None, None, 1 0           conv5_block5_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block5_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block5_concat (Concatenat (None, None, None, 6 0           conv5_block4_concat[0][0]        \n",
      "                                                                 conv5_block5_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_0_bn (BatchNormali (None, None, None, 6 2688        conv5_block5_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_0_relu (Activation (None, None, None, 6 0           conv5_block6_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_1_conv (Conv2D)    (None, None, None, 1 86016       conv5_block6_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_1_bn (BatchNormali (None, None, None, 1 512         conv5_block6_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_1_relu (Activation (None, None, None, 1 0           conv5_block6_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block6_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block6_concat (Concatenat (None, None, None, 7 0           conv5_block5_concat[0][0]        \n",
      "                                                                 conv5_block6_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_0_bn (BatchNormali (None, None, None, 7 2816        conv5_block6_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_0_relu (Activation (None, None, None, 7 0           conv5_block7_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_1_conv (Conv2D)    (None, None, None, 1 90112       conv5_block7_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_1_bn (BatchNormali (None, None, None, 1 512         conv5_block7_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_1_relu (Activation (None, None, None, 1 0           conv5_block7_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block7_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block7_concat (Concatenat (None, None, None, 7 0           conv5_block6_concat[0][0]        \n",
      "                                                                 conv5_block7_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_0_bn (BatchNormali (None, None, None, 7 2944        conv5_block7_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_0_relu (Activation (None, None, None, 7 0           conv5_block8_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_1_conv (Conv2D)    (None, None, None, 1 94208       conv5_block8_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_1_bn (BatchNormali (None, None, None, 1 512         conv5_block8_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_1_relu (Activation (None, None, None, 1 0           conv5_block8_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block8_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block8_concat (Concatenat (None, None, None, 7 0           conv5_block7_concat[0][0]        \n",
      "                                                                 conv5_block8_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_0_bn (BatchNormali (None, None, None, 7 3072        conv5_block8_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_0_relu (Activation (None, None, None, 7 0           conv5_block9_0_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_1_conv (Conv2D)    (None, None, None, 1 98304       conv5_block9_0_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_1_bn (BatchNormali (None, None, None, 1 512         conv5_block9_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_1_relu (Activation (None, None, None, 1 0           conv5_block9_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_2_conv (Conv2D)    (None, None, None, 3 36864       conv5_block9_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block9_concat (Concatenat (None, None, None, 8 0           conv5_block8_concat[0][0]        \n",
      "                                                                 conv5_block9_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_0_bn (BatchNormal (None, None, None, 8 3200        conv5_block9_concat[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_0_relu (Activatio (None, None, None, 8 0           conv5_block10_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_1_conv (Conv2D)   (None, None, None, 1 102400      conv5_block10_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_1_bn (BatchNormal (None, None, None, 1 512         conv5_block10_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_1_relu (Activatio (None, None, None, 1 0           conv5_block10_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block10_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block10_concat (Concatena (None, None, None, 8 0           conv5_block9_concat[0][0]        \n",
      "                                                                 conv5_block10_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_0_bn (BatchNormal (None, None, None, 8 3328        conv5_block10_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_0_relu (Activatio (None, None, None, 8 0           conv5_block11_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_1_conv (Conv2D)   (None, None, None, 1 106496      conv5_block11_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_1_bn (BatchNormal (None, None, None, 1 512         conv5_block11_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_1_relu (Activatio (None, None, None, 1 0           conv5_block11_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block11_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block11_concat (Concatena (None, None, None, 8 0           conv5_block10_concat[0][0]       \n",
      "                                                                 conv5_block11_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_0_bn (BatchNormal (None, None, None, 8 3456        conv5_block11_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_0_relu (Activatio (None, None, None, 8 0           conv5_block12_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_1_conv (Conv2D)   (None, None, None, 1 110592      conv5_block12_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_1_bn (BatchNormal (None, None, None, 1 512         conv5_block12_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_1_relu (Activatio (None, None, None, 1 0           conv5_block12_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block12_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block12_concat (Concatena (None, None, None, 8 0           conv5_block11_concat[0][0]       \n",
      "                                                                 conv5_block12_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_0_bn (BatchNormal (None, None, None, 8 3584        conv5_block12_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_0_relu (Activatio (None, None, None, 8 0           conv5_block13_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_1_conv (Conv2D)   (None, None, None, 1 114688      conv5_block13_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_1_bn (BatchNormal (None, None, None, 1 512         conv5_block13_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_1_relu (Activatio (None, None, None, 1 0           conv5_block13_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block13_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block13_concat (Concatena (None, None, None, 9 0           conv5_block12_concat[0][0]       \n",
      "                                                                 conv5_block13_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_0_bn (BatchNormal (None, None, None, 9 3712        conv5_block13_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_0_relu (Activatio (None, None, None, 9 0           conv5_block14_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_1_conv (Conv2D)   (None, None, None, 1 118784      conv5_block14_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_1_bn (BatchNormal (None, None, None, 1 512         conv5_block14_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_1_relu (Activatio (None, None, None, 1 0           conv5_block14_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block14_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block14_concat (Concatena (None, None, None, 9 0           conv5_block13_concat[0][0]       \n",
      "                                                                 conv5_block14_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_0_bn (BatchNormal (None, None, None, 9 3840        conv5_block14_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_0_relu (Activatio (None, None, None, 9 0           conv5_block15_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_1_conv (Conv2D)   (None, None, None, 1 122880      conv5_block15_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_1_bn (BatchNormal (None, None, None, 1 512         conv5_block15_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_1_relu (Activatio (None, None, None, 1 0           conv5_block15_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block15_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block15_concat (Concatena (None, None, None, 9 0           conv5_block14_concat[0][0]       \n",
      "                                                                 conv5_block15_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_0_bn (BatchNormal (None, None, None, 9 3968        conv5_block15_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_0_relu (Activatio (None, None, None, 9 0           conv5_block16_0_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_1_conv (Conv2D)   (None, None, None, 1 126976      conv5_block16_0_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_1_bn (BatchNormal (None, None, None, 1 512         conv5_block16_1_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_1_relu (Activatio (None, None, None, 1 0           conv5_block16_1_bn[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_2_conv (Conv2D)   (None, None, None, 3 36864       conv5_block16_1_relu[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block16_concat (Concatena (None, None, None, 1 0           conv5_block15_concat[0][0]       \n",
      "                                                                 conv5_block16_2_conv[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "bn (BatchNormalization)         (None, None, None, 1 4096        conv5_block16_concat[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "global_average_pooling2d_1 (Glo (None, 1024)         0           bn[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 14)           14350       global_average_pooling2d_1[0][0] \n",
      "==================================================================================================\n",
      "Total params: 7,051,854\n",
      "Trainable params: 6,968,206\n",
      "Non-trainable params: 83,648\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CvyTsSRdJ2TP"
   },
   "source": [
    "There are a lot of layers, but typically we'll only be extracting one of the last few. Remember that the last few layers usually have more abstract information. To access a layer, we can use `model.get_layer(layer).output`, which takes in the name of the layer in question. Let's try getting the `conv5_block16_concat` layer, the raw output of the last convolutional layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "icUrvQF7KUJp",
    "outputId": "611d75e5-1345-489f-a01a-bb566ae230cf"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"conv5_block16_concat/concat:0\", shape=(?, ?, ?, 1024), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "spatial_maps =  model.get_layer('conv5_block16_concat').output\n",
    "print(spatial_maps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kCsIMf08KrWi"
   },
   "source": [
    "Now, this tensor is just a placeholder, it doesn't contain the actual activations for a particular image. To get this we will use [Keras.backend.function](https://www.tensorflow.org/api_docs/python/tf/keras/backend/function) to return intermediate computations while the model is processing a particular input. This method takes in an input and output placeholders and returns a function. This function will compute the intermediate output (until it reaches the given placeholder) evaluated given the input. For example, if you want the layer that you just retrieved (conv5_block16_concat), you could write the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "U4JEE37ALl-N",
    "outputId": "1de337cf-bec5-4edf-fb6d-4bbbe672bad8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<keras.backend.tensorflow_backend.Function object at 0x7faf4fe307b8>\n"
     ]
    }
   ],
   "source": [
    "get_spatial_maps = K.function([model.input], [spatial_maps])\n",
    "print(get_spatial_maps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "B2-soRaiL8xA"
   },
   "source": [
    "We see that we now have a `Function` object. Now, to get the actual intermediate output evaluated with a particular input, we just plug in an image to this function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x is of type <class 'numpy.ndarray'>\n",
      "x is of shape (1, 320, 320, 3)\n"
     ]
    }
   ],
   "source": [
    "# get an image\n",
    "x = load_image_normalize(im_path, mean, std)\n",
    "print(f\"x is of type {type(x)}\")\n",
    "print(f\"x is of shape {x.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x_l is of type <class 'list'>\n",
      "spatial_maps_x_l is has length 1\n"
     ]
    }
   ],
   "source": [
    "# get the spatial maps layer activations (a list of numpy arrays)\n",
    "spatial_maps_x_l = get_spatial_maps([x])\n",
    "\n",
    "print(f\"spatial_maps_x_l is of type {type(spatial_maps_x_l)}\")\n",
    "print(f\"spatial_maps_x_l is has length {len(spatial_maps_x_l)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x is of type <class 'numpy.ndarray'>\n",
      "spatial_maps_x is of shape (1, 10, 10, 1024)\n"
     ]
    }
   ],
   "source": [
    "# get the 0th item in the list\n",
    "spatial_maps_x = spatial_maps_x_l[0]\n",
    "print(f\"spatial_maps_x is of type {type(spatial_maps_x)}\")\n",
    "print(f\"spatial_maps_x is of shape {spatial_maps_x.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the shape is (1, 10, 10, 1024).  The 0th dimension of size 1 is the batch dimension.  Remove the batch dimension for later calculations by taking the 0th index of spatial_maps_x."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x without the batch dimension has shape (10, 10, 1024)\n",
      "Output some of the content:\n",
      "[[-0.46017444  0.20640776 -0.63506377 ...  0.1264174  -0.06400048\n",
      "   0.15870997]\n",
      " [-0.8125281  -0.29398838 -0.8967887  ...  0.21837974 -0.0994716\n",
      "   0.26966757]\n",
      " [-0.508806   -0.14127392 -0.5690727  ...  0.27967227 -0.11622357\n",
      "   0.318372  ]\n",
      " ...\n",
      " [-0.34813794 -0.3922896  -1.0565547  ...  0.17491409 -0.08235557\n",
      "   0.25179753]\n",
      " [-0.4438535  -0.32872048 -0.65662026 ...  0.21583238 -0.10991383\n",
      "   0.31518462]\n",
      " [-0.29580766  0.4920513  -0.2233113  ...  0.08722244 -0.04751847\n",
      "   0.17896183]]\n"
     ]
    }
   ],
   "source": [
    "# Get rid of the batch dimension\n",
    "spatial_maps_x = spatial_maps_x[0] # equivalent to spatial_maps_x[0,:]\n",
    "print(f\"spatial_maps_x without the batch dimension has shape {spatial_maps_x.shape}\")\n",
    "print(\"Output some of the content:\")\n",
    "print(spatial_maps_x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aRCSdXfRO6KI"
   },
   "source": [
    "We now have the activations for that particular image, and we can use it for interpretation. The function that is returned by calling `K.function([model.input], [spatial_maps])` (saved here in the variable `get_spatial_maps`) is sometimes referred to as a \"hook\", letting you peek into the intermediate computations in the model. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FJlxNWRyPQqV"
   },
   "source": [
    "<a name=\"1-1-2\"></a>\n",
    "#### 1.1.2 Getting Gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G12g9fOeaqyM"
   },
   "source": [
    "The other major step in computing GradCAMs is getting gradients with respect to the output for a particular class. Luckily, Keras makes getting gradients simple. We can use the [Keras.backend.gradients](https://www.tensorflow.org/api_docs/python/tf/keras/backend/gradients) function. The first parameter is the value you are taking the gradient of, and the second is the parameter you are taking that gradient with respect to. We illustrate below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model output includes batch dimension, has shape (?, 14)\n"
     ]
    }
   ],
   "source": [
    "# get the output of the model\n",
    "output_with_batch_dim = model.output\n",
    "print(f\"Model output includes batch dimension, has shape {output_with_batch_dim.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get the output without the batch dimension, you can take the 0th index of the tensor. Note that because the batch dimension is 'None', you could actually enter any integer index, but let's just use 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The output for all 14 categories of disease has shape (14,)\n"
     ]
    }
   ],
   "source": [
    "# Get the output without the batch dimension\n",
    "output_all_categories = output_with_batch_dim[0]\n",
    "print(f\"The output for all 14 categories of disease has shape {output_all_categories.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output has 14 categories, one for each disease category, indexed from 0 to 13. Cardiomegaly is the disease category at index 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Cardiomegaly output is at index 0, and has shape ()\n"
     ]
    }
   ],
   "source": [
    "# Get the first category's output (Cardiomegaly) at index 0\n",
    "y_category_0 = output_all_categories[0]\n",
    "print(f\"The Cardiomegaly output is at index 0, and has shape {y_category_0.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "gdnX8taUbF7h",
    "outputId": "646e7c7c-0ab6-492d-f6b5-26b30f4c1125"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradient_l is of type <class 'list'> and has length 1\n",
      "Tensor(\"gradients/AddN:0\", shape=(?, ?, ?, 1024), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# Get gradient of y_category_0 with respect to spatial_maps\n",
    "\n",
    "gradient_l = K.gradients(y_category_0, spatial_maps)\n",
    "print(f\"gradient_l is of type {type(gradient_l)} and has length {len(gradient_l)}\")\n",
    "\n",
    "# gradient_l is a list of size 1.  Get the gradient at index 0\n",
    "gradient = gradient_l[0]\n",
    "print(gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fj-b8lqHbrM5"
   },
   "source": [
    "Again, this is just a placeholder. Just like for intermediate layers, we can use `K.function` to compute the value of the gradient for a particular input.  \n",
    "\n",
    "The K.function() takes in\n",
    "- a list of inputs: in this case, one input, 'model.input'\n",
    "- a list of tensors: in this case, one output tensor 'gradient'\n",
    "\n",
    "It returns a function that calculates the activations of the list of tensors.\n",
    "- This returned function returns a list of the activations, one for each tensor that was passed into K.function()."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "keras.backend.tensorflow_backend.Function"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Create the function that gets the gradient\n",
    "get_gradient = K.function([model.input], [gradient])\n",
    "type(get_gradient)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X-ray image has shape (1, 320, 320, 3)\n"
     ]
    }
   ],
   "source": [
    "# get an input x-ray image\n",
    "x = load_image_normalize(im_path, mean, std)\n",
    "print(f\"X-ray image has shape {x.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `get_gradient` function takes in a list of inputs, and returns a list of the gradients, one for each image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grad_x_l is of type <class 'list'> and length 1\n",
      "grad_x_with_batch_dim is type <class 'numpy.ndarray'> and shape (1, 10, 10, 1024)\n",
      "grad_x is type <class 'numpy.ndarray'> and shape (10, 10, 1024)\n",
      "Gradient grad_x (show some of its content:\n",
      "[[-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " ...\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]]\n"
     ]
    }
   ],
   "source": [
    "# use the get_gradient function to get the gradient (pass in the input image inside a list)\n",
    "grad_x_l = get_gradient([x])\n",
    "print(f\"grad_x_l is of type {type(grad_x_l)} and length {len(grad_x_l)}\")\n",
    "\n",
    "# get the gradient at index 0 of the list.\n",
    "grad_x_with_batch_dim = grad_x_l[0]\n",
    "print(f\"grad_x_with_batch_dim is type {type(grad_x_with_batch_dim)} and shape {grad_x_with_batch_dim.shape}\")\n",
    "\n",
    "# To remove the batch dimension, take the value at index 0 of the batch dimension\n",
    "grad_x = grad_x_with_batch_dim[0]\n",
    "print(f\"grad_x is type {type(grad_x)} and shape {grad_x.shape}\")\n",
    "\n",
    "print(\"Gradient grad_x (show some of its content:\")\n",
    "print(grad_x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aw7DVK2Xc8gv"
   },
   "source": [
    "Just like we had a hook into the penultimate layer, we now have a hook into the gradient! This allows us to easily compute pretty much anything relevant to our model output. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ssEQgnCLdXcr"
   },
   "source": [
    "We can also combine the two to have one function call which gives us both the gradient and the last layer (this might come in handy when implementing GradCAM in the next section)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'keras.backend.tensorflow_backend.Function'>\n"
     ]
    }
   ],
   "source": [
    "# Use K.function to generate a single function\n",
    "# Notice that a list of two tensors, is passed in as the second argument of K.function()\n",
    "get_spatial_maps_and_gradient = K.function([model.input], [spatial_maps, gradient])\n",
    "print(type(get_spatial_maps_and_gradient))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_eval_l is type <class 'list'> and length 2\n"
     ]
    }
   ],
   "source": [
    "# The returned function returns a list of the evaluated tensors\n",
    "tensor_eval_l = get_spatial_maps_and_gradient([x])\n",
    "print(f\"tensor_eval_l is type {type(tensor_eval_l)} and length {len(tensor_eval_l)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x_with_batch_dim has shape (1, 10, 10, 1024)\n",
      "grad_x_with_batch_dim has shape (1, 10, 10, 1024)\n"
     ]
    }
   ],
   "source": [
    "# store the two numpy arrays from index 0 and 1 into their own variables\n",
    "spatial_maps_x_with_batch_dim, grad_x_with_batch_dim = tensor_eval_l\n",
    "print(f\"spatial_maps_x_with_batch_dim has shape {spatial_maps_x_with_batch_dim.shape}\")\n",
    "print(f\"grad_x_with_batch_dim has shape {grad_x_with_batch_dim.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x_with_batch_dim has shape (1, 10, 10, 1024)\n",
      "grad_x_with_batch_dim has shape (1, 10, 10, 1024)\n"
     ]
    }
   ],
   "source": [
    "# Note: you could also do this directly from the function call:\n",
    "spatial_maps_x_with_batch_dim, grad_x_with_batch_dim = get_spatial_maps_and_gradient([x])\n",
    "print(f\"spatial_maps_x_with_batch_dim has shape {spatial_maps_x_with_batch_dim.shape}\")\n",
    "print(f\"grad_x_with_batch_dim has shape {grad_x_with_batch_dim.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spatial_maps_x shape (10, 10, 1024)\n",
      "grad_x shape (10, 10, 1024)\n",
      "\n",
      "Spatial maps (print some content):\n",
      "[[-0.46017444  0.20640776 -0.63506377 ...  0.1264174  -0.06400048\n",
      "   0.15870997]\n",
      " [-0.8125281  -0.29398838 -0.8967887  ...  0.21837974 -0.0994716\n",
      "   0.26966757]\n",
      " [-0.508806   -0.14127392 -0.5690727  ...  0.27967227 -0.11622357\n",
      "   0.318372  ]\n",
      " ...\n",
      " [-0.34813794 -0.3922896  -1.0565547  ...  0.17491409 -0.08235557\n",
      "   0.25179753]\n",
      " [-0.4438535  -0.32872048 -0.65662026 ...  0.21583238 -0.10991383\n",
      "   0.31518462]\n",
      " [-0.29580766  0.4920513  -0.2233113  ...  0.08722244 -0.04751847\n",
      "   0.17896183]]\n",
      "\n",
      "Gradient (print some content:\n",
      "[[-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " ...\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]\n",
      " [-1.4058211e-09  2.8323848e-09  3.3191864e-07 ...  9.2680755e-05\n",
      "  -6.2032734e-05  6.4634791e-05]]\n"
     ]
    }
   ],
   "source": [
    "# Remove the batch dimension by taking the 0th index at the batch dimension\n",
    "spatial_maps_x = spatial_maps_x_with_batch_dim[0]\n",
    "grad_x = grad_x_with_batch_dim[0]\n",
    "print(f\"spatial_maps_x shape {spatial_maps_x.shape}\")\n",
    "print(f\"grad_x shape {grad_x.shape}\")\n",
    "\n",
    "print(\"\\nSpatial maps (print some content):\")\n",
    "print(spatial_maps_x[0])\n",
    "print(\"\\nGradient (print some content:\")\n",
    "print(grad_x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bdLxNcp9dD3i"
   },
   "source": [
    "<a name=\"1-1-3\"></a>\n",
    "#### 1.1.3 Implementing GradCAM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3QHMAogf2HbD"
   },
   "source": [
    "<a name='ex-01'></a>\n",
    "### Exercise 1\n",
    "\n",
    "In the next cell, fill in the `grad_cam` method to produce GradCAM visualizations for an input model and image. This is fairly complicated, so it might help to break it down into these steps:\n",
    "\n",
    "1. Hook into model output and last layer activations.\n",
    "2. Get gradients of last layer activations with respect to output.\n",
    "3. Compute value of last layer and gradients for input image.\n",
    "4. Compute weights from gradients by global average pooling.\n",
    "5. Compute the dot product between the last layer and weights to get the score for each pixel.\n",
    "6. Resize, take ReLU, and return cam. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details>\n",
    "    <summary>\n",
    "    <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n",
    "</summary>\n",
    "    \n",
    " The following hints follow the order of the sections described above.\n",
    " 1. Remember that the output shape of our model will be [1, class_amount]. \n",
    "     1. The input in this case will always have batch_size = 1\n",
    " 2. See [K.gradients](https://www.tensorflow.org/api_docs/python/tf/keras/backend/gradients)\n",
    " 3. Follow the procedure we used in the previous two sections.\n",
    " 4. Check the axis; make sure weights have shape (C)!\n",
    " 5. See [np.dot](https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html)\n",
    "     </details>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nAJnlx2T7C78"
   },
   "source": [
    "To test, you will compare your output on an image to the output from a correct implementation of GradCAM. You will receive full credit if the pixel-wise mean squared error is less than 0.05."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TYwp5kvT8ZfR"
   },
   "outputs": [],
   "source": [
    "# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "def grad_cam(input_model, image, category_index, layer_name):\n",
    "    \"\"\"\n",
    "    GradCAM method for visualizing input saliency.\n",
    "    \n",
    "    Args:\n",
    "        input_model (Keras.model): model to compute cam for\n",
    "        image (tensor): input to model, shape (1, H, W, 3)\n",
    "        cls (int): class to compute cam with respect to\n",
    "        layer_name (str): relevant layer in model\n",
    "        H (int): input height\n",
    "        W (int): input width\n",
    "    Return:\n",
    "        cam ()\n",
    "    \"\"\"\n",
    "    cam = None\n",
    "    \n",
    "    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###\n",
    "\n",
    "    # 1. Get placeholders for class output and last layer\n",
    "    # Get the model's output\n",
    "    output_with_batch_dim = input_model.output\n",
    "    \n",
    "    # Remove the batch dimension\n",
    "    output_all_categories = output_with_batch_dim[0]\n",
    "    \n",
    "    # Retrieve only the disease category at the given category index\n",
    "    y_c = output_all_categories[category_index]\n",
    "    \n",
    "    # Get the input model's layer specified by layer_name, and retrive the layer's output tensor\n",
    "    spatial_map_layer = input_model.get_layer(layer_name).output\n",
    "\n",
    "    # 2. Get gradients of last layer with respect to output\n",
    "\n",
    "    # get the gradients of y_c with respect to the spatial map layer (it's a list of length 1)\n",
    "    grads_l = K.gradients(y_c,spatial_map_layer)\n",
    "    \n",
    "    # Get the gradient at index 0 of the list\n",
    "    grads = grads_l[0]\n",
    "        \n",
    "    # 3. Get hook for the selected layer and its gradient, based on given model's input\n",
    "    # Hint: Use the variables produced by the previous two lines of code\n",
    "    spatial_map_and_gradient_function = K.function([input_model.input],[spatial_map_layer, grads])\n",
    "    \n",
    "    # Put in the image to calculate the values of the spatial_maps (selected layer) and values of the gradients\n",
    "    spatial_map_all_dims, grads_val_all_dims = spatial_map_and_gradient_function([image])\n",
    "\n",
    "    # Reshape activations and gradient to remove the batch dimension\n",
    "    # Shape goes from (B, H, W, C) to (H, W, C)\n",
    "    # B: Batch. H: Height. W: Width. C: Channel    \n",
    "    # Reshape spatial map output to remove the batch dimension\n",
    "    spatial_map_val = spatial_map_all_dims[0]\n",
    "    \n",
    "    # Reshape gradients to remove the batch dimension\n",
    "    grads_val = grads_val_all_dims[0]\n",
    "    \n",
    "    # 4. Compute weights using global average pooling on gradient \n",
    "    # grads_val has shape (Height, Width, Channels) (H,W,C)\n",
    "    # Take the mean across the height and also width, for each channel\n",
    "    # Make sure weights have shape (C)\n",
    "    weights = np.mean(grads_val,axis=(0,1))\n",
    "    \n",
    "    # 5. Compute dot product of spatial map values with the weights\n",
    "    cam = np.dot(spatial_map_val, weights)\n",
    "\n",
    "    ### END CODE HERE ###\n",
    "    \n",
    "    # We'll take care of the postprocessing.\n",
    "    H, W = image.shape[1], image.shape[2]\n",
    "    cam = np.maximum(cam, 0) # ReLU so we only get positive importance\n",
    "    cam = cv2.resize(cam, (W, H), cv2.INTER_NEAREST)\n",
    "    cam = cam / cam.max()\n",
    "\n",
    "    return cam"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GBi4c71M7OVY"
   },
   "source": [
    "Below we generate the CAM for the image and compute the error (pixel-wise mean squared difference) from the expected values according to our reference. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "4yAC5xBo8J3L",
    "outputId": "0bb46d16-07c6-4211-8312-5b3de8ba06a7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error from reference: 0.0302, should be less than 0.05\n"
     ]
    }
   ],
   "source": [
    "im = load_image_normalize(im_path, mean, std)\n",
    "cam = grad_cam(model, im, 5, 'conv5_block16_concat') # Mass is class 5\n",
    "\n",
    "# Loads reference CAM to compare our implementation with.\n",
    "reference = np.load(\"reference_cam.npy\")\n",
    "error = np.mean((cam-reference)**2)\n",
    "\n",
    "print(f\"Error from reference: {error:.4f}, should be less than 0.05\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Tnl0D5pN8MvE"
   },
   "source": [
    "Run the next cell to visualize the CAM and the original image. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 837
    },
    "colab_type": "code",
    "id": "m1kqthIt5AOs",
    "outputId": "5d486156-ada6-4670-fe70-ae928b7baef3"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(load_image(im_path, df, preprocess=False), cmap='gray')\n",
    "plt.title(\"Original\")\n",
    "plt.axis('off')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "plt.imshow(load_image(im_path, df, preprocess=False), cmap='gray')\n",
    "plt.imshow(cam, cmap='magma', alpha=0.5)\n",
    "plt.title(\"GradCAM\")\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1Lx0lCu-5DeF"
   },
   "source": [
    "We can see that it focuses on the large (white) empty area on the right lung. Indeed this is a clear case of Mass."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"1-1-4\"></a>\n",
    "#### 1.1.4 Using GradCAM to Visualize Multiple Labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LUQSkaHsBsUn"
   },
   "source": [
    "<a name='ex-02'></a>\n",
    "### Exercise 2\n",
    "\n",
    "We can use GradCAMs for multiple labels on the same image. Let's do it for the labels with best AUC for our model, Cardiomegaly, Mass, and Edema. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "def compute_gradcam(model, img, mean, std, data_dir, df, \n",
    "                    labels, selected_labels, layer_name='conv5_block16_concat'):\n",
    "    \"\"\"\n",
    "    Compute GradCAM for many specified labels for an image. \n",
    "    This method will use the `grad_cam` function.\n",
    "    \n",
    "    Args:\n",
    "        model (Keras.model): Model to compute GradCAM for\n",
    "        img (string): Image name we want to compute GradCAM for.\n",
    "        mean (float): Mean to normalize to image.\n",
    "        std (float): Standard deviation to normalize the image.\n",
    "        data_dir (str): Path of the directory to load the images from.\n",
    "        df(pd.Dataframe): Dataframe with the image features.\n",
    "        labels ([str]): All output labels for the model.\n",
    "        selected_labels ([str]): All output labels we want to compute the GradCAM for.\n",
    "        layer_name: Intermediate layer from the model we want to compute the GradCAM for.\n",
    "    \"\"\"\n",
    "    img_path = data_dir + img\n",
    "    preprocessed_input = load_image_normalize(img_path, mean, std)\n",
    "    predictions = model.predict(preprocessed_input)\n",
    "    print(\"Ground Truth: \", \", \".join(np.take(labels, np.nonzero(df[df[\"Image\"] == img][labels].values[0]))[0]))\n",
    "\n",
    "    plt.figure(figsize=(15, 10))\n",
    "    plt.subplot(151)\n",
    "    plt.title(\"Original\")\n",
    "    plt.axis('off')\n",
    "    plt.imshow(load_image(img_path, df, preprocess=False), cmap='gray')\n",
    "    \n",
    "    j = 1\n",
    "    \n",
    "    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###    \n",
    "    # Loop through all labels\n",
    "    for i in range(len(labels)): # complete this line\n",
    "        # Compute CAM and show plots for each selected label.\n",
    "        \n",
    "        # Check if the label is one of the selected labels\n",
    "        if labels[i] in selected_labels: # complete this line\n",
    "            \n",
    "            # Use the grad_cam function to calculate gradcam\n",
    "            gradcam = grad_cam(model,preprocessed_input,i,layer_name)\n",
    "            \n",
    "            ### END CODE HERE ###\n",
    "            \n",
    "            print(\"Generating gradcam for class %s (p=%2.2f)\" % (labels[i], round(predictions[0][i], 3)))\n",
    "            plt.subplot(151 + j)\n",
    "            plt.title(labels[i] + \": \" + str(round(predictions[0][i], 3)))\n",
    "            plt.axis('off')\n",
    "            plt.imshow(load_image(img_path, df, preprocess=False), cmap='gray')\n",
    "            plt.imshow(gradcam, cmap='magma', alpha=min(0.5, predictions[0][i]))\n",
    "            j +=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MIB7qyZ6_eu4"
   },
   "source": [
    "Run the following cells to print the ground truth diagnosis for a given case and show the original x-ray as well as GradCAMs for Cardiomegaly, Mass, and Edema."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 259
    },
    "colab_type": "code",
    "id": "zCh0tNn_6Wmu",
    "outputId": "26e0692b-af47-46f0-f749-8e007e252fba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground Truth:  Cardiomegaly\n",
      "Generating gradcam for class Cardiomegaly (p=0.98)\n",
      "Generating gradcam for class Mass (p=0.30)\n",
      "Generating gradcam for class Edema (p=0.13)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1080x720 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df = pd.read_csv(\"nih_new/train-small.csv\")\n",
    "\n",
    "image_filename = '00016650_000.png'\n",
    "labels_to_show = ['Cardiomegaly', 'Mass', 'Edema']\n",
    "compute_gradcam(model, image_filename, mean, std, IMAGE_DIR, df, labels, labels_to_show)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cvPOHJeb6HkD"
   },
   "source": [
    "The model correctly predicts absence of mass or edema. The probability for mass is higher, and we can see that it may be influenced by the shapes in the middle of the chest cavity, as well as around the shoulder. We'll run it for two more images. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 259
    },
    "colab_type": "code",
    "id": "4eiXeIXWCCTf",
    "outputId": "90838ee8-7843-4a7f-9268-ad6b36833b94"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground Truth:  Mass\n",
      "Generating gradcam for class Cardiomegaly (p=0.02)\n",
      "Generating gradcam for class Mass (p=0.99)\n",
      "Generating gradcam for class Edema (p=0.32)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1080x720 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "image_filename = '00005410_000.png'\n",
    "compute_gradcam(model, image_filename, mean, std, IMAGE_DIR, df, labels, labels_to_show)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "v8Wx5u3t6fiE"
   },
   "source": [
    "In the example above, the model correctly focuses on the mass near the center of the chest cavity. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 259
    },
    "colab_type": "code",
    "id": "V5ViYPuIqM3H",
    "outputId": "1bf4aa59-bb4a-49ad-8b8a-eb654a97dcaa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground Truth:  Edema\n",
      "Generating gradcam for class Cardiomegaly (p=0.71)\n",
      "Generating gradcam for class Mass (p=0.23)\n",
      "Generating gradcam for class Edema (p=0.99)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1080x720 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "image_name = '00004090_002.png'\n",
    "compute_gradcam(model, image_name, mean, std, IMAGE_DIR, df, labels, labels_to_show)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FroGFoB98A26"
   },
   "source": [
    "Here the model correctly picks up the signs of edema near the bottom of the chest cavity. We can also notice that Cardiomegaly has a high score for this image, though the ground truth doesn't include it. This visualization might be helpful for error analysis; for example, we can notice that the model is indeed looking at the expected area to make the prediction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sN_laVv3DHVp"
   },
   "source": [
    "This concludes the section on GradCAMs. We hope you've gained an appreciation for the importance of interpretation when it comes to deep learning models in medicine. Interpretation tools like this one can be helpful for discovery of markers, error analysis, and even in deployment. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hBR1ML_8pEXe"
   },
   "source": [
    "<a name=\"2\"></a>\n",
    "## 2 Feature Importance in Machine Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jqq5oDOuLjrB"
   },
   "source": [
    "When developing predictive models and risk measures, it's often helpful to know which features are making the most difference. This is easy to determine in simpler models such as linear models and decision trees. However as we move to more complex models to achieve high performance, we usually sacrifice some interpretability. In this assignment we'll try to regain some of that interpretability using Shapley values, a technique which has gained popularity in recent years, but which is based on classic results in cooperative game theory. \n",
    "\n",
    "We'll revisit our random forest model from course 2 module 2 and try to analyze it more closely using Shapley values. Run the next cell to load in the data and model from that assignment and recalculate the test set c-index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 122
    },
    "colab_type": "code",
    "id": "VrCZoOmJVF_U",
    "outputId": "9251d060-baed-440d-d4a2-cd206b60e3a3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model C-index on test: 0.7776169781865744\n"
     ]
    }
   ],
   "source": [
    "rf = pickle.load(open('nhanes_rf.sav', 'rb')) # Loading the model\n",
    "test_df = pd.read_csv('nhanest_test.csv')\n",
    "test_df = test_df.drop(test_df.columns[0], axis=1)\n",
    "X_test = test_df.drop('y', axis=1)\n",
    "y_test = test_df.loc[:, 'y']\n",
    "cindex_test = cindex(y_test, rf.predict_proba(X_test)[:, 1])\n",
    "\n",
    "print(\"Model C-index on test: {}\".format(cindex_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "w4-TT-8VXJxW"
   },
   "source": [
    "Run the next cell to print out the riskiest individuals according to our model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 258
    },
    "colab_type": "code",
    "id": "y_7aDY6rWsbS",
    "outputId": "f296fd3f-37c5-4834-d0f1-19f93758ab63"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Diastolic BP</th>\n",
       "      <th>Poverty index</th>\n",
       "      <th>Race</th>\n",
       "      <th>Red blood cells</th>\n",
       "      <th>Sedimentation rate</th>\n",
       "      <th>Serum Albumin</th>\n",
       "      <th>Serum Cholesterol</th>\n",
       "      <th>Serum Iron</th>\n",
       "      <th>Serum Magnesium</th>\n",
       "      <th>Serum Protein</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Systolic BP</th>\n",
       "      <th>TIBC</th>\n",
       "      <th>TS</th>\n",
       "      <th>White blood cells</th>\n",
       "      <th>BMI</th>\n",
       "      <th>Pulse pressure</th>\n",
       "      <th>risk</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>572</th>\n",
       "      <td>70.0</td>\n",
       "      <td>80.0</td>\n",
       "      <td>312.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>54.8</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.4</td>\n",
       "      <td>222.0</td>\n",
       "      <td>52.0</td>\n",
       "      <td>1.57</td>\n",
       "      <td>7.2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>180.0</td>\n",
       "      <td>417.0</td>\n",
       "      <td>12.5</td>\n",
       "      <td>7.5</td>\n",
       "      <td>45.770473</td>\n",
       "      <td>100.0</td>\n",
       "      <td>0.77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>190</th>\n",
       "      <td>69.0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>316.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>77.7</td>\n",
       "      <td>26.0</td>\n",
       "      <td>4.2</td>\n",
       "      <td>197.0</td>\n",
       "      <td>65.0</td>\n",
       "      <td>1.49</td>\n",
       "      <td>7.5</td>\n",
       "      <td>1.0</td>\n",
       "      <td>165.0</td>\n",
       "      <td>298.0</td>\n",
       "      <td>21.8</td>\n",
       "      <td>8.8</td>\n",
       "      <td>22.129018</td>\n",
       "      <td>65.0</td>\n",
       "      <td>0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1300</th>\n",
       "      <td>73.0</td>\n",
       "      <td>80.0</td>\n",
       "      <td>999.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>52.6</td>\n",
       "      <td>35.0</td>\n",
       "      <td>3.9</td>\n",
       "      <td>258.0</td>\n",
       "      <td>61.0</td>\n",
       "      <td>1.66</td>\n",
       "      <td>6.8</td>\n",
       "      <td>1.0</td>\n",
       "      <td>150.0</td>\n",
       "      <td>314.0</td>\n",
       "      <td>19.4</td>\n",
       "      <td>9.4</td>\n",
       "      <td>26.466850</td>\n",
       "      <td>70.0</td>\n",
       "      <td>0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>634</th>\n",
       "      <td>66.0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>69.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>42.9</td>\n",
       "      <td>47.0</td>\n",
       "      <td>3.8</td>\n",
       "      <td>233.0</td>\n",
       "      <td>170.0</td>\n",
       "      <td>1.42</td>\n",
       "      <td>8.6</td>\n",
       "      <td>1.0</td>\n",
       "      <td>180.0</td>\n",
       "      <td>411.0</td>\n",
       "      <td>41.4</td>\n",
       "      <td>7.2</td>\n",
       "      <td>22.129498</td>\n",
       "      <td>80.0</td>\n",
       "      <td>0.68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1221</th>\n",
       "      <td>74.0</td>\n",
       "      <td>80.0</td>\n",
       "      <td>67.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>40.3</td>\n",
       "      <td>24.0</td>\n",
       "      <td>3.7</td>\n",
       "      <td>139.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>1.91</td>\n",
       "      <td>6.4</td>\n",
       "      <td>2.0</td>\n",
       "      <td>140.0</td>\n",
       "      <td>495.0</td>\n",
       "      <td>5.7</td>\n",
       "      <td>4.1</td>\n",
       "      <td>22.066389</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.68</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       Age  Diastolic BP  Poverty index  Race  Red blood cells  \\\n",
       "572   70.0          80.0          312.0   1.0             54.8   \n",
       "190   69.0         100.0          316.0   1.0             77.7   \n",
       "1300  73.0          80.0          999.0   1.0             52.6   \n",
       "634   66.0         100.0           69.0   2.0             42.9   \n",
       "1221  74.0          80.0           67.0   1.0             40.3   \n",
       "\n",
       "      Sedimentation rate  Serum Albumin  Serum Cholesterol  Serum Iron  \\\n",
       "572                  7.0            4.4              222.0        52.0   \n",
       "190                 26.0            4.2              197.0        65.0   \n",
       "1300                35.0            3.9              258.0        61.0   \n",
       "634                 47.0            3.8              233.0       170.0   \n",
       "1221                24.0            3.7              139.0        28.0   \n",
       "\n",
       "      Serum Magnesium  Serum Protein  Sex  Systolic BP   TIBC    TS  \\\n",
       "572              1.57            7.2  1.0        180.0  417.0  12.5   \n",
       "190              1.49            7.5  1.0        165.0  298.0  21.8   \n",
       "1300             1.66            6.8  1.0        150.0  314.0  19.4   \n",
       "634              1.42            8.6  1.0        180.0  411.0  41.4   \n",
       "1221             1.91            6.4  2.0        140.0  495.0   5.7   \n",
       "\n",
       "      White blood cells        BMI  Pulse pressure  risk  \n",
       "572                 7.5  45.770473           100.0  0.77  \n",
       "190                 8.8  22.129018            65.0  0.69  \n",
       "1300                9.4  26.466850            70.0  0.69  \n",
       "634                 7.2  22.129498            80.0  0.68  \n",
       "1221                4.1  22.066389            60.0  0.68  "
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_risky = X_test.copy(deep=True)\n",
    "X_test_risky.loc[:, 'risk'] = rf.predict_proba(X_test)[:, 1] # Predicting our risk.\n",
    "X_test_risky = X_test_risky.sort_values(by='risk', ascending=False) # Sorting by risk value.\n",
    "X_test_risky.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "egBagzTuLduV"
   },
   "source": [
    "<a name=\"2-1\"></a>\n",
    "### 2.1 Permuation Method for Feature Importance\n",
    "\n",
    "First we'll try to determine feature importance using the permutation method. In the permutation method, the importance of feature $i$ would be the regular performance of the model minus the performance with the values for feature $i$ permuted in the dataset. This way we can assess how well a model without that feature would do without having to train a new model for each feature. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-1-1\"></a>\n",
    "#### 2.1.1 Implementing Permutation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='ex-03'></a>\n",
    "### Exercise 3\n",
    "\n",
    "Complete the implementation of the function below, which given a feature name returns a dataset with those feature values randomly permuted. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details>\n",
    "    <summary>\n",
    "    <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n",
    "</summary>\n",
    "    <ul>\n",
    "        <li>\n",
    "            See <a href=https://numpy.org/devdocs/reference/random/generated/numpy.random.permutation.html> np.random.permutation</a>\n",
    "        </li>\n",
    "    </ul>\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 306
    },
    "colab_type": "code",
    "id": "iKTpkaRgP-dz",
    "outputId": "bc74170d-deac-444c-cd6d-7223f0545c87"
   },
   "outputs": [],
   "source": [
    "# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "def permute_feature(df, feature):\n",
    "    \"\"\"\n",
    "    Given dataset, returns version with the values of\n",
    "    the given feature randomly permuted. \n",
    "\n",
    "    Args:\n",
    "        df (dataframe): The dataset, shape (num subjects, num features)\n",
    "        feature (string): Name of feature to permute\n",
    "    Returns:\n",
    "        permuted_df (dataframe): Exactly the same as df except the values\n",
    "                                of the given feature are randomly permuted.\n",
    "    \"\"\"\n",
    "    permuted_df = df.copy(deep=True) # Make copy so we don't change original df\n",
    "\n",
    "    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###\n",
    "\n",
    "    # Permute the values of the column 'feature'\n",
    "    permuted_features = df.copy(deep=True)\n",
    "    \n",
    "    # Set the column 'feature' to its permuted values.\n",
    "    permuted_df[feature] = np.random.permutation(permuted_df[feature])\n",
    "    \n",
    "    ### END CODE HERE ###\n",
    "\n",
    "    return permuted_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Case\n",
      "Original dataframe:\n",
      "   col1 col2\n",
      "0     0    A\n",
      "1     1    B\n",
      "2     2    C\n",
      "\n",
      "\n",
      "col1 permuted:\n",
      "   col1 col2\n",
      "0     2    A\n",
      "1     1    B\n",
      "2     0    C\n",
      "\n",
      "\n",
      "Compute average values over 1000 runs to get expected values:\n",
      "Average of col1: [0.976 1.03  0.994], expected value: [0.976, 1.03, 0.994]\n"
     ]
    }
   ],
   "source": [
    "print(\"Test Case\")\n",
    "\n",
    "example_df = pd.DataFrame({'col1': [0, 1, 2], 'col2':['A', 'B', 'C']})\n",
    "print(\"Original dataframe:\")\n",
    "print(example_df)\n",
    "print(\"\\n\")\n",
    "\n",
    "print(\"col1 permuted:\")\n",
    "print(permute_feature(example_df, 'col1'))\n",
    "\n",
    "print(\"\\n\")\n",
    "print(\"Compute average values over 1000 runs to get expected values:\")\n",
    "col1_values = np.zeros((3, 1000))\n",
    "np.random.seed(0) # Adding a constant seed so we can always expect the same values and evaluate correctly. \n",
    "for i in range(1000):\n",
    "    col1_values[:, i] = permute_feature(example_df, 'col1')['col1'].values\n",
    "\n",
    "print(\"Average of col1: {}, expected value: [0.976, 1.03, 0.994]\".format(np.mean(col1_values, axis=1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-1-2\"></a>\n",
    "#### 2.1.2 Implementing Importance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YNH7oo7jQVo6"
   },
   "source": [
    "<a name='ex-04'></a>\n",
    "### Exercise 4\n",
    "\n",
    "Now we will use the function we just created to compute feature importances (according to the permutation method) in the function below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details>\n",
    "    <summary>\n",
    "    <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n",
    "</summary>\n",
    "\\begin{align}\n",
    "I_x  = \\left\\lvert perf - perf_x  \\right\\rvert\n",
    "\\end{align}\n",
    "\n",
    "where $I_x$ is the importance of feature $x$ and\n",
    "\\begin{align}\n",
    "perf_x  = \\frac{1}{n}\\cdot \\sum_{i=1}^{n} perf_i^{sx}\n",
    "\\end{align}\n",
    "\n",
    "where $perf_i^{sx}$ is the performance with the feature $x$ shuffled in the $i$th permutation.\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "colab_type": "code",
    "id": "R_PP-DEz04hp",
    "outputId": "a8fb98da-d513-43cd-d349-b2265beb270c"
   },
   "outputs": [],
   "source": [
    "# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "def permutation_importance(X, y, model, metric, num_samples = 100):\n",
    "    \"\"\"\n",
    "    Compute permutation importance for each feature.\n",
    "\n",
    "    Args:\n",
    "        X (dataframe): Dataframe for test data, shape (num subject, num features)\n",
    "        y (np.array): Labels for each row of X, shape (num subjects,)\n",
    "        model (object): Model to compute importances for, guaranteed to have\n",
    "                        a 'predict_proba' method to compute probabilistic \n",
    "                        predictions given input\n",
    "        metric (function): Metric to be used for feature importance. Takes in ground\n",
    "                           truth and predictions as the only two arguments\n",
    "        num_samples (int): Number of samples to average over when computing change in\n",
    "                           performance for each feature\n",
    "    Returns:\n",
    "        importances (dataframe): Dataframe containing feature importance for each\n",
    "                                 column of df with shape (1, num_features)\n",
    "    \"\"\"\n",
    "\n",
    "    importances = pd.DataFrame(index = ['importance'], columns = X.columns)\n",
    "    \n",
    "    # Get baseline performance (note, you'll use this metric function again later)\n",
    "    baseline_performance = metric(y, model.predict_proba(X)[:, 1])\n",
    "\n",
    "    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###\n",
    "\n",
    "    # Iterate over features (the columns in the importances dataframe)\n",
    "    for feature in importances.columns: # complete this line\n",
    "        \n",
    "        # Compute 'num_sample' performances by permutating that feature\n",
    "        \n",
    "        # You'll see how the model performs when the feature is permuted\n",
    "        # You'll do this num_samples number of times, and save the performance each time\n",
    "        # To store the feature performance,\n",
    "        # create a numpy array of size num_samples, initialized to all zeros\n",
    "        feature_performance_arr = np.zeros(num_samples)\n",
    "        \n",
    "        # Loop through each sample\n",
    "        for i in range(num_samples): # complete this line\n",
    "            \n",
    "            # permute the column of dataframe X\n",
    "            perm_X = permute_feature(X,feature)\n",
    "            \n",
    "            # calculate the performance with the permuted data\n",
    "            # Use the same metric function that was used earlier\n",
    "            feature_performance_arr[i] = metric(y, model.predict_proba(perm_X)[:, 1])\n",
    "    \n",
    "    \n",
    "        # Compute importance: absolute difference between \n",
    "        # the baseline performance and the average across the feature performance\n",
    "        importances[feature]['importance'] = np.abs(baseline_performance - np.mean(feature_performance_arr))\n",
    "        \n",
    "    ### END CODE HERE ###\n",
    "\n",
    "    return importances"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Test Case**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Case\n",
      "\n",
      "\n",
      "We check our answers on a Logistic Regression on a dataset\n",
      "where y is given by a sigmoid applied to the important feature.\n",
      "The unimportant feature is random noise.\n",
      "\n",
      "\n",
      "Computed importances:\n",
      "           important  unimportant\n",
      "importance  0.496674  2.89012e-05\n",
      "\n",
      "\n",
      "Expected importances (approximate values):\n",
      "            important  unimportant\n",
      "importance        0.5          0.0\n",
      "If you round the actual values, they will be similar to the expected values\n"
     ]
    }
   ],
   "source": [
    "print(\"Test Case\")\n",
    "print(\"\\n\")\n",
    "print(\"We check our answers on a Logistic Regression on a dataset\")\n",
    "print(\"where y is given by a sigmoid applied to the important feature.\") \n",
    "print(\"The unimportant feature is random noise.\")\n",
    "print(\"\\n\")\n",
    "example_df = pd.DataFrame({'important': np.random.normal(size=(1000)), 'unimportant':np.random.normal(size=(1000))})\n",
    "example_y = np.round(1 / (1 + np.exp(-example_df.important)))\n",
    "example_model = sklearn.linear_model.LogisticRegression(fit_intercept=False).fit(example_df, example_y)\n",
    "\n",
    "example_importances = permutation_importance(example_df, example_y, example_model, cindex, num_samples=100)\n",
    "print(\"Computed importances:\")\n",
    "print(example_importances)\n",
    "print(\"\\n\")\n",
    "print(\"Expected importances (approximate values):\")\n",
    "print(pd.DataFrame({\"important\": 0.50, \"unimportant\": 0.00}, index=['importance']))\n",
    "print(\"If you round the actual values, they will be similar to the expected values\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-1-3\"></a>\n",
    "#### 2.1.3 Computing our Feature Importance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "47iozBppkTwl"
   },
   "source": [
    "Next, we compute importances on our dataset. Since we are computing the permutation importance for all the features, it might take a few minutes to run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 134
    },
    "colab_type": "code",
    "id": "uhxor87eYEKt",
    "outputId": "0e98bf36-4691-484c-bb5e-6ceda74de971"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Diastolic BP</th>\n",
       "      <th>Poverty index</th>\n",
       "      <th>Race</th>\n",
       "      <th>Red blood cells</th>\n",
       "      <th>Sedimentation rate</th>\n",
       "      <th>Serum Albumin</th>\n",
       "      <th>Serum Cholesterol</th>\n",
       "      <th>Serum Iron</th>\n",
       "      <th>Serum Magnesium</th>\n",
       "      <th>Serum Protein</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Systolic BP</th>\n",
       "      <th>TIBC</th>\n",
       "      <th>TS</th>\n",
       "      <th>White blood cells</th>\n",
       "      <th>BMI</th>\n",
       "      <th>Pulse pressure</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>importance</th>\n",
       "      <td>0.147772</td>\n",
       "      <td>0.0113034</td>\n",
       "      <td>0.0111148</td>\n",
       "      <td>0.000449158</td>\n",
       "      <td>0.000805694</td>\n",
       "      <td>0.006285</td>\n",
       "      <td>0.00527172</td>\n",
       "      <td>0.000848118</td>\n",
       "      <td>0.000203789</td>\n",
       "      <td>0.00274019</td>\n",
       "      <td>0.00154867</td>\n",
       "      <td>0.0272337</td>\n",
       "      <td>0.00618949</td>\n",
       "      <td>0.00225922</td>\n",
       "      <td>0.000425288</td>\n",
       "      <td>0.00256128</td>\n",
       "      <td>0.00304884</td>\n",
       "      <td>0.00379624</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Age Diastolic BP Poverty index         Race Red blood cells  \\\n",
       "importance  0.147772    0.0113034     0.0111148  0.000449158     0.000805694   \n",
       "\n",
       "           Sedimentation rate Serum Albumin Serum Cholesterol   Serum Iron  \\\n",
       "importance           0.006285    0.00527172       0.000848118  0.000203789   \n",
       "\n",
       "           Serum Magnesium Serum Protein        Sex Systolic BP        TIBC  \\\n",
       "importance      0.00274019    0.00154867  0.0272337  0.00618949  0.00225922   \n",
       "\n",
       "                     TS White blood cells         BMI Pulse pressure  \n",
       "importance  0.000425288        0.00256128  0.00304884     0.00379624  "
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importances = permutation_importance(X_test, y_test, rf, cindex, num_samples=100)\n",
    "importances"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pQ-d-X6VlQ2T"
   },
   "source": [
    "Let's plot these in a bar chart for easier comparison."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 515
    },
    "colab_type": "code",
    "id": "YOpUIt32lQZW",
    "outputId": "fa872d44-da2f-4517-c2e1-8f5e54080522"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "importances.T.plot.bar()\n",
    "plt.ylabel(\"Importance\")\n",
    "l = plt.legend()\n",
    "l.remove()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yFbnsXMTnmRI"
   },
   "source": [
    "You should see age as by far the best prediction of near term mortality, as one might expect. Next is sex, followed by diastolic blood pressure. Interestingly, the poverty index also has a large impact, despite the fact that it is not directly related to an individual's health. This alludes to the importance of social determinants of health in our model. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-wiIIOV0Xgyn"
   },
   "source": [
    "<a name=\"2-2\"></a>\n",
    "### 2.2 Shapley Values for Random Forests"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I-MzbzsnXaGI"
   },
   "source": [
    "We'll contrast the permutation method with a more recent technique known as Shapley values (actually, Shapley values date back to the mid 20th century, but have only been applied to machine learning very recently). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-2-1\"></a>\n",
    "#### 2.2.1 Visualizing Feature Importance on Specific Individuals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use Shapley values to try and understand the model output on specific individuals. In general Shapley values take exponential time to compute, but luckily there are faster approximations for forests in particular that run in polynomial time. Run the next cell to display a 'force plot' showing how each feature influences the output for the first person in our dataset. If you want more information about 'force plots' and other decision plots, please take a look at [this notebook](https://github.com/slundberg/shap/blob/master/notebooks/plots/decision_plot.ipynb) by the `shap` library creators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 310
    },
    "colab_type": "code",
    "id": "iVPZg-I_XjFJ",
    "outputId": "4fde0bf6-6cd6-44b5-dcfb-c906eda2066b"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting feature_perturbation = \"tree_path_dependent\" because no background data was given.\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1440x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "explainer = shap.TreeExplainer(rf)\n",
    "i = 0 # Picking an individual\n",
    "shap_value = explainer.shap_values(X_test.loc[X_test_risky.index[i], :])[1]\n",
    "shap.force_plot(explainer.expected_value[1], shap_value, feature_names=X_test.columns, matplotlib=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5EIPLm1hAunF"
   },
   "source": [
    "For this individual, their age, pulse pressure, and sex were the biggest contributors to their high risk prediction. Note how shapley values give us greater granularity in our interpretations. \n",
    "\n",
    "Feel free to change the `i` value above to explore the feature influences for different individuals."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-2-2\"></a>\n",
    "#### 2.2.2 Visualizing Feature Importance on Aggregate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BH1p21GpY-0Y"
   },
   "source": [
    "Just like with the permutation method, we might also want to understand model output in aggregate. Shapley values allow us to do this as well. Run the next cell to initialize the shapley values for each example in the test set (this may also take a few minutes). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "gFXBwZesYkRb",
    "outputId": "69846d78-f974-4299-82f6-fa0823345c16"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting feature_perturbation = \"tree_path_dependent\" because no background data was given.\n"
     ]
    }
   ],
   "source": [
    "shap_values = shap.TreeExplainer(rf).shap_values(X_test)[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can ignore the `setting feature_perturbation` message."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RzPYyknlZR3N"
   },
   "source": [
    "Run the next cell to see a summary plot of the shapley values for each feature on each of the test examples. The colors indicate the value of the feature. The features are listed in terms of decreasing absolute average shapley value over all the individuals in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "colab_type": "code",
    "id": "msyIjYJxZMwn",
    "outputId": "cbfd0b81-e4c5-42ff-e073-4943257b921d"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x626.4 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "shap.summary_plot(shap_values, X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above plot, you might be able to notice a high concentration of points on specific SHAP value ranges. This means that a high proportion of our test set lies on those ranges.\n",
    "\n",
    "As with the permutation method, age, sex, poverty index, and diastolic BP seem to be the most important features. Being older has a negative impact on mortality, and being a woman (sex=2.0) has a positive effect. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"2-2-3\"></a>\n",
    "#### 2.2.3 Visualizing Interactions between Features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "srQhm3NLZo_I"
   },
   "source": [
    "The `shap` library also lets you visualize interactions between features using dependence plots. These plot the Shapley value for a given feature for each data point, and color the points in using the value for another feature. This lets us begin to explain the variation in shapley value for a single value of the main feature."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UZ8XWDYcZtKr"
   },
   "source": [
    "Run the next cell to see the interaction between Age and Sex. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 341
    },
    "colab_type": "code",
    "id": "RnaiY5h3Zh4s",
    "outputId": "a3bff772-dc60-4fa2-f1d7-19b67c551208"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 540x360 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "shap.dependence_plot('Age', shap_values, X_test, interaction_index = 'Sex')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that while Age > 50 is generally bad (positive Shapley value), being a woman (red points) generally reduces the impact of age. This makes sense since we know that women generally live longer than men. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CwHcYZFXZ1R0"
   },
   "source": [
    "Run the next cell to see the interaction between Poverty index and Age "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337
    },
    "colab_type": "code",
    "id": "xzfw5itbZwAQ",
    "outputId": "f9f229f0-4316-4687-9fe2-4d3461d4dad0"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 540x360 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "shap.dependence_plot('Poverty index', shap_values, X_test, interaction_index='Age')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the impact of poverty index drops off quickly, and for higher income individuals age begins to explain much of variation in the impact of poverty index. We encourage you to try some other pairs and see what other interesting relationships you can find!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a1tzI8JrBVa6"
   },
   "source": [
    "Congratulations! You've completed the final assignment of course 3, well done! "
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "C3M2_Assignment.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "coursera": {
   "schema_names": [
    "AI4MC3-3"
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}