{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise 12.3\n", "## Discriminative localization (CAM)\n", "In this exercise we will create class activation maps (CAMs) for predictions made by a model trained to classify magentic phases (see [Exercise 7_1](Exercise_7_1_solution.ipynb)).\n", " 1. Pick out two correctly and two wrongly classified images classified with a convolutional network.\n", " 2. Look at Exercise 8.1 to extract weights and feature maps from the trained network model.\n", " 3. Create and plot the class activation maps and compare them with the images in order to see which regions lead to the prediction.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "keras version 2.4.0\n" ] } ], "source": [ "from tensorflow import keras\n", "import numpy as np\n", "callbacks = keras.callbacks\n", "layers = keras.layers\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and prepare dataset\n", "See https://doi.org/10.1038/nphys4035 for more information" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import gdown\n", "url = \"https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1Ihxt1hb3Kyv0IrjHlsYb9x9QY7l7n2Sl\"\n", "output = 'ising_data.npz'\n", "gdown.download(url, output, quiet=True)\n", "\n", "f = np.load(output)\n", "n_train = 20000\n", "\n", "x_train, x_test = f[\"C\"][:n_train], f[\"C\"][n_train:]\n", "T_train, T_test = f[\"T\"][:n_train], f[\"T\"][n_train:]\n", "\n", "Tc = 2.27\n", "y_train = np.zeros_like(T_train)\n", "y_train[T_train > Tc] = 1\n", "y_train = keras.utils.to_categorical(y_train, 2)\n", "\n", "y_test = np.zeros_like(T_test)\n", "y_test[T_test > Tc] = 1\n", "y_test = keras.utils.to_categorical(y_test, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEZCAYAAABFFVgWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAowklEQVR4nO3df5AcZ33n8fd3VyuvLSFsISNkyUJOZNUaKCw7OouEYIyBCqSgjI6Ujx+F5RxXyiWXAi7UxT5SKaBC7kSuAuTqjgNV+YdIHIwLw9kEAnHZpmxfysY2GPskR7YwUpBiS2uEZcsyWix974/p5VrSPrNPb/fT093zeVVNaWe6p/uZ2a/62fnO83wfc3dERERSGBl0A0REpLvUyYiISDLqZEREJBl1MiIikow6GRERSUadjIiIJKNORkREkulsJ2Nmh3K3Y2b2Qu7++2d57p+Z2SNm9qKZfWKWfc3MPm1mP81unzYzy21fa2YPmtnh7N+11bxCqVpdMZPtf6GZ3ZUde5+ZfTi3bZWZ3ZnFzD+Z2VsqeHlSsRqvMf/RzJ4ws2fN7F/M7LNmNi+3vdHx0tlOxt0XTt+AfwbemXvshlmevhP4Y+CbEafaBLwLOB94LfBO4PcAzGw+cAvwN8AZwFbgluxxaZi6YsbMlgDfBr4IvAxYDfxDbpcvAz/Itv0J8FUzO7PwC5KkarzG3Apc6O6LgNfQu9Z8KLe90fHS2U6mDHff6u5/DzwXsftG4C/dfY+77wX+Ergy23YJMA/4nLsfcff/DhhwafWtlkEqGDN/BHzH3W/I4uI5d38UwMzWABcCH3f3F9z9ZuAR4N3JGi+1KxIv7v4jd38mu2vAMXp/mLQiXoaykzGzz5vZ5ys63KuBH+bu/zB7bHrbw3587Z6Hc9ulJSqOmdcBB8zsH81sv5l9w8xWZtteDTzh7vmLTz6mpAUqjhfM7H1m9izwNL1PMl/MNjU+XubNvkv3uPsfVHi4hcDB3P2DwMLse5kTt01vf0mF55caVBwzK+j99flWen91/gW9lMfrCcfM8grPL4lVHC+4+98Cf2tm5wJXAPuyTY2Pl6H8JFOxQ8Ci3P1FwKHs08uJ26a3x6RUpLteAL7u7ve7+8+BTwK/YWYvRTEjfbj748A2YPpTUuPjRZ1MedvofXyddn722PS21+ZHm9EbHLANGWYPA/kUav7nbcCvmFn+024+pkTmAb+a/dz4eFEnMwMzGzOzcXrvzzwzGzez0cDuXwL+yMyWm9lZwEeB67Nt3wWOAh8ys1PM7A+zx+9I13oZhIIxcx2wIRvePgb8KXCPux9098eAh4CPZ8fYQO8Pk5treBlSkyLxYmb/zsxenv38KuA/A7cDtCJe3L3zN2AX8Jbc/S8AX+iz//X0/rrM367Mtr2BXjpsel+jl1M/kN3+ArDc9guAB+mlSL4PXDDo90O3wcZM9tjvA3uBnwHfAM7ObVtF7w+UF4Ad+Xbo1sxb4mvMdfS+g3k+O89/A8bbEi+WNVJERKRySpeJiEgy6mRERCQZdTIiIpJMqU7GzN5mZjvMbKeZXV1Vo6S7FDNShOKl/eb8xX823O4xerOW9wD3A+919+3VNU+6RDEjRSheuqFMWZmLgJ3u/gSAmd0IXAYEA2C+neLjLChxSpmrn/M8U37EZt8zqUIxo3gZnDbGCwxvzBw7fQHjr3iBVfMP8eOphRx5apyRZw7X2oZQzJTpZJYDP8nd3wOs7/eEcRaw3t5c4pQyV/f57YNuAhSMGcXL4LQxXmB4Y+bY2gvYuXGUifMeY3LbGtZe/yIj9zxUaxtCMZO8QKaZbaK35grjnJb6dNJyihcpSjEDY9v3sPq6FexaPMHqn04xtmMvRwfdqEyZTmYvcHbu/orsseO4+xZgC8AiW6yZn8Nt1phRvEiOrjGRjk5OMjI5yanT9wfamuOVGV12P3CumZ2TrfT4HnoruImEKGakCMVLB8z5k4y7v5gVfPwOMApc6+6NqfwpzaOYkSIUL91Q6jsZd/8W8K2K2iJDQDEjRShe2k8z/kVEJBl1MiIikow6GRERSUadjIiIJKNORkREklEnIyIiyaiTERGRZNTJiIhIMupkREQkGXUyIiKSjDoZERFJRp2MiIgko05GRESSUScjIiLJqJMREZFkSq0nIyIyLEaXvpyp81Zw5Iwxxg9MMbZ9D0cnJwfdrMbTJxkRkQhT563giSuNNVdtY+fGUX4xsWLQTWoFdTIiIhGOnDHGmyZ2cM3Ke3jDeY9xZPHYoJvUCkqXiYhEGD8wxR3bJ7jCR7h72xpW/3Rq0E1qBXUyIiIRxrbvYfV1K9i1eILVP51ibMdejg66US0wa7rMzK41s/1m9n9zjy02s9vM7PHs3zPSNlPaRDEjRbQlXo5OTjJy9w849ZbvMXLPQ/rSP1LMdzLXA2874bGrgdvd/Vzg9uy+yLTrqSFmRpe+nKOXXMjhDes59sYLGD3zzLKHlMG4Hl1jOmvWTsbd7wIOnPDwZcDW7OetwLuqbZa0WV0xo9E+3aBrTLfNdXTZUnd/Mvv5KWBpRe2R7qo8ZjTap9N0jemI0l/8u7ubmYe2m9kmYBPAOKeVPZ10QL+YKRIvGu1zsi5OGNQ1pt3m+klmn5ktA8j+3R/a0d23uPs6d183xilzPJ10QFTMFImX3mifo+zaPMHq619kbMfe6lvdMh1KIeoa0xFz7WRuBTZmP28EbqmmOdJhlceMRvucrEMpRF1jOmLWdJmZfRm4BFhiZnuAjwObgZvM7IPAbuDylI2UdkkZM11MB1UpJoXYtPdQ15hum7WTcff3Bja9ueK2SEekjJnpdNCbJrZxx/YJVl+3ghF1Mr8UM2Gwae+hrjHdphn/0iq9dNA2rll5D1f4CLsWT3DqoBvVIEcnJxmZnPzlezLTjHS9h1IndTLSKhpRNjf5FNnk2lHesWDfoJskQ0KdjLSK6kfNTT5F9o4F+9iw6CFgwaCbJUNAnYy0Skw6SE6WT5H1qIOReqiTERkC+TRjnlKOJ2va6Lu2UycjMgTyacY8pRxP1rTRd22nTkZkCJyYZjxuW+2taTaNvquWOhmREpRa6R6NYKyWOhmREpRa6R6NYKyWOhmREpRa6R6NYKyWOhmREpRaEelPnYxICUqtiPSnTkakBKVWRPpTJyMi0jBdGrU410XLREQkkQ6tcKpORkSkaTq0wqnSZdJe+ZRCXtvTCyJdGrWoTkZaKz8RMk+TIqXtujRqUZ2MtNbJ5et7NClS2q5LoxbVyYgMSJdGEImE6It/kQHp0ggikZBZOxkzO9vM7jSz7Wa2zcw+nD2+2MxuM7PHs3/PSN9caTrFS7wujSAqQzHTbTHpsheBj7r7983sJcCDZnYbcCVwu7tvNrOrgauBq9I1VVqitnhp+2qPXRpBVJKuMR02ayfj7k8CT2Y/P2dmjwLLgcuAS7LdtgLfRQEw9OqMl7av9tilEURl6BrTbYW++DezVcAFwH3A0iw4AJ4CllbbNGm71PHS9tUeuzSCqCq6xnRP9Bf/ZrYQuBn4iLs/m9/m7g544HmbzOwBM3vgFxwp1VhpD8WLFKWY6aaoTsbMxuj98m9w969lD+8zs2XZ9mXA/pme6+5b3H2du68b45Qq2iwNp3iRohQz3RUzusyAa4BH3f0zuU23AhuznzcCt1TfPGkbxYsUpZjptpjvZF4PfAB4xMweyh77GLAZuMnMPgjsBi5P0kJpG8WLFKWY6bCY0WX3ABbY/OZqmyNtp3iRohQz3aYZ/yIikoxql4k0gOqYSVfpk4xIA6iOmXSVOhmRBlAdM+kqpctEGkB1zKSr1MmINIDqmElXqZMRaQDVMZOuUicj0hIagSZtpC/+RVpCI9CkjdTJiLSERqBJGyldJlKCUlgi/emTjEgJSmGJ9KdORqQEpbBE+lO6TKSEOidRasKmtJE6GZES6pxEqQmb0kbqZERKqHMSpSZsShvpOxkREUlGnYyIiCSjTkZERJJRJyMiIsmokxERkWRm7WTMbNzMvmdmPzSzbWb2yezxc8zsPjPbaWZfMbP56ZsrTad4kaIUM90W80nmCHCpu58PrAXeZmavAz4NfNbdVwM/Az6YrJXSJooXKUox02GzdjLecyi7O5bdHLgU+Gr2+FbgXSkaKO2ieJGiFDPdFvWdjJmNmtlDwH7gNuBHwDPu/mK2yx5geZIWSusoXqQoxUx3RXUy7n7U3dcCK4CLgInYE5jZJjN7wMwe+AVH5tZKaRXFixSlmOmuQqPL3P0Z4E7g14HTzWy6LM0KYG/gOVvcfZ27rxvjlDJtlZZRvEhRipnuiRlddqaZnZ79fCrwVuBReoHwO9luG4FbErVRWkTxIkUpZrotpkDmMmCrmY3S65Rucve/M7PtwI1m9ingB8A1Cdsp7aF4kaIUMx1m7l7fycwmgeeBp2s76eAtoRmv95XufuagG1FEFi+7ac57WIemvNbWxQvoGjNgM8ZMrZ0MgJk94O7raj3pAA3b601hmN7DYXqtqQzbe9j016uyMiIikow6GRERSWYQncyWAZxzkIbt9aYwTO/hML3WVIbtPWz06639OxkRERkeSpeJiEgytXYyZvY2M9uRle6+us5zp2ZmZ5vZnWa2PStX/uHs8cVmdpuZPZ79e8ag29oWXY4XUMyk0OWYaWu81JYuyyZaPUZvNu8e4H7gve6+vZYGJGZmy4Bl7v59M3sJ8CC9qrFXAgfcfXMW9Ge4+1WDa2k7dD1eQDFTta7HTFvjpc5PMhcBO939CXefAm4ELqvx/Em5+5Pu/v3s5+folcVYTu81bs12U7nyeJ2OF1DMJNDpmGlrvNTZySwHfpK739nS3Wa2CrgAuA9Y6u5PZpueApYOql0tMzTxAoqZigxNzLQpXvTFf8XMbCFwM/ARd382v817uUkN55PjKGakiLbFS52dzF7g7Nz9YOnutjKzMXq//Bvc/WvZw/uyXOp0TnX/oNrXMp2PF1DMVKzzMdPGeKmzk7kfONfMzjGz+cB7gFtrPH9SZmb0qsQ+6u6fyW26lV6ZclC58iI6HS+gmEmg0zHT1nipuwrzbwOfA0aBa939z2s7eWJm9pvA3cAjwLHs4Y/Ry5neBKykV1H4cnc/MJBGtkyX4wUUMyl0OWbaGi+a8S8iIsnoi38REUlGnYyIiCSjTkZERJJRJyMiIsmokxERkWTUyYiISDLqZEREJBl1MiIikow6GRERSUadjIiIJKNORkREklEnIyIiyaiTERGRZDrbyZjZodztmJm9kLv//lme+2dm9oiZvWhmn4g414Vmdld27H1m9uHctlVmdqeZHTazfzKzt1Tw8iSBucaMmb3czL5sZv9iZgfN7P+Y2fo++59uZlvNbH92+8QJ2xUzLVBjvLwpi4eDZrZrhu2NjpfOdjLuvnD6Bvwz8M7cYzfM8vSdwB8D35ztPGa2BPg28EXgZcBq4B9yu3wZ+EG27U+Ar5rZmYVfkCRXImYW0lsw69eAxcBW4JvZMrkz+SxwGrAKuAj4gJn9bm67YqYFaoyX54Frgf8U2N7seHH3zt+AXcBb5vC8vwE+Mcs+/wX468C2NcAR4CW5x+4G/v2g3xPd0sRM7vnPAr8W2PY08K9y9z8G3K2Yae8tZbzk9nkLsOuExxofL539JNOPmX3ezD5f0eFeBxwws3/MUh/fMLOV2bZXA0+4+3O5/X+YPS4tUiRmzGwtMJ/eJ+Lgbif8/JrsZ8VMBySIl5DGx8u8QTdgENz9Dyo83ArgQuCt9JZF/Qt6H19fT+9j8cET9j8ILK/w/FKD2Jgxs0XAXwOfdPcTf/fTvg1cbWYbgaXAv6WXPgPFTCdUHC/9ND5ehvKTTMVeAL7u7ve7+8+BTwK/YWYvBQ4Bi07YfxHwHNI5ZnYq8A3gXnf/r312/RC9uHkcuIXeHyV7sm2KmSFRIF76aXy8qJMp72HAc/fzP28DfsXMXpJ77PzscekQMzsF+N/0Oovf67evux9w9/e7+yvc/dX0/h9+L9usmBkCReJlFo2PF3UyMzCzMTMbp/f+zDOzcTMbDex+HbDBzNaa2Rjwp8A97n7Q3R8DHgI+nh1jA/Ba4OYaXobUJPu9f5Xep5ON7n5slv1/1cxeZmajZvZ2YBPwKQDFTPfNIV5GsuvRWO+ujZvZfGhJvAx65EEdN04Y+QF8AfhCn/2vp/eJJH+7Mtv2BuDQCfv/PrAX+Bm9j79n57atAr5LL6B2UGIEim7NjBngjVmMHKaXvpi+vWGmmAEuB/4l2/8h4LdOOJ5ipmW3xPFyyQzXo++2JV4sa6SIiEjllC4TEZFk1MmIiEgypToZM3ubme0ws51mdnVVjZLuUsxIEYqX9pvzdzLZaKvH6E1C3EOvFs973X17dc2TLlHMSBGKl24o80nmImCnuz/h7lPAjcBl1TRLOkoxI0UoXjqgTFmZ5cBPcvf3AMFy1QDz7RQfZ0GJU8pc/ZznmfIjNvueSRWKmSWLR33V2WPHPfbjqYUceWqckWcOc+z0BYy/4gVWzT8UPGF+/6rkzzuo9sScK0ao/bt+8guePnC0VfECM8dMG8XESVUxUJVQzCSvXWZmm+hNNmOc01hvb059SpnBfX77oJsQJR8vp71iARNb/g1rFuzn8pc+yJqxBXzq6QluuPlSljx8lMm1o/zuv76Nq172ePB4+f1jjB+YYmz7Ho5OTgb3Obb2AnZuHGXivMeY3LaGtde/yMg9Dx33eL7Nj04d5qaD6/jR4SXH7T+69OVMnbeCI2ccf1GMacPhS9ez5qptXLPynqjXFXLF7ovZtXmCU2/53nHHvOi3fjL7kxsiHzMrl8/je985e8At6i8fD/k4ycv/XvLyMRMT/3UKxUyZTmYvkP9trsgeO467bwG2ACyyxZqUM9xmjZl8vCw8Y4Xv2jzB/Wtfw8i7nY8t2cHlL32QkXc7P3r7mbxjwT42LHoI+nw6zu8f447tE6y+bgUjfS7wY9v3sPq6FexaPMHqn04xtmMvR094PN/mmw6u4ytfvYQlD7943P5T563giSuNN00cXwEkpg1DovA1Zt35442/xuTjIR8nMfIxExP/TVCmk7kfONfMzqH3i38P8L5KWiVdVShmRp45zKm3fI8lI+uzTmIHa8YWZP8hp/9T9v8PdvL+/V3hI+xaPMGpffY5OjnJyOTkL/c5OsPj+Tb/6PASljz84i//Kp3e/8gZY7xp4uRPIzFtGBKdvMbk4yEfJzFOjplmdzBQopNx9xfN7A+B7wCjwLXu3piibNI8c42Z8QNT3LF9git8JJheyKcg8kL7h6xZsJ/7176GJSPHp/5DKaxQymty7SjvWLCv7zHz+4TEHL+MfNuqOmZVunqNiXnPy8RM05T6TsbdvwV8q6K2yBCYS8yE0lB5+RREXtF0RCi9FkphhVJe+VRG6Jgx6Y6Y45dRNP1Yty5eY2Le8zIx0zRDuWiZtEsoDZV3YkpqWtF0RCi9FkphhVJePQv6HjO/T0jM8csomn6U8mLe8zIx0zTqZKTT8qm2vJg0Wj4Fd/e2Naz+6VTfc4VSdiGhNoTSKUWPX/S8wyBmZJdUS52MdFo+1ZYXk0YLjQqL2T9GqA2hdErR4xc97zAoM7JL5kadjHTaiSPBpsWk0UKjwmL3n02oDaF0StHjFz3vMCgzskvmRp2MNNax0xdw+NL1UZMTi44KC41YK5oiS9GG1AZ13iZo8mi6rlInI401/ooXWHPVtqjJiUVHhRWdOBmjqjakNqjzNkHTR9N1kToZaaxV8w9xzcp7oiYnFh0VVnTiZIyq2pDaoM7bBBpNVz91MtJYP55ayBW7L+aBvSsZO2cevqFciiM/sTGfwsqnj2JSZCfWj5pt9FfMMVOk70Jtzgu1/8dTt8zpPMNCo9TiqZORxjry1Di7Nk8wds48jl58kDVn7S6V4shPbMynsEK1yGKOEzP6K+aYKdJ3oTbnhdp/5KnvzOEsw0Oj1OKpk5HGmq5d5hvWs+as3aXrNeUnNuZTWKFaZDHHyben3+iv2Y6ZIn3Xv815J7d/xKtbGqGLNEotnjoZGRoxabFQSi2UIsvLj1wqWussL1T3LGaUXcxrT5GOa7uite80Sq0nJsWqTkaGRkxaLJRSiymxnh+5VLTWWV6o7lmZJQBSp+ParmjtO41S64lJsaqTkaERkxYLpdRiSqznRy7NrdZZ3sl1z8osAZA6Hdd2RWvfaZRaT0yKVZ2MNF6o/lheVSmgoiPN6hQz2TOU7pP+mlZav6o6eKnl37djd9w74z7qZKTxQvXH8qpKARUdaVanmMmeoXSf9Ne00vpV1cFLLf++7dj+woz7qJORxgvVH8urKgVUdKRZnWIme4bSfdJf00rrV1UHL7X8+3bR/EMz7qNORhprunZZjJiVKIumlWImXRZNU8Sk/kKqSok0OSXYdikmaRYdkdg06mSksaZrl8WIWYmyaFopZtJl0TRFTOovpKqUSJNTgm2XYpJm0RGJTaNORhprunZZvP4rURZNK8VMuiyapohJ/YVUlRJpckqw7VJM0iw6IrFp1MlIY03XLgspk46ISRnFlMQPpb+KpqHKpERCEwOrmsgp8aqapNmlyZ7qZKSxpmuXhZRJR8SkjGJK4ofSX0XTUGVSIqGJgVVN5JR4VU3S7NJkz1k7GTO7FngHsN/dX5M9thj4CrAK2AVc7u4/S9dMaZOqYma6dllImXRETMoopiR+v/RXkTRUmZRIaGJgVRM5U+vSNaaqSZpdmuwZ80nmeuB/AF/KPXY1cLu7bzazq7P7V1XfPGmp66kwZmLK1OeFRvgUrS1WZ4qpqtUqyyw3MEDXo2tMo4X+T1VSu8zd7zKzVSc8fBlwSfbzVuC7KAAkU3XMxJSpzwuN8ClaW6zOFFNVq1WWWW5gUHSNab7Q/6mUtcuWuvuT2c9PAUvneBwZHnOOmZgy9XmhET5Fa4vVmWKqarXKMssNNIyuMQ0S+j9VS+0yd3cz89B2M9sEbAIY57Syp5MO6Bcz+XiZf+rpHH77+qjVJ4ummJpWqyqvzITNhqfF5qTINWblco1lSiFmBGPVtcv2mdkyd3/SzJYB+0M7uvsWYAvAIlscDBTpvKiYycfLkvOW+Jqr4lafLJpialqtqrwyEzabnBYraE7XmHXnj+sak0DMCMaqa5fdCmwENmf/akFwmU3hmDl+Mmb/1SeLppiaVqsqr8yETWhdWixE15gGiRnBOOfaZWb2ZXpfwC0xsz3Ax+n94m8ysw8Cu4HLS76GzhnmkuuDiJnQ5MqqRm3lj/PA3pWMnTMP3xBXV63oKpllSvfHHDNm/1DqIwVdY7otZnTZewOb3lxxWzplmEuuDyJmQpMrqxq1lT/O2DnzOHrxQdactTvquUVXySxTuj/mmDH7h1IfKega0236liwRlVyvV2hyZVWjtvLH8Q3rWXPW7ui6akVXySxTuj/mmDH7h1IfIkWpk0lE5dTLm6l2Wer3ssyqmlVNhExRiyxmNF3MxDoZTmWWMFAnk4jKqZc3U+2y1O9lmVU1q5oImaIWWcxoupiJdTKcyixhoE4mEZVTLy9Uuyzle1lmVc2qJkKmqEUWM5ouZmKdDKcySxiok5FOCI3CKrq6ZSitlDeoCZuhVF7oNcrwCKVq88qMriyz9IA6GemE0CisoqtbhtJKeYOasBlK5YVeowyPUKo2r8zoyjJLD6iTkU4IjcIqurpl/7RSXv0X8VAqL/QaZXj0S9VOKzO6sszSA+pkpBPKTMYsOmGz6FICMUsVxKQ7YlbwDAmVZ485vlSrzEitGEWXxkhNnYx0QpnJmEUnbBZdSiBmqYKYdEfMCp4hofLsMceXapUZqRWj6NIYqamTkU4oMxmz6ITNoksJxCxVEJPuyL+ufq99JqHy7DHHl2qVGakVo+jSGKmpk5FWKVPLq6p0QT4t9uyqUe7d+8rStdFCo9rK1DE78TiaHNxdRUd/pU7Z5amTkVYpU8urqnRBPi12795XMu+ul7Jr14LKRu/klaljlqfJwd1WdPRX6pRdnjoZaZXytbzK/7V2Ulps14LSqY/QqLYydczyNDm424qO/kqdsstTJyOdVnTkWGikVui5odRZfjmAMmm6mDpmJ7Z/mJaViJE6NVT0+E1IbZWZXFmUOhnptKIjx0IjtULPDaXO8ssBlEnTxdQxyxu2ZSVipE4NFT1+E1JbZSZXFqVORjqt6Mix0Eit0HNDqbOTlwOY23/gmDpmeVpW4mSpU0NFj9+E1FaZyZVFqZORoREaYRWavJYXk1KoMwURM6Gy6CqZXZX691Ln7z2kztFiRamTkaERGmEVmryWF5NSqDMFETOhsugqmV2V+vdS5+89pM7RYkWpk5GhERph1X/yWl7/C0edKYiYCZVFV8nsqtS/lzp/7yF1jhYrSp2MtErRlFdMaihmBFpMbbG8mOOkmLyZp7L/3RCTjkudstPKmDI0iqa8YlJDMSPQYmqL5cUcJ8XkzTyV/e+GmHRc6pRd0pUxzexs4EvAUsCBLe7+V2a2GPgKsArYBVzu7j+b42uQjkgdL0VTXjGpoZgRaLG1xabFHCfF5M2TNb+D0TWmv5h0XOqUXZm4DdcG//9eBD7q7q8CXgf8BzN7FXA1cLu7nwvcnt0XqSxejp2+gMMb1nPsjRcwemb4L3bIpbx2X8ynnp7gsV88X/qFPDp1mE9Ovoordl/M3dvWcEouNXf0kgtPalv+8cm1o6wJpDUm187ru88Q0jWm4crE7ayfZNz9SeDJ7OfnzOxRYDlwGXBJtttW4LvAVcWaLl1TZbyMv+IF1lwVV6er6KTLGDErUcaswpnXhJFITaNrTPPVtjKmma0CLgDuA5ZmwQHwFL2PuiK/VDZeVs0/FF2nq+ikyxgxK1HGrMKZ14SRSE2ma0wz1bIyppktBG4GPuLuz5rZL7e5u5uZB563CdgEMM5p0Q2TdqsiXlYu74VnaMXJkKJl7YsuB9D2svkxI9OO3XFvjS3qqTJmpFrJR5eZ2Ri9X/4N7v617OF9ZrbM3Z80s2XA/pme6+5bgC0Ai2zxjEEi3VJVvKw7f9yB4IqTIUXL2hddDqDtZfNjRqbt2P5CjS2qPmakWqlHlxlwDfCou38mt+lWYCOwOfv3luJNl65JES+hFSdDipa1L7ocQNvL5seMTLto/qHa2qNrTPOVGV0W80nm9cAHgEfM7KHssY/R+8XfZGYfBHYDlxduuXRRo+IltJpkTIoslKYrs0JlaFJn6hL0MefK7//jqVqv57XFTCjt0+TaX6nFvPYykz1jRpfdA1hg85ujzyRDoWnxUtVIsKpWqAxN6kxdgj7mXPn9jzz1ncraMZs6YyaU9mly7a/UYl57baPLRNqmqpFgVa1QGZrUmboEfcy58vuP+OHK2tEkobRPk2t/pRbz2msZXdZ0MeXah63Eedv9eGohV+y++LjHypToL7MiYcwospgaaKGRXUXbEyPU5qKrhXZJE8ryN03q96QznUxMufZhK3HedkeeGmfX5uNXqCxTor/MioQxo8hiJoSGRnYVbU+MUJtTTFxtC02GPVnq96QznUxMufZhK3HediPPHJ5TKfuThVaTjF+RMH/ekJgJof1HdhVrT4yZ2pxi4mpbaDLsyVK/J53pZPIpgJA2TpwbZsdOX8DhS8MjuwaV7sjH2gN7VzJ2zjx8Q9xE0byY0V+hlF2K9HA+bTKIyZiD1MY0WltGxHWmk8mnAELaOHFumM1Uu6zoxMkU8rE2ds48jl58kDVn7U4y0iymflpI0fbk0yZ1T8YctDam0doyIq4zncyJE+SC+9XSGqnCTLXLik6cTCEfa75hPWvO2p1spFlM/bSQou3Jp03qnIzZBG1Mo7VlRFxnOhnpriav+heqYxYawRUzYi2mNlqZ4zch5ZjK9IjEJqSPUqezmpDii5nAq05GGq/Jq/6F6piFRnDFjFiLqY1W5vhNSDmmMj0isQnpo9TprCak+GIm8KqTkcZr8qp/oTpmoRFcMSPWYmqjlTl+E1KOqUyPSGxC+ih1OqsJKb6YCbyd72TK1JmS9iiTmgilHWKWGGhafBVdFqFrpkckNiEN2IR0VmoxIxI738mUqTMl7VEmNRFKO8QsMdC0+Cq6LELXTI9IbEIasAnprNRiRiR2vpMpU2dK2qNMaiKUdohZYqBp8VV0WYSumR6R2DPYC3oT0lmpxYxI7Hwn0/ZVDCVO6tRETM2xovvEpLOK1mcLtXkYU2dVa8vkxxRi6uYN7eiytq9iKHFSpyZiao4V3ScmnVW0PluozcOYOqtaWyY/phBTN29oR5e1fRVDiZM6NRFTc6zoPjHprKL12UJtHsbUWdXaMvkxhZi6eUM7ukzar0yaoszqkDH7p1amLH9oYmaXS/2nnIwZk5Jte0otFP9lvmpQJyONVyZNUWZ1yJj9UytTlj80MbPLpf5TTsaMScm2PaUWiv8yXzWok5HGK5OmKLM6ZMz+qZUpyx+amNnlUv8pJ2PGpGTbnlLrlxab61cN6mSk8cqMHCs6wqpLI7JCo93a/rraIrT8SOqUb9Ook5HGKzNyrOgIqy6NyAqNdmv762qL0PIjqVO+TTNrJ2Nm48BdwCnZ/l9194+b2TnAjcDLgAeBD7i7JqEMuRTxUmbkWNERVl0akRUa7da019XVa0xo+ZHUKd+mifkkcwS41N0PmdkYcI+Z/T3wR8Bn3f1GM/sC8EHgfyVsq7RDbfFSVYn+0HPLpOlSjDKqalRYA2tqVRYzg65dFkq3tmV5hVCKNcaca5e5uwPT9QLGspsDlwLvyx7fCnwCdTJDr854qapEf+i5ZdJ0KUYZVTUqrGk1taqMmUHXLgulW9uyvEIoxRqjVO0yMxul93F1NfA/gR8Bz7j7dJJwD7A88NxNwCaAcU4r1Ghpp6riZeXy/uFZVYn+0HPLpOlSjDKqalRYE2tqVRkzg6xdFkq3tmV5hf4TivsrVbvM3Y8Ca83sdODrwET/Zxz33C3AFoBFtthjnyftVVW8rDt/vG+8hNI+RSeUpZicmG/bs6tGuXfvK08aZVRVLb22jz6C+mKmTlWlW+uM2xQKjS5z92fM7E7g14HTzWxe9pfGCmBvigZKe6WOl1Dap+iEshSTE/Ntu3fvK5l310vZtev4C0BVtfTaPvoor0vXmKrSrXXGbQoxo8vOBH6R/fJPBd4KfBq4E/gdeqM/NgIzl+CUoVJnvITSPkUnlKWYnHhS2mTXgkonuOW1ffRRV68xVaVbob64TSHmk8wyYGuWMx0BbnL3vzOz7cCNZvYp4AfANQnbKe2heJGiFDMdZr2BHTWdzGwSeB54uraTDt4SmvF6X+nuxYeMDFAWL7tpzntYh6a81tbFC+gaM2AzxkytnQyAmT3g7utqPekADdvrTWGY3sNheq2pDNt72PTXOzL7LiIiInOjTkZERJIZRCezZQDnHKRhe70pDNN7OEyvNZVhew8b/Xpr/05GRESGh9JlIiKSTK2djJm9zcx2mNlOM7u6znOnZmZnm9mdZrbdzLaZ2Yezxxeb2W1m9nj27xmDbmtbdDleQDGTQpdjpq3xUlu6LJto9Ri92bx7gPuB97r79loakJiZLQOWufv3zewl9Ir9vQu4Ejjg7puzoD/D3a8aXEvboevxAoqZqnU9ZtoaL3V+krkI2OnuT2QLD90IXFbj+ZNy9yfd/fvZz88Bj9KrGnsZvTLlZP++ayANbJ9OxwsoZhLodMy0NV7q7GSWAz/J3Q+W7m47M1sFXADcByx19yezTU8BSwfVrpYZmngBxUxFhiZm2hQv+uK/Yma2ELgZ+Ii7P5vfli3OpOF8chzFjBTRtnips5PZC5ydu9+60t2zyZaOvRm4wd2/lj28L8ulTudU9w+qfS3T+XgBxUzFOh8zbYyXOjuZ+4FzzewcM5sPvAe4tcbzJ2VmRq9K7KPu/pncplvplSmHFpYrH6BOxwsoZhLodMy0NV7qrsL828DngFHgWnf/89pOnpiZ/SZwN/AIcCx7+GP0cqY3ASvpVRS+3N0PDKSRLdPleAHFTApdjpm2xotm/IuISDL64l9ERJJRJyMiIsmokxERkWTUyYiISDLqZEREJBl1MiIikow6GRERSUadjIiIJPP/AK3l+0UQ13+JAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "for i,j in enumerate(np.random.choice(n_train, 6)):\n", " plt.subplot(2,3,i+1)\n", " image = x_train[j]\n", " plot = plt.imshow(image)\n", " plt.title(\"T: %.2f\" % T_train[j])\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Definition of the model\n", "\n", "Define a CNN for discriminative localization. Note that the CNN must use `GlobalAveragePooling2D` after the convolutional part and must not feature more than a single fully-connected layer as output." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "reshape (Reshape) (None, 32, 32, 1) 0 \n", "_________________________________________________________________\n", "conv2d (Conv2D) (None, 32, 32, 16) 160 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 32, 32, 16) 2320 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 16, 16, 16) 0 \n", "_________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 16, 16, 32) 4640 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 16, 16, 32) 9248 \n", "_________________________________________________________________\n", "global_average_pooling2d (Gl (None, 32) 0 \n", "_________________________________________________________________\n", "dropout (Dropout) (None, 32) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 2) 66 \n", "=================================================================\n", "Total params: 16,434\n", "Trainable params: 16,434\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model = keras.models.Sequential()\n", "model.add(layers.InputLayer(input_shape=(32, 32)))\n", "model.add(layers.Reshape((32, 32,1)))\n", "model.add(layers.Convolution2D(16, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.Convolution2D(16, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.MaxPooling2D((2, 2)))\n", "model.add(layers.Convolution2D(32, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.Convolution2D(32, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.GlobalAveragePooling2D())\n", "model.add(layers.Dropout(0.25))\n", "model.add(layers.Dense(2, activation='softmax'))\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### prepare model for training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model.compile(\n", " loss='binary_crossentropy',\n", " optimizer=keras.optimizers.Adam(0.001),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "282/282 - 10s - loss: 0.1006 - accuracy: 0.9611 - val_loss: 0.0612 - val_accuracy: 0.9750\n", "Epoch 2/50\n", "282/282 - 14s - loss: 0.0476 - accuracy: 0.9814 - val_loss: 0.0415 - val_accuracy: 0.9825\n", "Epoch 3/50\n", "282/282 - 23s - loss: 0.0464 - accuracy: 0.9810 - val_loss: 0.0555 - val_accuracy: 0.9785\n", "Epoch 4/50\n", "282/282 - 24s - loss: 0.0480 - accuracy: 0.9804 - val_loss: 0.0364 - val_accuracy: 0.9845\n", "Epoch 5/50\n", "282/282 - 24s - loss: 0.0465 - accuracy: 0.9809 - val_loss: 0.0379 - val_accuracy: 0.9840\n", "Epoch 6/50\n", "282/282 - 25s - loss: 0.0443 - accuracy: 0.9811 - val_loss: 0.0435 - val_accuracy: 0.9835\n", "\n", "Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.0006700000318232924.\n", "Epoch 7/50\n", "282/282 - 23s - loss: 0.0421 - accuracy: 0.9832 - val_loss: 0.0377 - val_accuracy: 0.9840\n", "Epoch 8/50\n", "282/282 - 25s - loss: 0.0424 - accuracy: 0.9823 - val_loss: 0.0388 - val_accuracy: 0.9835\n", "\n", "Epoch 00008: ReduceLROnPlateau reducing learning rate to 0.0004489000252215192.\n", "Epoch 9/50\n", "282/282 - 25s - loss: 0.0396 - accuracy: 0.9828 - val_loss: 0.0390 - val_accuracy: 0.9835\n", "Epoch 00009: early stopping\n" ] } ], "source": [ "results = model.fit(x_train, y_train,\n", " batch_size=64,\n", " epochs=50,\n", " verbose=2,\n", " validation_split=0.1,\n", " callbacks=[\n", " callbacks.EarlyStopping(patience=5, verbose=1),\n", " callbacks.ReduceLROnPlateau(factor=0.67, patience=2, verbose=1)]\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate training" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(1, (12, 4))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(results.history['loss'])\n", "plt.plot(results.history['val_loss'])\n", "plt.ylabel('loss')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'val'], loc='upper right')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(1, (12, 4))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(results.history['accuracy'])\n", "plt.plot(results.history['val_accuracy'])\n", "plt.ylabel('accuracy')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'val'], loc='upper right')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create class activation maps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Task\n", "Look at Exercise 8.1 to extract weights and feature maps from the trained network model. \n", "\n", "First, extract the activations of the last convolutional layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conv = model.layers[-4] # last conv layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, create the class activation maps by omitting the global average pooling operation and applying the weights of the single classification layer to the extracted activations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the class activation maps for examples just below and above the critical temperature" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Task\n", "Plot the CAMs for wrongly and correctly classified images. \n", "\n", "Note, you can use `interpolation='bilinear` in `plt.imshow` to upsample the CAMs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }