{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 07 : ConvNet in TensorFlow\n", "same contents, but different style with [Lec07_ConvNet mnist with Weight initialization and Drop out.ipynb](https://nbviewer.jupyter.org/github/aisolab/CS20/blob/master/Lec07_ConvNet%20in%20Tensorflow/Lec07_ConvNet%20mnist%20with%20Weight%20initialization%20and%20Drop%20out.ipynb)\n", "\n", "### ConvNet mnist with Weight initialization and Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Using `tf.keras`, alias `keras` and `eager execution`\n", "- Creating the model as **Class** by subclassing `tf.keras.Model`\n", "- Initializaing weights of model with **He initialization** by `tf.keras.initializers.he_uniform`\n", "- Training the model with **Drop out** technique by `tf.keras.layers.Dropout`\n", "- Using tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.12.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "%matplotlib inline\n", "\n", "print(tf.__version__)\n", "tf.enable_eager_execution()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and Pre-process data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n", "x_train = x_train / 255\n", "x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)\n", "x_tst = x_tst / 255\n", "x_tst = x_tst.reshape(-1, 28, 28, 1).astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(55000, 28, 28, 1) (55000,)\n", "(5000, 28, 28, 1) (5000,)\n" ] } ], "source": [ "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n", "\n", "x_tr = x_train[tr_indices]\n", "y_tr = y_train[tr_indices].astype(np.int32)\n", "\n", "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n", "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0).astype(np.int32)\n", "\n", "print(x_tr.shape, y_tr.shape)\n", "print(x_val.shape, y_val.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define SimpleCNN class by high-level api" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class SimpleCNN(keras.Model):\n", " def __init__(self, num_classes):\n", " super(SimpleCNN, self).__init__()\n", " self.__conv1 = keras.layers.Conv2D(filters=32, kernel_size=[5,5], padding='same',\n", " kernel_initializer=keras.initializers.he_uniform(),\n", " bias_initializer=keras.initializers.he_uniform(),\n", " activation=tf.nn.relu)\n", " self.__conv2 = keras.layers.Conv2D(filters=64, kernel_size=[5,5], padding='same',\n", " kernel_initializer=keras.initializers.he_uniform(),\n", " bias_initializer=keras.initializers.he_uniform(),\n", " activation=tf.nn.relu)\n", " self.__pool = keras.layers.MaxPooling2D()\n", " self.__flatten = keras.layers.Flatten()\n", " self.__dropout = keras.layers.Dropout(rate =.5)\n", " self.__dense1 = keras.layers.Dense(units=1024, activation=tf.nn.relu, \n", " kernel_initializer=keras.initializers.he_uniform(),\n", " bias_initializer=keras.initializers.he_uniform())\n", " self.__dense2 = keras.layers.Dense(units=num_classes,\n", " kernel_initializer=keras.initializers.he_uniform(),\n", " bias_initializer=keras.initializers.he_uniform())\n", " \n", " def call(self, inputs, training=False):\n", " conv1 = self.__conv1(inputs)\n", " pool1 = self.__pool(conv1)\n", " conv2 = self.__conv2(pool1)\n", " pool2 = self.__pool(conv2)\n", " flattened = self.__flatten(pool2)\n", " fc = self.__dense1(flattened)\n", " if training:\n", " fc = self.__dropout(fc, training=training)\n", " score = self.__dense2(fc)\n", " return score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing for training a model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "550\n" ] } ], "source": [ "# hyper-parameter\n", "lr = .001\n", "epochs = 10\n", "batch_size = 100\n", "total_step = int(x_tr.shape[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)> <BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)>\n" ] } ], "source": [ "## create input pipeline with tf.data\n", "# for train\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 10000)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "\n", "# for validation\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))\n", "val_dataset = val_dataset.batch(batch_size = batch_size)\n", "print(tr_dataset, val_dataset)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "\n", "## create model \n", "cnn = SimpleCNN(num_classes=10)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def loss_fn(model, x, y, training):\n", " score = model(x, training=training)\n", " return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training a model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "## for tensorboard\n", "# creating check point (Object-based saving)\n", "checkpoint_dir = '../graphs/lecture07/convnet_mnist_drop_out_kde/'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, 'cnn')\n", "checkpoint = tf.train.Checkpoint(cnn=cnn)\n", "\n", "# create writer for tensorboard\n", "summary_writer = tf.contrib.summary.create_file_writer(logdir=checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.189, val_loss : 0.051\n", "epoch : 2, tr_loss : 0.055, val_loss : 0.044\n", "epoch : 3, tr_loss : 0.037, val_loss : 0.041\n", "epoch : 4, tr_loss : 0.028, val_loss : 0.042\n", "epoch : 5, tr_loss : 0.024, val_loss : 0.035\n", "epoch : 6, tr_loss : 0.020, val_loss : 0.033\n", "epoch : 7, tr_loss : 0.017, val_loss : 0.037\n", "epoch : 8, tr_loss : 0.012, val_loss : 0.045\n", "epoch : 9, tr_loss : 0.012, val_loss : 0.034\n", "epoch : 10, tr_loss : 0.013, val_loss : 0.036\n" ] }, { "data": { "text/plain": [ "'../graphs/lecture07/convnet_mnist_drop_out_kde/cnn-1'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tr_loss_hist = []\n", "val_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", "\n", " avg_tr_loss = 0\n", " avg_val_loss = 0\n", " tr_step = 0\n", " val_step = 0\n", "\n", " with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): # for tensorboard\n", " # for training\n", " for x_mb, y_mb in tr_dataset:\n", " with tf.GradientTape() as tape:\n", " tr_loss = loss_fn(cnn, x_mb, y_mb, training = True) \n", " grads = tape.gradient(target=tr_loss, sources=cnn.variables)\n", " opt.apply_gradients(grads_and_vars=zip(grads, cnn.variables))\n", " tf.contrib.summary.scalar(name='tr_loss', tensor=tr_loss)\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " else:\n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " # for validation\n", " for x_mb, y_mb in val_dataset:\n", " val_loss = loss_fn(cnn, x_mb, y_mb, training = False)\n", " tf.contrib.summary.scalar(name='val_loss', tensor=val_loss)\n", " avg_val_loss += val_loss\n", " val_step += 1\n", " else:\n", " avg_val_loss /= val_step\n", " val_loss_hist.append(avg_val_loss)\n", "# if (epoch + 1) % 5 == 0:\n", " print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n", "\n", "checkpoint.save(file_prefix=checkpoint_prefix)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.legend.Legend at 0x7fbc9259c048>" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd4XPWd7/H3d0a9j1VcVG1s3JskGyeEakIMSQwkgCGkkLsbns2GZXN3szdk7y5JWHiWPE8uYXPDkjgJbAolXCeULE6cBUxJFhMXjLFcsHCT5KJiS1Zv871/nCNpJMvWSB7pSJrv63nOo5nfKfObAZ/PnF85I6qKMcYY4/O6AsYYY8YHCwRjjDGABYIxxhiXBYIxxhjAAsEYY4zLAsEYYwxggWCMMcZlgWCMMQawQDDGGOOK8boCw5GVlaVFRUVeV8MYYyaU7du316pq9lDbTahAKCoqYtu2bV5XwxhjJhQRORLOdtZkZIwxBrBAMMYY47JAMMYYA0ywPgRjzOTR2dlJZWUlbW1tXldl0khISCAvL4/Y2NgR7W+BYIzxRGVlJampqRQVFSEiXldnwlNV6urqqKysZObMmSM6hjUZGWM80dbWRmZmpoVBhIgImZmZF3TFZYFgjPGMhUFkXejnGRWB8MLOKn65JaxhuMYYE7WiIhA2lZ3gR2984HU1jDHjSH19Pf/+7/8+7P2uv/566uvrR6FG3ouKQCguCFBxqpXqMzaawRjjOFcgdHV1nXe/jRs3kpGRMVrV8lR0BEJhAIAdR097XBNjzHhx77338sEHH7Bs2TJWrFjBZZddxtq1a1mwYAEAN954IyUlJSxcuJD169f37ldUVERtbS2HDx9m/vz5fOlLX2LhwoVce+21tLa2evV2IiIqhp0unJFGXIyPHUfrWbNoutfVMcYM8O3flrHn2JmIHnPBjDS++cmF51z/0EMPsXv3bnbu3Mlrr73Gxz/+cXbv3t07ZPPxxx9nypQptLa2smLFCj796U+TmZnZ7xgHDhzg6aef5sc//jG33norv/71r/nsZz8b0fcxlqIiEOJj/CzJTWf7EbtCMMYMbuXKlf3G73//+9/nueeeA6CiooIDBw6cFQgzZ85k2bJlAJSUlHD48OExq+9oiIpAACgpDPDEnw7T3tVNfIzf6+oYY0Kc75v8WElOTu59/Nprr/Hyyy/z1ltvkZSUxJVXXjno+P74+Pjex36/f8I3GUVFHwI4/Qgd3UF2V0X2stQYMzGlpqbS2Ng46LqGhgYCgQBJSUns27ePLVu2jHHtvBE1VwjFBW7H8pHTlLidzMaY6JWZmcmll17KokWLSExMZOrUqb3r1qxZww9/+EPmz5/P3LlzWbVqlYc1HTtREwjZqfEUTEmykUbGmF5PPfXUoOXx8fH87ne/G3RdTz9BVlYWu3fv7i3/2te+FvH6jbWoaTICpx9h25HTqKrXVTHGmHEnqgKhuDBATWM7lacndsePMcaMhugKhAJndqE1GxljzNnCCgQRWSMi+0WkXETuHWT95SKyQ0S6ROTmkPKrRGRnyNImIje66/5DRA6FrFsWubc1uLlTU0mO87PD5iMYY8xZhuxUFhE/8CjwUaAS2CoiL6rqnpDNjgJ3Av16VVR1M7DMPc4UoBz4Q8gm/6CqGy7kDQxHjN/HsoIMttsVgjHGnCWcK4SVQLmqHlTVDuAZ4IbQDVT1sKruAoLnOc7NwO9UtWXEtY2AkoIAe4830tx+/htYGWNMtAknEHKBipDnlW7ZcN0GPD2g7EER2SUi3xOR+MF2EpG7RGSbiGyrqakZwcv2t7wwQHdQebdyct6+1hgzelJSUgA4duwYN99886DbXHnllWzbtu28x3nkkUdoaen7bjxebqk9Jp3KIjIdWAxsCin+BjAPWAFMAb4+2L6qul5VS1W1NDs7+4LrUpzvTEp756j3H74xZmKaMWMGGzaMvLV7YCCMl1tqhxMIVUB+yPM8t2w4bgWeU9XOngJVPa6OduAJnKapUZeeFMucnBS70Z0xhnvvvZdHH3209/m3vvUtHnjgAVavXk1xcTGLFy/mhRdeOGu/w4cPs2jRIgBaW1u57bbbmD9/PjfddFO/+xl9+ctfprS0lIULF/LNb34TcG6ad+zYMa666iquuuoqoO+W2gAPP/wwixYtYtGiRTzyyCO9rzcWt9oOZ6byVmCOiMzECYLbgM8M83Vux7ki6CUi01X1uDg/AnojsHvQPUdBSWGA35edIBhUfD77TVdjPPe7e+HEe5E95rTFcN1D591k3bp1fPWrX+UrX/kKAM8++yybNm3innvuIS0tjdraWlatWsXatWvP+XvFjz32GElJSezdu5ddu3ZRXFzcu+7BBx9kypQpdHd3s3r1anbt2sU999zDww8/zObNm8nKyup3rO3bt/PEE0/w9ttvo6pccsklXHHFFQQCgTG51faQVwiq2gXcjdPcsxd4VlXLROR+EVkLICIrRKQSuAX4kYiU9ewvIkU4VxivDzj0kyLyHvAekAU8cOFvJzzFBQHqWzo5WNs8Vi9pjBmHli9fTnV1NceOHePdd98lEAgwbdo0/vEf/5ElS5ZwzTXXUFVVxcmTJ895jDfeeKP3xLxkyRKWLFnSu+7ZZ5+luLiY5cuXU1ZWxp49e851GAD++Mc/ctNNN5GcnExKSgqf+tSnePPNN4GxudV2WPcyUtWNwMYBZfeFPN6K05Q02L6HGaQTWlWvHk5FIyn0F9Rm56R4VQ1jTI8hvsmPpltuuYUNGzZw4sQJ1q1bx5NPPklNTQ3bt28nNjaWoqKiQW99PZRDhw7x3e9+l61btxIIBLjzzjtHdJweY3Gr7aiaqdxjVlYyGUmxNkHNGMO6det45pln2LBhA7fccgsNDQ3k5OQQGxvL5s2bOXLkyHn3v/zyy3tvkrd792527doFwJkzZ0hOTiY9PZ2TJ0/2u1neuW69fdlll/H888/T0tJCc3Mzzz33HJdddlkE3+35Rc3dTkP5fEJxQcA6lo0xLFy4kMbGRnJzc5k+fTp33HEHn/zkJ1m8eDGlpaXMmzfvvPt/+ctf5otf/CLz589n/vz5lJSUALB06VKWL1/OvHnzyM/P59JLL+3d56677mLNmjXMmDGDzZs395YXFxdz5513snKlM8bmL//yL1m+fPmY/RKbTKQ7f5aWlupQ43vD9YNXD/DdP7zPu/ddS3pSbESOaYwJ3969e5k/f77X1Zh0BvtcRWS7qpYOtW9UNhlBXz/COxV2lWCMMRDFgbA0LwO/T6wfwRhjXFEbCMnxMcyfnmo3ujPGQxOpyXoiuNDPM2oDAZz5CDuP1tPVfb578hljRkNCQgJ1dXUWChGiqtTV1ZGQkDDiY0TlKKMeJYUBfv7WEfafbGThjHSvq2NMVMnLy6OyspJI3LTSOBISEsjLG3RKWFiiOhCKC3omqNVbIBgzxmJjY5k5c6bX1TAhorrJKC+QSE5qvHUsG2MMUR4IIjZBzRhjekR1IIDTj3D0VAs1je1eV8UYYzwV9YEQeqM7Y4yJZlEfCIty04jz+6wfwRgT9aI+EOJj/CzKTbN+BGNM1Iv6QACnH2FXVQMdXTZBzRgTvSwQcAKhoytI2bEGr6tijDGesUCgb4KaNRsZY6KZBQKQk5ZAXiDRRhoZY6JaWIEgImtEZL+IlIvIvYOsv1xEdohIl4jcPGBdt4jsdJcXQ8pnisjb7jF/JSJxF/52Rq6k0JmgZjfaMsZEqyEDQUT8wKPAdcAC4HYRWTBgs6PAncBTgxyiVVWXucvakPLvAN9T1dnAaeAvRlD/iCkpDHDyTDvHGkb+I9jGGDORhXOFsBIoV9WDqtoBPAPcELqBqh5W1V1AWMN0RESAq4ENbtHPgBvDrvUosH4EY0y0CycQcoGKkOeVblm4EkRkm4hsEZGek34mUK+qXUMdU0TucvffNpq3yZ03LZWkOL9NUDPGRK2xuP11oapWicgs4FUReQ8Ie3ynqq4H1gOUlpaOWgN/jN/H0rwMu0IwxkStcK4QqoD8kOd5bllYVLXK/XsQeA1YDtQBGSLSE0jDOuZoKSkMsOf4GVo6uobe2BhjJplwAmErMMcdFRQH3Aa8OMQ+AIhIQETi3cdZwKXAHnWG8mwGekYkfQF4YbiVj7SSwgDdQWVXpU1QM8ZEnyEDwW3nvxvYBOwFnlXVMhG5X0TWAojIChGpBG4BfiQiZe7u84FtIvIuTgA8pKp73HVfB/5ORMpx+hR+Gsk3NhLLCzIA61g2xkSnsPoQVHUjsHFA2X0hj7fiNPsM3O+/gcXnOOZBnBFM40ZGUhwXZSdbx7IxJirZTOUBSgoD7DhqE9SMMdHHAmGAksIAp1s6OVTb7HVVjDFmTFkgDFBSaBPUjDHRyQJhgFlZKaQlxNiN7owxUccCYQCfTyh2b3RnjDHRxAJhECUFAQ5UN9HQ2ul1VYwxZsxYIAyipDCAKuysqPe6KsYYM2YsEAaxND8Dn1jHsjEmulggDCI5PoZ509JsgpoxJqpYIJxDSWGAnRX1dAdtgpoxJjpYIJxDSWGApvYu3j/Z6HVVjDFmTFggnIP9gpoxJtpYIJxD/pREslLirR/BGBM1LBDOQUQoKcywGcvGmKhhgXAeJYUBDte1UNvU7nVVjDFm1FkgnEdPP4I1GxljooEFwnksyk0n1i9st2YjY0wUsEA4j4RYP4ty0+0KwRgTFSwQhlBSEGBXZQMdXUGvq2KMMaMqrEAQkTUisl9EykXk3kHWXy4iO0SkS0RuDilfJiJviUiZiOwSkXUh6/5DRA6JyE53WRaZtxRZxYUB2ruC7Dl+xuuqGGPMqBoyEETEDzwKXAcsAG4XkQUDNjsK3Ak8NaC8Bfi8qi4E1gCPiEhGyPp/UNVl7rJzhO9hVNkvqBljokU4VwgrgXJVPaiqHcAzwA2hG6jqYVXdBQQHlL+vqgfcx8eAaiA7IjUfI1PTEsjNSLR+BGPMpBdOIOQCFSHPK92yYRGRlUAc8EFI8YNuU9L3RCT+HPvdJSLbRGRbTU3NcF82IkoKAzZBzRgz6Y1Jp7KITAd+AXxRVXuuIr4BzANWAFOArw+2r6quV9VSVS3Nzvbm4qK4IIPjDW0cq2/15PWNMWYshBMIVUB+yPM8tywsIpIGvAT8b1Xd0lOuqsfV0Q48gdM0NS6VFE4BrB/BGDO5hRMIW4E5IjJTROKA24AXwzm4u/1zwM9VdcOAddPdvwLcCOweTsXH0rzpqSTG+i0QjDGT2pCBoKpdwN3AJmAv8KyqlonI/SKyFkBEVohIJXAL8CMRKXN3vxW4HLhzkOGlT4rIe8B7QBbwQETfWQTF+n0szU/nHetHMMZMYjHhbKSqG4GNA8ruC3m8FacpaeB+vwR+eY5jXj2smnqsuCDA+jcO0trRTWKc3+vqGGNMxNlM5TCVFAboCiq7Kuu9rooxxowKC4QwLe/5BTVrNjLGTFIWCGGakhzHrOxkdhyxKwRjzORkgTAMxQXOBDVV9boqxhgTcRYIw1BSGOBUcweH61q8rooxxkScBcIw2I3ujDGTmQXCMMzOTiE1Icbua2SMmZQsEIbB5xOWFwTszqfGmEnJAmGYSgoC7D/ZyJm2Tq+rYowxEWWBMEwlhQFUYedRG35qjJlcLBCGaWl+Oj6xjmVjzORjgTBMqQmxXDw11TqWjTGTjgXCCJQUBth5tJ7uoE1QM8ZMHhYII1BSGKCxvYsD1Y1eV8UYYyLGAmEEbIKaMWYyskAYgYIpSWSlxNmN7owxk4oFwgiIuBPUrGPZGDOJWCCMUElhgEO1zdQ1tXtdFWOMiQgLhBHq6UfYYRPUjDGTRFiBICJrRGS/iJSLyL2DrL9cRHaISJeI3Dxg3RdE5IC7fCGkvERE3nOP+X0RkQt/O2NncW46sX6xZiNjzKQxZCCIiB94FLgOWADcLiILBmx2FLgTeGrAvlOAbwKXACuBb4pIwF39GPAlYI67rBnxu/BAQqyfBTPSbaSRMWbSCOcKYSVQrqoHVbUDeAa4IXQDVT2sqruA4IB9Pwb8l6qeUtXTwH8Ba0RkOpCmqlvU+fmxnwM3XuibGWslBQHerains3vg2zbGmIknnEDIBSpCnle6ZeE417657uORHHPcKCkM0N4VZM+xM15XxRhjLti471QWkbtEZJuIbKupqfG6Ov0UF2YAWD+CMWZSCCcQqoD8kOd5blk4zrVvlft4yGOq6npVLVXV0uzs7DBfdmxMT09kRnqC9SMYYyaFcAJhKzBHRGaKSBxwG/BimMffBFwrIgG3M/laYJOqHgfOiMgqd3TR54EXRlB/zxUX2i+oGWMmhyEDQVW7gLtxTu57gWdVtUxE7heRtQAiskJEKoFbgB+JSJm77yngX3BCZStwv1sG8NfAT4By4APgdxF9Z2OkpDDAsYY2jje0el0VY4y5IDHhbKSqG4GNA8ruC3m8lf5NQKHbPQ48Pkj5NmDRcCo7HvVOUDtSz8eXJHpcG2OMGblx36k83s2fnkZCrM/6EYwxE54FwgWK9ftYkpfBdhtpZIyZ4CwQIqCkMEBZVQNtnd1eV8UYY0bMAiECSgoCdAWVXZUNXlfFGGNGzAIhApYX2AQ1Y8zEZ4EQAZkp8czMSraOZWPMhGaBECHFBc4ENedefcYYM/FYIERISWGAuuYOjtS1eF0VY4wZEQuECLEb3RljJjoLhAiZk5NKanyM9SMYYyYsC4QI8fuEZQUZFgjGmAnLAiGCSgoD7D/ZSGNbp9dVMcaYYbNAiKDiggCq8G6FTVAzxkw8FggRtKwgAxGs2cgYMyFZIERQWkIsc6em2o3ujDETkgVChBUXBnjn6GmCQZugZoyZWCwQIqy4IEBjWxflNU1eV8UYY4bFAiHCen5BzfoRjDETjQVChBVlJjElOc4CwRgz4YQVCCKyRkT2i0i5iNw7yPp4EfmVu/5tESlyy+8QkZ0hS1BElrnrXnOP2bMuJ5JvzCsi0nujO2OMmUiGDAQR8QOPAtcBC4DbRWTBgM3+AjitqrOB7wHfAVDVJ1V1maouAz4HHFLVnSH73dGzXlWrI/B+xoXiwgwO1jZzqrnD66oYY0zYwrlCWAmUq+pBVe0AngFuGLDNDcDP3McbgNUiIgO2ud3dd9IrKXD6Ed6x4afGmAkknEDIBSpCnle6ZYNuo6pdQAOQOWCbdcDTA8qecJuL/nmQAJmwluRlEOMT60cwxkwoY9KpLCKXAC2qujuk+A5VXQxc5i6fO8e+d4nINhHZVlNTMwa1vXCJcX4WzkizQDDGTCjhBEIVkB/yPM8tG3QbEYkB0oG6kPW3MeDqQFWr3L+NwFM4TVNnUdX1qlqqqqXZ2dlhVHd8WF4Q4N3Kejq7g15XxRhjwhJOIGwF5ojITBGJwzm5vzhgmxeBL7iPbwZeVfe3JEXEB9xKSP+BiMSISJb7OBb4BLCbSaSkMEBbZ5B9xxu9rooxxoRlyEBw+wTuBjYBe4FnVbVMRO4XkbXuZj8FMkWkHPg7IHRo6uVAhaoeDCmLBzaJyC5gJ84Vxo8v+N2MI30T1E55XBNjjAlPTDgbqepGYOOAsvtCHrcBt5xj39eAVQPKmoGSYdZ1QpmRkcj09AS2H63nzku9ro0xxgzNZiqPIpugZoyZSCwQRlFxYYCq+lZONLR5XRVjjBmSBcIo6ulH2GET1IwxE4AFwihaMD2N+BifzUcwxkwIFgijKC7Gx9K8DAsEY8yEYIEwypYXZlB2rIG2zm6vq2KMMedlgTDKSgoCdHYru6savK6KMcaclwXCKCu2X1AzxkwQFgijLCslnqLMJAsEY8y4Z4EwBooLAuw4Wo97eydjjBmXLBDGQHFhgNqmdipOtXpdFWOMOScLhDHQe6O7o3ajO2PM+GWBMAYunppKSnyM9SMYY8Y1C4Qx4PcJy/Iz2H6k3uuqGGPMOVkgjJHiwgD7T5yhqb3L66oYY8ygLBDGSElhgKDCuxV2lWCMGZ8sEMbIsvwMROAPZSds+KkxZlyyQBgj6YmxXLtgKj976wjr1m9h/wn7rWVjzPhigTCGHrujhH/91GLeP9nI9d9/kwdf2mN9CsaYcSOsQBCRNSKyX0TKReTeQdbHi8iv3PVvi0iRW14kIq0istNdfhiyT4mIvOfu830RkUi9qfHK5xNuX1nAq39/JbeU5PHjNw9xzf95nf/cdcyakYwxnhsyEETEDzwKXAcsAG4XkQUDNvsL4LSqzga+B3wnZN0HqrrMXf4qpPwx4EvAHHdZM/K3MYTKbXDsHegeH9/GpyTH8dCnl/Cbv/4wmSlx3P3UO3z+8T9zsKbJ66oZY6JYOFcIK4FyVT2oqh3AM8ANA7a5AfiZ+3gDsPp83/hFZDqQpqpb1Plq/HPgxmHXPlyvfBvWXwkPFcDPPgmvPgjlL0Obt7ekLi4I8OLdH+Hbaxeys6KeNY+8yXc37ae1w347wRgz9mLC2CYXqAh5Xglccq5tVLVLRBqATHfdTBF5BzgD/JOqvuluXzngmLnDr36YbnwMjm6Biredv29+FzQICExdCPmXOEvBJZBRCGPYeuX3CV/4cBHXL57Ov27cyw82l/PcO1V8a+1CPrpg6pjVwxhjwgmEC3EcKFDVOhEpAZ4XkYXDOYCI3AXcBVBQUDCyWqTnweKbnQWgvdFpRqr4M1RsgV3PwrafOutSpjnBkL/K+TttCfhjR/a6w5CdGs/D65Zx64p87nthN1/6+TZWz8vhW2sXkj8ladRf3xhjwgmEKiA/5HmeWzbYNpUiEgOkA3Vuc1A7gKpuF5EPgIvd7fOGOCbufuuB9QClpaWR6XmNT4WLrnIWgGA3VO8JuYp4G/a84KyLTYLcEvcKYhXklUJiICLVGMyqWZm8dM9lPPGnQzzy8gGuefh17r5qNnddMYv4GP+ova4xxshQo1vcE/z7wGqck/ZW4DOqWhayzVeAxar6VyJyG/ApVb1VRLKBU6raLSKzgDfd7U6JyJ+Be4C3gY3A/1XVjeerS2lpqW7btm3Eb3ZYzhzrC4iKt+H4LtBuQCB7Xv+riMDMUWlmOt7QygP/uZeX3jvOzKxkvr12IZdfnB3x1zHGTG4isl1VS4fcLpzhjiJyPfAI4AceV9UHReR+YJuqvigiCcAvgOXAKeA2VT0oIp8G7gc6gSDwTVX9rXvMUuA/gETgd8Df6BCVGdNAGKijGaq2O1cPFVugYiu0u53SyTkhAbHKaWaKiYvYS7/xfg3ffLGMQ7XNXL94Gv/8iQVMT0+M2PGNMZNbRANhvPA0EAYKBqFmb//O6vojzrqYBJhR3BcS+SshacoFvVx7VzfrXz/IDzaX4/cJf7t6Dv/jIzOJ9dvcQjOJBINwYhcc/iNkXgSzPwr+0e7qnPwsELzQeKKvD6JiCxx/F4Lu3IesuU5AZF3sXFGkuEtyjhMWvvD6BypOtfDt35bx8t5q5uSk8C83LmLVrMyhdzRmvGquhQ9edYaCf/AqNNf0rUvOgaXrYNlnIWeed3Wc4CwQxoOOFji2o39fxGBzH8QHSVmQMhVSst3AyHae9zzuCZGkTPD5eXnPSb712zIqT7dy0/JcvnH9PHJSE8b+PRozXN1dULnVDYBX4NhOQJ3/ty9aDbOvgZmXOeU7n4T3f+98scotgWV3wKJPQ2KG1+9iQrFAGI9Uof0MNFU7S3M1NNVA08m+x83Vfeu7288+Rm945NCdlMW+pkTeOuGn3pfBykXzuHTpfPxp0/qFhzGea6iE8lecEDj4utP/Jn6nOXW2GwLTloJvkCbQphrY9SsnHKr3OE2y8z4Byz8LM68YfB/TjwXCRNcbHucJDLcs2FSNr7vt7GOIzwmFfk1U2c7ij3WG22rQXbqd1+x5Hva6AeuHs87ndzrgCz7knBjsW9/k0dkGR/+7LwRq9jnlabl9ATDziuH9N1d1bkGz80l47/85V9vp+bD0dlh2O0yZNTrvZRKwQIgmqmhbA6/v2MPTm7fja6nhY4XCx4p8JLbXOW2yTSf7AqVrkPAYSHzONzjxOYuv57FEbl1XK1TvdftZBHIWOKO0epb0/DGdNW4ugCrUfeCc/MtfdjqFu1rBHw+FH3YCYPY1kD03Mv9NO9tg/0vwzpNOvwMKhR+B5XfAghsgLvnCX2MSsUCIUk3tXXz/lQM8/sdDpCTE8PU181hXmo/P5/4jVIWOJucb+vlO3mMldDjv0bec2eMd7m9FpOW64fAhZ2Lg1IXWBDaetDfCoTfcEHilb5Rd5uy+ACi8FOJGeaZ9QxW8+7Rz5XDqIMSlwMIbnY7oglX2pQILhKj3/slG/un53fz50CmW5WfwwI2LWJSb7nW1hhbshpNlbkf8FjjyFjQec9bFp0HeCicgClY5nYyjfbIxfVThxHtOR3D5K06AB7ucE/DMK9ymoNUQKPKufke3wM5fQtnzzhefKRfBss84zUrpo3e7tPHOAsGgqjy/s4oHX9rHqeZ2PruqkL+/di7piaN/b6aIUYWGCucf+tG3nL/VewEFXwxMX9oXEPmrnBFZJnJaTrlDQl9xgqDppFM+dXFfX0D+JRGdiBkRHc3O7WfeeRKO/NG58p11ldOkNPfjEBtdI/IsEEyvhtZOHv7Dfn6x5QhTkuP4xnXzuWl5bl8z0kTTetqZKd4TEFXb+0ZkZc7uC4eCDzmTm6zJIHzdXc5Q6Z6+gKodgDr377roaicALroaUqd5XdPwnToIO592mpUaKiAhHRbf4gxhnbE8Kv7/sEAwZ9ld1cA/Pb+bnRX1ZCbHccXcbK6am8PlF2dPrKuGgbranTHrFVv6riRaTzvrkrL6+iFG4bYiE4aq85k0HneWM8ediZQ9zxvd500n+/qXckvdvoDVzolzovffBINw6HWnr2Hvb53BFTkLnGBYsm78Xl32jDiMSxnxfwMLBDOoYFD5fdkJ/lB2gtfer6G+pRO/TygpDHD1vByunpfDnJwUJvQvmgaDUHeg7wri6BY4fchZF5Po3LG2ZyRT3kpISPO2vheqvck9uR8LOcmfcG7QGPp8sHktiVMgbYbzjT91GqROd06Ss6684NutjGut9VD2G6dJqWqb0/w452NOk9Kca0fnlvddHdBW77x26+m+pW3A87PWNzjDt/9mh3PFOwLnAloeAAAMg0lEQVQWCGZI3UFlZ8VpXt1Xzav7ath7/AwAuRmJXDUvm6vn5fChWVkkxk3wb4bgnBCPhlxBnHjP+UcmPqfjMT4VYhOdJSbBue15z/PYRCdIYhPPUZbktEnHJoXsm+Csv5BJU13tzjf2MyHf4M866R/vG5UVKjYZ0qY7J/jU6c7JvvfE75alTI26tvRBVe9zrhp2/cr5vJOznSuGZXfA1AG/FtwzSm/gifucJ/X6vvUd5/uJXHGashIznOa5niUh5PnS2yF5ZLepsUAww3a8oZXX9tfw6r5q/lReS0tHN/ExPj58USZXz8vhqnk55AUmyaie9ibnm+HRLc6ops5WZ+ly/3a2OGPde8q6O0b2OjEJ/UOiJ2gGlsUkOCN2QptxWurOPp4v1jmZp00POblPg9SQk33adCfgzPB0dzn9Jjt/Cft/D8FOp/M8NrH/ST94nt9m98f1P6EPPKn3nvAz+q9PSB/VJjkLBHNB2ru6efvgKV7dV83m/dUcqWsB4OKpKVw11wmHksJA9NxttbvLDYs2Jyy62kJCo8UNjsHKWvvCpt82rWeHkPgGP7mHnvQTA3arhrHQXAfvPQv7XnJO1AkDvrmf65t8bOK47KS2QDARdbCmqTcc/nzoFJ3dSmpCDJdfnM3Vc3O4Ym42WSnxXlfTGDMICwQzahrbOvlTea0bEDXUNLYjAkvyMrh6rtMxvXBG2sQd1mrMJGOBYMZEMKjsOX7G7Ziu5t3KelQhOzWeKy92OqY/MieL1IQJPKzVmAnOAsF4oq6pndffdzqm33i/hjNtXcT6hRVFU3o7pmdlJU/sYa3GTDAWCMZzXd1Bth85zav7q9m8r5r3TzrD7gqmJHH1vByuuDib2TkpTEtPiJ7OaWM8ENFAEJE1wL8BfuAnqvrQgPXxwM+BEqAOWKeqh0Xko8BDQBzQAfyDqr7q7vMaMB1odQ9zrapWn68eFggTW+XpFjbvr2GzO6y1vSsIgE9galoCuRmJ5AYS+/3NCySSm5E0OeZCGOORiAWCiPiB94GPApXAVuB2Vd0Tss1fA0tU9a9E5DbgJlVdJyLLgZOqekxEFgGbVDXX3ec14GuqGvYZ3gJh8mjr7GbH0dNUnmqlsr6VqtOtVNW3UFXfyvH6NrqC/f+/zEyO6wuLAcGRl5FEWmKMNUMZcw7hBkJMGMdaCZSr6kH3wM8ANwB7Qra5AfiW+3gD8AMREVV9J2SbMiBRROJVdZA59CaaJMT6+fBFWTDITPzuoHLyTBtVvUHRSqX79/2TjWzeX01bZ7DfPinxMX1XFIGzQyM7Jd4Cw5ghhBMIuUBFyPNK4JJzbaOqXSLSAGQCtSHbfBrYMSAMnhCRbuDXwAM6kTo0zKjx+4QZGYnMyEhkRdHZ61WVuuaO3rAYGBpbD5/iTFv/2aRxMb6+q4ueK4ueZqkpSUxPS7BhsibqhRMIF0xEFgLfAa4NKb5DVatEJBUnED6H0w8xcN+7gLsACgoKxqC2ZrwTEbJS4slKiWdp/uC/yXumrdMJip7QcIOjsr6VV/ZVU9vU/yI1LsZH4ZQkCjOTmZmVRFFWMjMzkynMSrawMFEjnECoAvJDnue5ZYNtUykiMUA6TucyIpIHPAd8XlU/6NlBVavcv40i8hRO09RZgaCq64H14PQhhPe2TLRLS4glbXos86cPfifTts5ujrlXFRWnWzhS18Kh2maO1DXzxoEaOrr6mqTiY3wUZvaERTJFmckUZTqhMc3Cwkwi4QTCVmCOiMzEOfHfBnxmwDYvAl8A3gJuBl5VVRWRDOAl4F5V/VPPxm5oZKhqrYjEAp8AXr7gd2NMmBJi/czKTmFWdspZ64JB5fiZNo7UNnOorpnDtc0crmvhcG0zr78/eFgUuWFRmJlMUVYSM7OSmZpqYWEmliEDwe0TuBvYhDPs9HFVLROR+4Ftqvoi8FPgFyJSDpzCCQ2Au4HZwH0icp9bdi3QDGxyw8CPEwY/juD7MmbEfD7p7Wv48Oysfuu6g8qJM20crm3mUG1fWBysbea1/TV0dPeFRUKsj8IpTkAUZSZT1HN1kZVkYWHGJZuYZkyEdAeV4w2tHK5t4VBdM0dqmzlc5wRHxanWs8LCaXpKpjAriZmZyeQFkshKjSMrJZ5AUhx+CwwTIZEcdmqMCYPfJ+QFksgLJPGROWdfWRyrb+VwXV/z0+HaZg5UN/LKvpN0dvf/YuYTmJIcT1ZKnNuB7v5NjSczOY6s1Hiy3Y71zJQ4m+ltIsICwZgx4PcJ+VOSyJ+SxGVz+q/rCYuq+lZqm9qpbWyntqmDuuZ2aho7qG1q53BdM7VN7WfNv+iRnhjbLzSyU/qCo1+gpMTbrG9zThYIxngsNCyG0tze5YRGkxMWdc3t1LqhUdvUTl1TB3uPneGNpnYa2wb/Za/kOH+/oMh0gyLbDY3cQCIXZaeQHG+nh2hj/8WNmUCS42NIjo+hMDN5yG3bOrupa+6gzg2L2sYOatzQ6AmQQ7XNbDt8mlMtHQzsTpyRnsDsqanMzk5hdo6zzMlJIZAcN0rvznjNAsGYSSoh1t87WmooXd1BTrV0UNvYwdFTLZRXN1Je3UR5TRNPHzpFa2d377aZyXFc5AbE7OwU5kx1Hk9LS7Dbg0xwFgjGGGL8PnJSE8hJTWDBjDRgWu+6YFA51tDKgeomPqhucoKiuomXdh2nobWzd7uU+Bguyk5mdk5q7xXF7JwUCqYk2YipCcICwRhzXr6Q0VNXzc3pLVdVaps6eq8kPqhu4kB1I38sr+HXOyp7t4uL8TErK9m5qgi5opiZlUx8jHVwjycWCMaYERERslPjyU6N50MXZfZbd6at0w2IvquK3VUNbHzveG9fhU+cH0tyriT6X1WkeNyhraqoQlCVblWCQegKBgkGoVu193FoWXcwSPc5yrqD6ixDlAWDSlew5zWdxz1/b1uRP+r9NxYIxpiIS0uIZXlBgOUFgX7lbZ3dHKxpprzGCYmeq4rX36/pNxdjWloCyfF+FMA9MSvuX8VdnDINWR9a3ret+xfncVBBCTkOblnIMcajjy7IsUAwxkweCbF+FsxIc/sp+nR1B93O7J7mp2baOrsRca5EBOeKouexiDjrAF/P437b9q3v21bcY5yjzH3MgGP4feIsIvh8Qoyv7+/AMr9I7/aDlYUeq+85+H0+p8zfc0yIGVAWHzP6kw8tEIwxnovx+3pvNnjt0JubUWLz3Y0xxgAWCMYYY1wWCMYYYwALBGOMMS4LBGOMMYAFgjHGGJcFgjHGGMACwRhjjGtC/aayiNQAR0a4exZQG8HqTHT2efSxz6I/+zz6mwyfR6GqZg+10YQKhAshItvC+ZHpaGGfRx/7LPqzz6O/aPo8rMnIGGMMYIFgjDHGFU2BsN7rCowz9nn0sc+iP/s8+ouazyNq+hCMMcacXzRdIRhjjDmPqAgEEVkjIvtFpFxE7vW6Pl4RkXwR2Swie0SkTET+1us6jQci4heRd0TkP72ui9dEJENENojIPhHZKyIf8rpOXhGR/+n+O9ktIk+LSILXdRptkz4QRMQPPApcBywAbheRBd7WyjNdwN+r6gJgFfCVKP4sQv0tsNfrSowT/wb8XlXnAUuJ0s9FRHKBe4BSVV0E+IHbvK3V6Jv0gQCsBMpV9aCqdgDPADd4XCdPqOpxVd3hPm7E+cee622tvCUiecDHgZ94XReviUg6cDnwUwBV7VDVem9r5akYIFFEYoAk4JjH9Rl10RAIuUBFyPNKovwkCCAiRcBy4G1va+K5R4D/BQS9rsg4MBOoAZ5wm9B+IiLJXlfKC6paBXwXOAocBxpU9Q/e1mr0RUMgmAFEJAX4NfBVVT3jdX28IiKfAKpVdbvXdRknYoBi4DFVXQ40A1HZ5yYiAZyWhJnADCBZRD7rba1GXzQEQhWQH/I8zy2LSiISixMGT6rqb7yuj8cuBdaKyGGcpsSrReSX3lbJU5VApar2XDVuwAmIaHQNcEhVa1S1E/gN8GGP6zTqoiEQtgJzRGSmiMThdAy96HGdPCEigtM+vFdVH/a6Pl5T1W+oap6qFuH8f/Gqqk76b4HnoqongAoRmesWrQb2eFglLx0FVolIkvvvZjVR0MEe43UFRpuqdonI3cAmnJECj6tqmcfV8sqlwOeA90Rkp1v2j6q60cM6mfHlb4An3S9PB4EvelwfT6jq2yKyAdiBMzrvHaJgxrLNVDbGGANER5ORMcaYMFggGGOMASwQjDHGuCwQjDHGABYIxhhjXBYIxhhjAAsEY4wxLgsEY4wxAPx/0K+BgTFf7N8AAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')\n", "plt.plot(val_loss_hist, label = 'validation')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test acc: 99.19%\n" ] } ], "source": [ "yhat = np.argmax(cnn.predict(x_tst), axis=-1)\n", "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }