{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CEMExplainer: MNIST Example\n", "- This notebook showcases an example of how to use the CEMExplainer from [AIX360](https://github.com/IBM/AIX360) to obtain contrastive explanations i.e. *pertinent negatives (PNs)* and *pertinent postitives (PPs)* for predictions made by a model trained on MNIST data. \n", "- The CEMExplainer is an implementation of the [contrastive explanation method](https://arxiv.org/abs/1802.07623).\n", "- The default location of this notebook is aix360/examples/constrastive/ folder. This notebook uses trained models which are accessed from aix360/models/CEM/ folder." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import statements" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", "C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\dask\\dataframe\\utils.py:14: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", " import pandas.util.testing as tm\n" ] } ], "source": [ "import os\n", "import sys\n", "from keras.models import model_from_json\n", "from PIL import Image\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "\n", "from aix360.algorithms.contrastive import CEMExplainer, KerasClassifier\n", "from aix360.datasets import MNISTDataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load MNIST data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# load MNIST data and normalize it in the range [-0.5, 0.5]\n", "data = MNISTDataset()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MNIST train data range : ( -0.5 , 0.5 )\n", "MNIST test data range : ( -0.5 , 0.5 )\n", "MNIST train data shape : (55000, 28, 28, 1)\n", "MNIST test data shape : (10000, 28, 28, 1)\n", "MNIST train labels shape: (10000, 10)\n", "MNIST test labels shape : (10000, 10)\n" ] } ], "source": [ "# print the shape of train and test data\n", "print(\"MNIST train data range :\", \"(\", np.min(data.train_data), \",\", np.max(data.train_data), \")\")\n", "print(\"MNIST test data range :\", \"(\", np.min(data.train_data), \",\", np.max(data.train_data), \")\")\n", "print(\"MNIST train data shape :\", data.train_data.shape)\n", "print(\"MNIST test data shape :\", data.test_data.shape)\n", "print(\"MNIST train labels shape:\", data.test_labels.shape)\n", "print(\"MNIST test labels shape :\", data.test_labels.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load a trained MNIST model\n", "- This notebook uses a trained MNIST model. The code to train this model is available [here](https://github.com/huanzhang12/ZOO-Attack/blob/master/train_models.py). Note that the model outputs logits and does not use a softmax function. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Colocations handled automatically by placer.\n", "Model: \"sequential_4\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d_5 (Conv2D) (None, 26, 26, 32) 320 \n", "_________________________________________________________________\n", "activation_7 (Activation) (None, 26, 26, 32) 0 \n", "_________________________________________________________________\n", "conv2d_6 (Conv2D) (None, 24, 24, 32) 9248 \n", "_________________________________________________________________\n", "activation_8 (Activation) (None, 24, 24, 32) 0 \n", "_________________________________________________________________\n", "max_pooling2d_3 (MaxPooling2 (None, 12, 12, 32) 0 \n", "_________________________________________________________________\n", "conv2d_7 (Conv2D) (None, 10, 10, 64) 18496 \n", "_________________________________________________________________\n", "activation_9 (Activation) (None, 10, 10, 64) 0 \n", "_________________________________________________________________\n", "conv2d_8 (Conv2D) (None, 8, 8, 64) 36928 \n", "_________________________________________________________________\n", "activation_10 (Activation) (None, 8, 8, 64) 0 \n", "_________________________________________________________________\n", "max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64) 0 \n", "_________________________________________________________________\n", "flatten_2 (Flatten) (None, 1024) 0 \n", "_________________________________________________________________\n", "dense_4 (Dense) (None, 200) 205000 \n", "_________________________________________________________________\n", "activation_11 (Activation) (None, 200) 0 \n", "_________________________________________________________________\n", "dense_5 (Dense) (None, 200) 40200 \n", "_________________________________________________________________\n", "activation_12 (Activation) (None, 200) 0 \n", "_________________________________________________________________\n", "dense_6 (Dense) (None, 10) 2010 \n", "=================================================================\n", "Total params: 312,202\n", "Trainable params: 312,202\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "# path to mnist related models\n", "model_path = '../../aix360/models/CEM'\n", "\n", "def load_model(model_json_file, model_wt_file):\n", " \n", " # read model json file\n", " with open(model_json_file, 'r') as f:\n", " model = model_from_json(f.read())\n", " \n", " # read model weights file\n", " model.load_weights(model_wt_file)\n", " \n", " return model\n", " \n", "\n", "# load MNIST model using its json and wt files\n", "mnist_model = load_model(os.path.join(model_path, 'mnist.json'), os.path.join(model_path, 'mnist'))\n", "\n", "# print model summary\n", "mnist_model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load a trained convolutional autoencoder model (optional)\n", "- This notebook uses a trained convolutional autoencoder model. The code to train this model is available [here](https://github.com/chunchentu/autoencoder). " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_1\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "sequential_2 (Sequential) (None, 14, 14, 1) 2625 \n", "_________________________________________________________________\n", "conv2d_4 (Conv2D) (None, 14, 14, 16) 160 \n", "_________________________________________________________________\n", "activation_3 (Activation) (None, 14, 14, 16) 0 \n", "_________________________________________________________________\n", "up_sampling2d_1 (UpSampling2 (None, 28, 28, 16) 0 \n", "_________________________________________________________________\n", "conv2d_5 (Conv2D) (None, 28, 28, 16) 2320 \n", "_________________________________________________________________\n", "activation_4 (Activation) (None, 28, 28, 16) 0 \n", "_________________________________________________________________\n", "conv2d_6 (Conv2D) (None, 28, 28, 1) 145 \n", "=================================================================\n", "Total params: 5,250\n", "Trainable params: 5,250\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "# load the trained convolutional autoencoder model\n", "ae_model = load_model(os.path.join(model_path, 'mnist_AE_1_decoder.json'), \n", " os.path.join(model_path, 'mnist_AE_1_decoder.h5'))\n", "# print model summary\n", "ae_model.summary()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Initialize CEM Explainer to explain model predictions" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# wrap mnist_model into a framework independent class structure\n", "mymodel = KerasClassifier(mnist_model)\n", "\n", "# initialize explainer object\n", "explainer = CEMExplainer(mymodel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explain an input instance" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class: [3]\n", "Predicted logits: [[-11.279338 0.7362482 -9.008648 19.396715 -8.286125\n", " 14.442826 -1.3170455 -11.587322 -0.99218464 1.0182221 ]]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAANLElEQVR4nO3db4xV9Z3H8c8HbBNj0aAGAnZAtsFkTc3Chpg1Niubpo1LjNgH3ZTEhk3R4QEaGjdG7T6oyWaTZrN29YFpGCIpu6k2TZQF6mZbQ1B2nzQiDgJli65gGRgH/zwoPOo6fPfBHHZHnHPucM+599zO9/1KJvfe851zzjd3+HDOvb97z88RIQBz37y2GwDQH4QdSIKwA0kQdiAJwg4kcVU/d2abt/6BHosIz7S81pHd9t22f2P7HduP19kWgN5yt+PstudLOiHpa5LGJL0uaUNE/LpiHY7sQI/14sh+u6R3IuLdiPi9pJ9KWl9jewB6qE7Yb5J0etrjsWLZp9getn3Q9sEa+wJQU5036GY6VfjMaXpEjEgakTiNB9pU58g+Jmlo2uMvSjpbrx0AvVIn7K9LWml7he3PS/qWpD3NtAWgaV2fxkfEJ7YfkvQLSfMl7YiIY411BqBRXQ+9dbUzXrMDPdeTD9UA+MNB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEn29lDTmnltuuaWyvm3bttLa888/X7nu9u3bu+oJM+PIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcHVZVOo0jv7yyy9X1lesWFFaO336dGmt07oox9VlgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJvs+e3NatW2vVly1b1vW+33vvva7XxZWrFXbbpySdlzQp6ZOIWNNEUwCa18SR/S8i4sMGtgOgh3jNDiRRN+wh6Ze237A9PNMv2B62fdD2wZr7AlBD3dP4OyPirO1Fkl6x/V8RcWD6L0TEiKQRiS/CAG2qdWSPiLPF7TlJuyTd3kRTAJrXddhtX2N7waX7kr4u6WhTjQFoVp3T+MWSdtm+tJ3nI+LfG+kKjbnqquo/8a233lpZX758eWW90/UQTpw4UVq7//77K9dFs7oOe0S8K+lPGuwFQA8x9AYkQdiBJAg7kARhB5Ig7EASfMV1jtu8eXNlfdOmTT3d/0cffVRaGxsb6+m+8Wkc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZ54ClS5eW1h544IHKdYuvKJeaN6/6eHDx4sXK+qOPPlpZR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnnwOqpk2+7bbbKtftdCnoTuPoe/furawfOnSoso7+4cgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj4HXLhwobRWdd12Sbrhhhtq7fuOO+6orK9cubK0duzYsVr7xpXpeGS3vcP2OdtHpy273vYrtt8ubhf2tk0Adc3mNP7Hku6+bNnjkvZFxEpJ+4rHAAZYx7BHxAFJH1+2eL2kncX9nZLua7gvAA3r9jX74ogYl6SIGLe9qOwXbQ9LGu5yPwAa0vM36CJiRNKIJNmu/tYFgJ7pduhtwvYSSSpuzzXXEoBe6DbseyRtLO5vlLS7mXYA9Io7fZ/Z9guS1kq6UdKEpO9L+ldJP5O0TNJvJX0zIi5/E2+mbXEa32fbtm2rrHean73TdeU7/fup2v+WLVsq10V3ImLGP1rH1+wRsaGk9NVaHQHoKz4uCyRB2IEkCDuQBGEHkiDsQBIdh94a3RlDb303NDRUWT958mRlve7Q2/j4eGntnnvuqVz38OHDlXXMrGzojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtyTz31VGX9kUceqax3mtK5ytjYWGV9+fLlXW87M8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTu+666yrr69atq6yPjIxU1q+++urS2uTkZOW6nS6DvWPHjsr66OhoZX2uYpwdSI6wA0kQdiAJwg4kQdiBJAg7kARhB5JgnB217Nq1q7K+du3a0tqCBQtq7XtiYqKyvmrVqtLaBx98UGvfg6zrcXbbO2yfs3102rInbZ+xPVr8VH/yAkDrZnMa/2NJd8+w/J8iYlXx82/NtgWgaR3DHhEHJH3ch14A9FCdN+gesv1WcZq/sOyXbA/bPmj7YI19Aaip27D/SNKXJK2SNC6p9KqFETESEWsiYk2X+wLQgK7CHhETETEZERclbZd0e7NtAWhaV2G3vWTaw29IOlr2uwAGQ8dxdtsvSFor6UZJE5K+XzxeJSkknZK0OSLKJ+L+/20xzp7M5s2bS2vPPvtsrW13mjt+2bJlpbUzZ87U2vcgKxtnv2oWK26YYfFztTsC0Fd8XBZIgrADSRB2IAnCDiRB2IEkOr4bD9Rx+PDhtltAgSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHsf3HXXXbXWf+211xrqpHkPPvhgZf2JJ54orXX6imon8+ZxrLoSPFtAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7A1YunRpZX337t2V9QMHDlTWFy1adMU9zda9995bWe/0GYHFixdX1ufPn19a63QZ89HR0cr6+vXrK+vvv/9+ZT0bjuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kETHKZsb3dkcnbJ5aGiosn7y5MnKeqfvdffzb3S5ur2dP3++tPbYY49Vrrt3797K+vh4x1nCUyqbsrnjkd32kO39to/bPmZ7a7H8etuv2H67uF3YdNMAmjOb0/hPJP1NRPyxpD+TtMX2rZIel7QvIlZK2lc8BjCgOoY9IsYj4lBx/7yk45JukrRe0s7i13ZKuq9XTQKo74o+G2/7ZkmrJf1K0uKIGJem/kOwPeMHuG0PSxqu1yaAumYddttfkPSipO9GxO9me7HAiBiRNFJsY06+QQf8IZjV0Jvtz2kq6D+JiJeKxRO2lxT1JZLO9aZFAE3oeGT31CH8OUnHI+KH00p7JG2U9IPitvp7nHPY5ORkZb1q+EmSrr322ibbadTY2Fhl/c0336ysP/PMM6W1/fv3d9UTujOb0/g7JX1b0hHbl75g/D1NhfxntjdJ+q2kb/amRQBN6Bj2iPhPSWUv0L/abDsAeoWPywJJEHYgCcIOJEHYgSQIO5AEX3Htg06XY169enWt7T/88MOltVdffbVy3SNHjlTWn3766W5aQou6/oorgLmBsANJEHYgCcIOJEHYgSQIO5AEYQeSYJwdmGMYZweSI+xAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkOobd9pDt/baP2z5me2ux/EnbZ2yPFj/ret8ugG51vHiF7SWSlkTEIdsLJL0h6T5JfyXpQkT846x3xsUrgJ4ru3jFbOZnH5c0Xtw/b/u4pJuabQ9Ar13Ra3bbN0taLelXxaKHbL9le4fthSXrDNs+aPtgrU4B1DLra9DZ/oKk1yT9fUS8ZHuxpA8lhaS/09Sp/nc6bIPTeKDHyk7jZxV225+T9HNJv4iIH85Qv1nSzyPiyx22Q9iBHuv6gpO2Lek5ScenB7144+6Sb0g6WrdJAL0zm3fjvyLpPyQdkXSxWPw9SRskrdLUafwpSZuLN/OqtsWRHeixWqfxTSHsQO9x3XggOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASHS842bAPJb037fGNxbJBNKi9DWpfEr11q8nelpcV+vp99s/s3D4YEWtaa6DCoPY2qH1J9NatfvXGaTyQBGEHkmg77CMt77/KoPY2qH1J9NatvvTW6mt2AP3T9pEdQJ8QdiCJVsJu+27bv7H9ju3H2+ihjO1Tto8U01C3Oj9dMYfeOdtHpy273vYrtt8ubmecY6+l3gZiGu+KacZbfe7anv6876/Zbc+XdELS1ySNSXpd0oaI+HVfGylh+5SkNRHR+gcwbP+5pAuS/vnS1Fq2/0HSxxHxg+I/yoUR8diA9PakrnAa7x71VjbN+F+rxeeuyenPu9HGkf12Se9ExLsR8XtJP5W0voU+Bl5EHJD08WWL10vaWdzfqal/LH1X0ttAiIjxiDhU3D8v6dI0460+dxV99UUbYb9J0ulpj8c0WPO9h6Rf2n7D9nDbzcxg8aVptorbRS33c7mO03j302XTjA/Mc9fN9Od1tRH2maamGaTxvzsj4k8l/aWkLcXpKmbnR5K+pKk5AMclPdVmM8U04y9K+m5E/K7NXqaboa++PG9thH1M0tC0x1+UdLaFPmYUEWeL23OSdmnqZccgmbg0g25xe67lfv5PRExExGREXJS0XS0+d8U04y9K+klEvFQsbv25m6mvfj1vbYT9dUkrba+w/XlJ35K0p4U+PsP2NcUbJ7J9jaSva/Cmot4jaWNxf6Ok3S328imDMo132TTjavm5a33684jo+4+kdZp6R/6/Jf1tGz2U9PVHkg4XP8fa7k3SC5o6rfsfTZ0RbZJ0g6R9kt4ubq8foN7+RVNTe7+lqWAtaam3r2jqpeFbkkaLn3VtP3cVffXleePjskASfIIOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5L4X+a8OHheluJSAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# choose an input image\n", "image_id = 340\n", "input_image = data.test_data[image_id]\n", "\n", "# rescale values from [-0.5, 0.5] to [0, 255] for plotting\n", "plt.imshow((input_image[:,:,0] + 0.5)*255, cmap=\"gray\")\n", "\n", "# check model prediction\n", "print(\"Predicted class:\", mymodel.predict_classes(np.expand_dims(input_image, axis=0)))\n", "print(\"Predicted logits:\", mymodel.predict(np.expand_dims(input_image, axis=0)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Observation: \n", "\n", "Although the above image is classified as digit 3 by the model, it could have been classified as digit 5 as well since it has similarities to the digit 5. We now employ the CEMExplainer from AIX360 to compute pertinent positive and pertinent negative explanations, which help us understand why the image was classified as digit 3 by the model and not as digit 5. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Obtain Pertinent Negative (PN) explanation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\training\\learning_rate_decay_v2.py:321: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Deprecated in favor of operator or tf.math.divide.\n", "WARNING:tensorflow:From C:\\Users\\RONNYLUSS\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.cast instead.\n", "iter:0 const:[10.]\n", "Loss_Overall:2737.2229, Loss_Attack:58.5389\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:500 const:[10.]\n", "Loss_Overall:2737.2229, Loss_Attack:58.5389\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[100.]\n", "Loss_Overall:3152.3979, Loss_Attack:0.0000\n", "Loss_L2Dist:12.6054, Loss_L1Dist:16.5280, AE_loss:3123.264404296875\n", "target_lab_score:9.0004, max_nontarget_lab_score:29.0375\n", "\n", "iter:500 const:[100.]\n", "Loss_Overall:2977.4849, Loss_Attack:0.0000\n", "Loss_L2Dist:7.0313, Loss_L1Dist:10.1030, AE_loss:2960.3505859375\n", "target_lab_score:9.2486, max_nontarget_lab_score:28.5018\n", "\n", "iter:0 const:[55.]\n", "Loss_Overall:2840.0417, Loss_Attack:0.0000\n", "Loss_L2Dist:4.8674, Loss_L1Dist:7.2291, AE_loss:2827.9453125\n", "target_lab_score:9.7374, max_nontarget_lab_score:27.1471\n", "\n", "iter:500 const:[55.]\n", "Loss_Overall:2670.4834, Loss_Attack:0.0000\n", "Loss_L2Dist:0.8409, Loss_L1Dist:2.1313, AE_loss:2667.51123046875\n", "target_lab_score:15.5937, max_nontarget_lab_score:19.4013\n", "\n", "iter:0 const:[32.5]\n", "Loss_Overall:2644.0200, Loss_Attack:2.0429\n", "Loss_L2Dist:0.5595, Loss_L1Dist:1.8527, AE_loss:2639.56494140625\n", "target_lab_score:16.7141, max_nontarget_lab_score:17.5513\n", "\n", "iter:500 const:[32.5]\n", "Loss_Overall:2868.9355, Loss_Attack:190.2514\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[21.25]\n", "Loss_Overall:2782.8970, Loss_Attack:117.1807\n", "Loss_L2Dist:0.0176, Loss_L1Dist:0.2093, AE_loss:2665.4892578125\n", "target_lab_score:19.1928, max_nontarget_lab_score:14.5784\n", "\n", "iter:500 const:[21.25]\n", "Loss_Overall:2803.0791, Loss_Attack:124.3951\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[26.875]\n", "Loss_Overall:2738.9080, Loss_Attack:91.5859\n", "Loss_L2Dist:0.1530, Loss_L1Dist:0.9359, AE_loss:2646.233154296875\n", "target_lab_score:18.1907, max_nontarget_lab_score:15.6829\n", "\n", "iter:500 const:[26.875]\n", "Loss_Overall:2836.0073, Loss_Attack:157.3233\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[24.0625]\n", "Loss_Overall:2774.3584, Loss_Attack:117.5740\n", "Loss_L2Dist:0.0524, Loss_L1Dist:0.4683, AE_loss:2656.263671875\n", "target_lab_score:18.8622, max_nontarget_lab_score:14.8760\n", "\n", "iter:500 const:[24.0625]\n", "Loss_Overall:2819.5432, Loss_Attack:140.8592\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[22.65625]\n", "Loss_Overall:2784.2271, Loss_Attack:122.1559\n", "Loss_L2Dist:0.0291, Loss_L1Dist:0.2813, AE_loss:2661.7607421875\n", "target_lab_score:19.1110, max_nontarget_lab_score:14.6193\n", "\n", "iter:500 const:[22.65625]\n", "Loss_Overall:2811.3113, Loss_Attack:132.6272\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n", "iter:0 const:[23.359375]\n", "Loss_Overall:2777.7581, Loss_Attack:118.8575\n", "Loss_L2Dist:0.0398, Loss_L1Dist:0.3873, AE_loss:2658.4736328125\n", "target_lab_score:18.9654, max_nontarget_lab_score:14.7772\n", "\n", "iter:500 const:[23.359375]\n", "Loss_Overall:2815.4272, Loss_Attack:136.7432\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:2678.68408203125\n", "target_lab_score:19.3967, max_nontarget_lab_score:14.4428\n", "\n" ] } ], "source": [ "arg_mode = \"PN\" # Find pertinent negative\n", "\n", "arg_max_iter = 1000 # Maximum number of iterations to search for the optimal PN for given parameter settings\n", "arg_init_const = 10.0 # Initial coefficient value for main loss term that encourages class change\n", "arg_b = 9 # No. of updates to the coefficient of the main loss term\n", "\n", "arg_kappa = 0.9 # Minimum confidence gap between the PNs (changed) class probability and original class' probability\n", "arg_beta = 1.0 # Controls sparsity of the solution (L1 loss)\n", "arg_gamma = 100 # Controls how much to adhere to a (optionally trained) autoencoder\n", "arg_alpha = 0.01 # Penalizes L2 norm of the solution\n", "arg_threshold = 0.05 # Automatically turn off features <= arg_threshold if arg_threshold < 1\n", "arg_offset = 0.5 # the model assumes classifier trained on data normalized\n", " # in [-arg_offset, arg_offset] range, where arg_offset is 0 or 0.5\n", "\n", "\n", "(adv_pn, delta_pn, info_pn) = explainer.explain_instance(np.expand_dims(input_image, axis=0), arg_mode, ae_model, arg_kappa, arg_b, \n", " arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha, arg_threshold, arg_offset)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO]kappa:0.9, Orig class:3, Perturbed class:5, Delta class: 1, Orig prob:[[-11.279338 0.7362482 -9.008648 19.396715 -8.286125 14.442826 -1.3170455 -11.587322 -0.99218464 1.0182221 ]], Perturbed prob:[[ -6.6616817 -1.9708817 -7.401487 13.478742 -6.3133864 13.78304 1.2838321 -11.600546 0.29793242 1.085611 ]], Delta prob:[[-0.11010491 1.0595146 -0.08893302 -0.25925025 -0.3346461 0.22845559 -0.099649 -0.00456608 -0.31767696 -0.56160116]]\n" ] } ], "source": [ "print(info_pn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Obtain Pertinent Positive (PP) explanation" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter:0 const:[10.]\n", "Loss_Overall:1186.7104, Loss_Attack:20.4772\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:500 const:[10.]\n", "Loss_Overall:1186.7104, Loss_Attack:20.4772\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:0 const:[100.]\n", "Loss_Overall:1374.8162, Loss_Attack:224.8765\n", "Loss_L2Dist:0.0581, Loss_L1Dist:0.5667, AE_loss:1149.824951171875\n", "target_lab_score:-0.1908, max_nontarget_lab_score:1.1579\n", "\n", "iter:500 const:[100.]\n", "Loss_Overall:1179.4254, Loss_Attack:0.0000\n", "Loss_L2Dist:9.2291, Loss_L1Dist:27.3661, AE_loss:1167.4598388671875\n", "target_lab_score:10.4896, max_nontarget_lab_score:3.3207\n", "\n", "iter:0 const:[55.]\n", "Loss_Overall:1278.8578, Loss_Attack:112.6245\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:500 const:[55.]\n", "Loss_Overall:1278.8578, Loss_Attack:112.6245\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:0 const:[77.5]\n", "Loss_Overall:1324.9314, Loss_Attack:158.6982\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:500 const:[77.5]\n", "Loss_Overall:1324.9314, Loss_Attack:158.6982\n", "Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.2332763671875\n", "target_lab_score:-0.1036, max_nontarget_lab_score:1.0441\n", "\n", "iter:0 const:[88.75]\n", "Loss_Overall:1347.3336, Loss_Attack:190.4549\n", "Loss_L2Dist:0.0195, Loss_L1Dist:0.2384, AE_loss:1156.8353271484375\n", "target_lab_score:-0.1378, max_nontarget_lab_score:1.1082\n", "\n", "iter:500 const:[88.75]\n", "Loss_Overall:1206.6337, Loss_Attack:0.0000\n", "Loss_L2Dist:7.8383, Loss_L1Dist:24.1343, AE_loss:1196.3819580078125\n", "target_lab_score:8.3027, max_nontarget_lab_score:3.7181\n", "\n", "iter:0 const:[83.125]\n", "Loss_Overall:1336.9935, Loss_Attack:176.8079\n", "Loss_L2Dist:0.0096, Loss_L1Dist:0.1385, AE_loss:1160.1622314453125\n", "target_lab_score:-0.1352, max_nontarget_lab_score:1.0918\n", "\n", "iter:500 const:[83.125]\n", "Loss_Overall:1176.8406, Loss_Attack:0.0000\n", "Loss_L2Dist:8.3247, Loss_L1Dist:24.4351, AE_loss:1166.0723876953125\n", "target_lab_score:8.1267, max_nontarget_lab_score:4.7032\n", "\n", "iter:0 const:[80.3125]\n", "Loss_Overall:1330.7097, Loss_Attack:169.8772\n", "Loss_L2Dist:0.0070, Loss_L1Dist:0.1182, AE_loss:1160.813720703125\n", "target_lab_score:-0.1306, max_nontarget_lab_score:1.0846\n", "\n", "iter:500 const:[80.3125]\n", "Loss_Overall:1175.4065, Loss_Attack:0.0000\n", "Loss_L2Dist:9.0849, Loss_L1Dist:26.8103, AE_loss:1163.640625\n", "target_lab_score:9.3781, max_nontarget_lab_score:2.3079\n", "\n", "iter:0 const:[78.90625]\n", "Loss_Overall:1327.5853, Loss_Attack:166.4040\n", "Loss_L2Dist:0.0058, Loss_L1Dist:0.1080, AE_loss:1161.1646728515625\n", "target_lab_score:-0.1282, max_nontarget_lab_score:1.0807\n", "\n", "iter:500 const:[78.90625]\n", "Loss_Overall:1168.5416, Loss_Attack:0.0000\n", "Loss_L2Dist:7.8014, Loss_L1Dist:23.5646, AE_loss:1158.3837890625\n", "target_lab_score:6.9406, max_nontarget_lab_score:5.5205\n", "\n", "iter:0 const:[78.203125]\n", "Loss_Overall:1326.0404, Loss_Attack:164.6752\n", "Loss_L2Dist:0.0053, Loss_L1Dist:0.1030, AE_loss:1161.349609375\n", "target_lab_score:-0.1270, max_nontarget_lab_score:1.0788\n", "\n", "iter:500 const:[78.203125]\n", "Loss_Overall:1183.5975, Loss_Attack:0.0000\n", "Loss_L2Dist:9.7138, Loss_L1Dist:28.4317, AE_loss:1171.0406494140625\n", "target_lab_score:9.9128, max_nontarget_lab_score:1.6840\n", "\n" ] } ], "source": [ "arg_mode = \"PP\" # Find pertinent positive\n", "arg_beta = 0.1 # Controls sparsity of the solution (L1 loss)\n", "(adv_pp, delta_pp, info_pp) = explainer.explain_instance(np.expand_dims(input_image, axis=0), arg_mode, ae_model, arg_kappa, arg_b, \n", " arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha, arg_threshold, arg_offset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO]kappa:0.9, Orig class:3, Perturbed class:3, Delta class: 3, Orig prob:[[-11.279338 0.7362482 -9.008648 19.396715 -8.286125 14.442826 -1.3170455 -11.587322 -0.99218464 1.0182221 ]], Perturbed prob:[[ -5.984942 -0.3156201 -6.267382 11.657149 -3.6047158 11.557238 3.9308367 -11.3727045 -0.803853 -1.8081436]], Delta prob:[[-2.7503839 0.4277636 -1.0708491 4.933249 -1.9914135 1.1908851 -2.4917073 -0.88367814 -1.0458403 1.2483816 ]]\n" ] } ], "source": [ "print(info_pp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot Pertinent Negative (PN) and Pertinent Positive (PP) explanations" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAACPCAYAAAA1FeWWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVHklEQVR4nO3df7BcZX3H8c8XAlNa8oNAYZIQklAThIKIgohxCoyGiSkSpj+slNgwg8I0SAOhFNDqKDhqbaEgopNQMJYf2hYwIQVEyhAYFVIJJCQ05YcNmBtuEogN+SFak3z7xzk5nLN39+6ee8/unvPs+zVz5z5nn7N7vjmffbLPPefsrrm7AAAA0Lr9ul0AAABA1TCBAgAAyIkJFAAAQE5MoAAAAHJiAgUAAJATEygAAICcemICZWafMbN/KnrdFh7LzewdqeWzzGxJi/c9x8y+V0QdZVOWPMrCzL5iZpe1uO59Zjaz3TUBjTB+0argX/PcvXI/ki6QtEbSLyVtkvQtSWO6XVedOl3SO1LLT0t6f2r5MUmvS9ouabWk2TX3XyvpXd3+d4SQRb08CnrMVyS9JWmnpM2Svi3p4LhvuaRfSZqYWv/Dkl5JLf+upI2SDoqXJ8d17kz9fC61/vskrez2vuxSfnX3dSv7mZ+G+5Tx23z87pT0hqT7JI3r9n7o0r5vNvbq7qMQX/PSP5U7AmVmV0j6O0lXShot6f2SJkl6xMwOrLP+iM5WWJ+ZnSJptLs/lbp5vqIn2yhJF0m608zGpfq/G99eSnmziO9TijwGY2YXmNniHHf5qLsfLOk9kk6R9Lepvl2SPjfIfS+Q9KC7v1Vz+xh3Pzj+uW7fje7+n5JGmdnJOeoLSaN93Ww/owbjNzHY+P103DdN0hhJ/1hYodXTaD/V3UchvubVqtQEysxGSfqipEvd/Qfu/ht3f0XSxxQN/Dlm9gUzu8fM7jSz7ZIuiG+7M/U4f2Fmr5rZVjP7nJm9YmYfjvuSdc1scnzYeK6Z/dzM3jCzz6Ye531m9qSZbTOzfjP7RqP/eCR9RNLj6Rvc/Tl3371vUdIBkiamVlku6Q+HvMPaqJUs4vXKmkfh3H2jpIckHZ+6+euSzhvk1MOA50ULlqukz4tOqbOvm+1npDB+B2owfvf1/ULSvfX6ek2j/VRnHwX1mldPpSZQkj4g6bcUHSZMuPtORYHOiG+aLekeRbPhu9Lrmtlxkr4p6XxJ4xT95TWhyXY/KOkYSR+S9HkzOza+fY+kyyUdJum0uH9eg8c4QdILtTea2b+b2a8krVD05Hk61b1O0uT4P7uyaTULqZx5FM7MJkqaJenZ1M0bJd0q6QsN7lb3eSHpVTPrM7Nvm9lhNX3rJJ04zHIrrc6+brafkcX4rdFg/O7rO0zSH9fr6zWN9lOdfRTaa94AVZtAHSbpjdQMNq0/7pekJ919ibvvrXNq5E8kLXP3H7n7/0n6vKKZ8GC+6O5vuftqRedtT5Qkd1/p7k+5++74r7eFkk5v8BhjJO2ovdHdz5Y0UtET8mF335vq3pG6b9m0moVUzjyKtMTMtkn6kaK/uL5c0/8VSR81s9+vc9/a58Ubig6PT5L0XkXPjbtq7rND5XxOdMJg+3qw/Ywsxu/bBntOfT3uW61ovyzoQD1l1Wg/NdpHob3mDVD689k13pB0mJmNqDPwx8X9krRhkMcYn+5391+a2dYm292Uav9S0cVzMrNpkm6QdLKk31a0P1c2eIz/VfSkGcDdfyPpITObb2Y/c/f74659629rUl83tJqFVM48Mszsm5L+PF48UNIIMzs3Xv65u79rkLuf6+7/0ajT3V83s29IulbRRbppmedFfARg319km83s05L6zWyUu2+Pbx+pcj4nOmHAvjYzSU33M7IYv28bbPz+lbsX8i7CADQae432UWiveQNU7QjUk5J+LemP0jea2e8oOt/6aHzTYH8B9Us6MnXfgyQdOsR6viXpvyVNjS+K+4wka7Duc4oushvMCEm/l1o+VtE7ibY3WL+bWs1CKmceGe4+z93HuPsYRacN7t633OQ/31b9vaQzFR1VSmv2vNi379L/jmMV/bWHgRrtZ2QxftFuob3mDVCpCZS7v6nowsebzWymmR1gZpMl/ZukPkl3tPAw9yg6zP+B+ALFL6rFQVrHSEVvx9xpZu+U9JeDrPugUoejzeydZvYRMzso/nfMkfQHyl50d7qi6xFKp6AspO7l0VHuvk3S9ZL+pqar9nlxqpkdY2b7mdmhii6OXh7v731K+7zotkH2M1IYv+iAoF7z6qnUBEqS3P1riv4y+QdFg22FokPIH3L3X7dw/+clXSrpe4r+etohaYuiv8by+mtFh413KLqA9V8G2e4zkt40s1Pjm0zRBa9bFH0uxnxJfxavt895iq4DKKXhZhE/Rlfy6JKbFF0om/bPkmbFf7lL0tGSfqDo37BW0X44b9/KFr01eFf8cQaor95+Rg3GL9opxNe8Wube7Hq/sJnZwYrOt0519/Vt3tZZkua5+7ktrPtRSZ9w94+1s6ay6WQeZWFmX5a0xd1vbGHdeyXd5u4Ptr8yIJ9eHL9oLPTXvJ6cQMVBPapoRny9pFMlvcd7cWeUAHkA1cX4Ra+q3Cm8gsyW9Fr8M1XSxxnsXUUeQHUxftGTevIIFAAAwHAM6whU/O6NF8zsZTO7uqii0B3kGQ6yDAt5hoMswzHkI1Bmtr+kFxV95H+fpJ9KOs/d/6u48tAp5BkOsgwLeYaDLMMynE8if5+kl939fyTJzL6n6Fx4wyeCmXG+sMvcvdFntOTKkyy7r6gs43XIs8sYm+FgbIalUZ7DOYU3QdmP+O9TnS+RNLOLzOxpM3u6tg+l0jRPsqwMxmZYGJvhYGwGZDhHoOrNyAbMlN19kaRFEjPpkmuaJ1lWBmMzLIzNcDA2AzKcI1B9kiamlo9U9DZWVBN5hoMsw0Ke4SDLgAxnAvVTSVPNbEr8HUgfl3R/k/ugvMgzHGQZFvIMB1kGZMin8Nx9t5l9WtLDkvaXdHv8vUioIPIMB1mGhTzDQZZh6egHaXIut/sGeXdILmTZfUVlKZFnGTA2w8HYDEs73oUHAADQk5hAAQAA5DScjzHoedOmTUvaCxcuzPTdfffdSfvWW2/tWE0YGrIEyomxibLiCBQAAEBOTKAAAAByYgIFAACQEx9jkEP6XLwkPfDAA0l7ypQpmb4NGzY07Osm3iodIcusqucZAsZmhLGZVfU8Q8DHGAAAABSECRQAAEBOfIxBE/Pnz6/blqSjjjqq4f1effXVttWEoSkiyyVLlmT6du3albTPP//84ZaIDqu9hOHNN99M2mPGjOl0OT2L/2dRRRyBAgAAyIkJFAAAQE5MoAAAAHLiGqg6Rox4e7ccd9xxSXvSpEmZ9dLXT7z44ouZvjlz5rSpOuRRdJZ9fX1Fl4guMivs3ebIif9nUXUcgQIAAMiJCRQAAEBOnMKr4+KLL07aF154YUv32bp1a2aZUz3lQJZAObU6Njdv3py0169fn+ljbFbP9OnTk3Y6W0l6+eWXO13OsHAECgAAICcmUAAAADkxgQIAAMiJa6AkjR8/PrP8yU9+Mmmn3+a8337Z+ebevXuT9pVXXtmm6pAHWQLl1OrYfO211zLrbd++PWlfd911baoO7XLWWWdllqdMmZK0f/zjH3e6nEJxBAoAACAnJlAAAAA5cQpPA7/t+4QTTkja6U/BTZ/mkaRly5Yl7WeeeaZN1SEPsgTKqdWxWXuqb+XKlUmbsVk9P/zhD7tdQttwBAoAACCnphMoM7vdzLaY2drUbWPN7BEzeyn+fUh7y0RRyDMcZBkW8gwHWfaGVo5ALZY0s+a2qyU96u5TJT0aL6MaFos8Q7FYZBmSxSLPUCwWWQav6TVQ7v6EmU2uuXm2pDPi9nckLZd0VYF1ddTOnTszy+mv8jj00EMb3u+0005L2lOnTs30Pf/88wVVV6zQ8yTLcLLsNaHnydgMJ0tEhnoN1BHu3i9J8e/DiysJXUCe4SDLsJBnOMgyMG1/F56ZXSTponZvB+1HlmEhz3CQZVjIsxqGOoHabGbj3L3fzMZJ2tJoRXdfJGmRJJmZN1qvm9auXZtZXrJkSdIe7FvC04ed582bl+m75JJLCqquI1rKkywrIaixCcYmY7OceWLop/DulzQ3bs+VtLSYctAl5BkOsgwLeYaDLAPTyscYfFfSk5KOMbM+M7tQ0lclzTCzlyTNiJdRAeQZDrIMC3mGgyx7QyvvwjuvQdeHCq4FHUCe4SDLsJBnOMiyN1j6I/TbvrGKnMudOHFi0l6/fn3STn9juJT9+oH+/v5M39lnn520V69eXXSJQ+bu1nyt5siy+4rKUqpOniHr5bH5k5/8JGnv2rUrs96mTZuS9po1azJ9l19+edLevXt30SUOWa+PzTPPPDNpT5o0KdM3ffr0pH3yySdn+m6//fakffPNN7epuvwa5clXuQAAAOTEBAoAACAnTuE1cf311yftBQsWZPr27t3b8H59fX1Ju/YQZjf12mmCNLJsrIp5hqaXx+bMmW9/68kVV1yR6duwYUPSHjt2bKZvwoQJSfuUU05pU3X59frYTH9ERa0ZM2Yk7eeeey7Tlz4l+9RTTxVf2BBxCg8AAKAgTKAAAABy4hReE6NHj07as2bNyvQtWrQoaR900EGZvj179iTthQsXZvrS7zRYtWpVIXW2qpdPE5BlY1XMMzS9PDbTDjzwwMzybbfdlrTT7+CSsqfwGJvlMW3atKRdezrvrbfeStrXXnttpm/p0nJ+tiin8AAAAArCBAoAACAnJlAAAAA5cQ3UMHz/+99P2meccUamb+TIkQ3vt3nz5qT97ne/O9P3+uuvF1NcA1xnUV8vZymFl2cVMTbrY2yGlWcVcQ0UAABAQZhAAQAA5MQpvIJcfPHFmeVbbrml4brpL7I96qijMn0bN24strAanCZorteylMLOsyoYm80xNtENnMIDAAAoCBMoAACAnJhAAQAA5DSi2wWEYvXq1d0uAQUhS6Ccli9fnlnu6+tL2kceeWSHq8FwHX/88ZnliRMnJu2HHnqo0+XkxhEoAACAnJhAAQAA5BT0KbzTTz+9Yd/jjz8+7Mf/1Kc+lbSvueaaTF/6LbS19tuPeWteZAmUU7vH5tFHH520582bl+nbunVr0k6f/pEYm2V1zz33JO0JEyZk+h544IGkzSk8AACAADGBAgAAyIkJFAAAQE5BXQM1fvz4zPLSpUuT9hNPPJHpO/zww1t6zHPOOSdp157rP+KII5L2/vvvn+lLf0XOqlWrMn2zZ89O2ps2bWqpjl5DlkA5tWNsjh49OmnXXhczadKkpF37lSwnnnhi0n722WczfYzN7rnjjjuS9pw5czJ96evWHn744Uzfl770pfYWVjCOQAEAAOTUdAJlZhPN7DEzW2dmz5vZ/Pj2sWb2iJm9FP8+pP3lYrjIMhyMzbCQZTgYm73B0qcn6q5gNk7SOHd/xsxGSlop6VxJF0j6hbt/1cyulnSIu1/V5LHa+q3StW9jXb9+fXrbmb5m/+56BnuMHTt2ZPquuurtXbFs2bJMX39/f+5tF+i9ZBlMluNVkbGJlvTs2FyxYkXS3rZtW6Yv/QnjtafwGJvldMMNNyTt2k8bf/DBB5P2jTfe2LGahsPd636WTdMjUO7e7+7PxO0dktZJmiBptqTvxKt9R9GTAyVHluFgbIaFLMPB2OwNuS4iN7PJkk6StELSEe7eL0VPFjOre7WgmV0k6aLhlYmikWVYyDMcZBkW8gxXyxMoMztY0r2SLnP37YN9OnOauy+StCh+jModigwRWYaFPMNBlmEhz7C1NIEyswMUPQnucvf74ps3m9m4eBY9TtKWdhXZqj179mSW09eyjBo1atiPn/7mbyn7ttmbbrop0/fYY48Ne3vtQJaRELKUqpMnmqtKlu0Ym6eeemrS3rhxY6YvPTYvu+yyTB9js5wWLFjQ7RI6opV34Zmk2yStc/cbUl33S5obt+dKWlp7X5QSWQaCsRkcsgwEY7M3tHIEarqkT0haY2b7PkXwM5K+KulfzexCST+X9KftKREFI8twMDbDQpbhYGz2gKYfY1Doxjp8Ljf9adMnnXRSw/UuvfTSzPLy5cuT9po1a5J2Vd5yOZhGb8fMiyy7r6gsJa6zKAPGZoSxmcXY7L4hf4wBAAAAsphAAQAA5MQECgAAIKegr4HCQFW9zgIDcZ1FWBib4WBshoVroAAAAArCBAoAACAnJlAAAAA5MYECAADIiQkUAABATkygAAAAcmICBQAAkBMTKAAAgJyYQAEAAOTEBAoAACAnJlAAAAA5MYECAADIiQkUAABATkygAAAAcmICBQAAkBMTKAAAgJxGdHh7b0h6VdJhcbvbeq2OSQU+Flk21olaisxSiurdpd7ah61gbA5fWeqQGJtFKEueXR+b5u4d2H7NRs2edveTO75h6ihcWWovSx1SuWrJo0x1l6WWstQxFGWpvSx1SOWqJY8y1V2WWspQB6fwAAAAcmICBQAAkFO3JlCLurTdWtQxfGWpvSx1SOWqJY8y1V2WWspSx1CUpfay1CGVq5Y8ylR3WWrpeh1duQYKAACgyjiFBwAAkFNHJ1BmNtPMXjCzl83s6g5v+3Yz22Jma1O3jTWzR8zspfj3IR2oY6KZPWZm68zseTOb361ahoMsw8lSIs94m0HkSZbhZCmRZ5mz7NgEysz2l3SLpI9IOk7SeWZ2XKe2L2mxpJk1t10t6VF3nyrp0Xi53XZLusLdj5X0fkmXxPuhG7UMCVkmKp+lRJ4plc+TLBOVz1Iiz1h5s3T3jvxIOk3Sw6nlayRd06ntx9ucLGltavkFSePi9jhJL3Synni7SyXNKEMtZNl7WZJnWHmSZThZkmf5s+zkKbwJkjaklvvi27rpCHfvl6T49+Gd3LiZTZZ0kqQV3a4lJ7KsUeEsJfIcoMJ5kmWNCmcpkWdG2bLs5ATK6tzWs28BNLODJd0r6TJ3397tenIiy5SKZymRZ0bF8yTLlIpnKZFnooxZdnIC1SdpYmr5SEmvdXD79Ww2s3GSFP/e0omNmtkBip4Id7n7fd2sZYjIMhZAlhJ5JgLIkyxjAWQpkafi7ZQyy05OoH4qaaqZTTGzAyV9XNL9Hdx+PfdLmhu35yo6t9pWZmaSbpO0zt1v6GYtw0CWCiZLiTwlBZMnWSqYLCXyLHeWHb74a5akFyX9TNJnO7zt70rql/QbRbP6CyUdqujq/Zfi32M7UMcHFR2CfU7SqvhnVjdqIUuyJM/w8iTLcLIkz3JnySeRAwAA5MQnkQMAAOTEBAoAACAnJlAAAAA5MYECAADIiQkUAABATkygAAAAcmICBQAAkBMTKAAAgJz+H9K/VXZbJcdDAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# rescale values from [-0.5, 0.5] to [0, 255] for plotting\n", "fig0 = (input_image[:,:,0] + 0.5)*255\n", "\n", "fig1 = (adv_pn[0,:,:,0] + 0.5) * 255\n", "fig2 = (fig1 - fig0) #rescaled delta_pn\n", "fig3 = (adv_pp[0,:,:,0] + 0.5) * 255\n", "fig4 = (delta_pp[0,:,:,0] + 0.5) * 255 #rescaled delta_pp\n", "\n", "f, axarr = plt.subplots(1, 5, figsize=(10,10))\n", "axarr[0].set_title(\"Original\" + \"(\" + str(mymodel.predict_classes(np.expand_dims(input_image, axis=0))[0]) + \")\")\n", "axarr[1].set_title(\"Original + PN\" + \"(\" + str(mymodel.predict_classes(adv_pn)[0]) + \")\")\n", "axarr[2].set_title(\"PN\")\n", "axarr[3].set_title(\"Original + PP\")\n", "axarr[4].set_title(\"PP\" + \"(\" + str(mymodel.predict_classes(delta_pp)[0]) + \")\")\n", "\n", "axarr[0].imshow(fig0, cmap=\"gray\")\n", "axarr[1].imshow(fig1, cmap=\"gray\")\n", "axarr[2].imshow(fig2, cmap=\"gray\")\n", "axarr[3].imshow(fig3, cmap=\"gray\")\n", "axarr[4].imshow(fig4, cmap=\"gray\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Explanation: \n", "- The PP highlights the minimum set of pixels which were present in the image for it to be classified as digit 3. Note that both the original image and PP are classified as digit 3 by the classifier. \n", "- The PN highlights a small horizontal line at the top whose presence would change the classification of the original image to digit 5 and thus should be absent for the classification to remain digit 3. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }