{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 07 : ConvNet in TensorFlow\n",
    "same contents, but different style with [Lec07_ConvNet mnist with Weight initialization and Batch norm.ipynb](https://nbviewer.jupyter.org/github/aisolab/CS20/blob/master/Lec07_ConvNet%20in%20Tensorflow/Lec07_ConvNet%20mnist%20with%20Weight%20initialization%20and%20Batch%20norm.ipynb)\n",
    "\n",
    "### ConvNet mnist with Weight initialization and Batch norm\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Using `tf.keras`, alias `keras` and `eager execution`\n",
    "- Creating the model as **Class** by subclassing `tf.keras.Model`\n",
    "- Initializaing weights of model with **He initialization** by `tf.keras.initializers.he_uniform`\n",
    "- Training the model with **Batch Normalization** technique by `tf.keras.layers.BatchNormalization`\n",
    "- Using tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)\n",
    "tf.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train  / 255\n",
    "x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)\n",
    "x_tst = x_tst / 255\n",
    "x_tst = x_tst.reshape(-1, 28, 28, 1).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 28, 28, 1) (55000,)\n",
      "(5000, 28, 28, 1) (5000,)\n"
     ]
    }
   ],
   "source": [
    "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n",
    "\n",
    "x_tr = x_train[tr_indices]\n",
    "y_tr = y_train[tr_indices].astype(np.int32)\n",
    "\n",
    "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n",
    "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0).astype(np.int32)\n",
    "\n",
    "print(x_tr.shape, y_tr.shape)\n",
    "print(x_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define SimpleCNN class by high-level api"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleCNN(keras.Model):\n",
    "    def __init__(self, num_classes):\n",
    "        super(SimpleCNN, self).__init__()\n",
    "        self.__conv1 = keras.layers.Conv2D(filters=32, kernel_size=[5,5], padding='same',\n",
    "                                           kernel_initializer=keras.initializers.he_uniform(),\n",
    "                                           bias_initializer=keras.initializers.he_uniform(),\n",
    "                                           activation=tf.nn.relu)\n",
    "        self.__conv2 = keras.layers.Conv2D(filters=64, kernel_size=[5,5], padding='same',\n",
    "                                           kernel_initializer=keras.initializers.he_uniform(),\n",
    "                                           bias_initializer=keras.initializers.he_uniform(),\n",
    "                                           activation=tf.nn.relu)\n",
    "        self.__pool = keras.layers.MaxPooling2D()\n",
    "        self.__flatten = keras.layers.Flatten()\n",
    "        self.__bnc1 = keras.layers.BatchNormalization(momentum=.9)\n",
    "        self.__bnc2 = keras.layers.BatchNormalization(momentum=.9)\n",
    "        self.__bnd = keras.layers.BatchNormalization(momentum=.9)\n",
    "        self.__dense1 = keras.layers.Dense(units=1024, activation=tf.nn.relu, \n",
    "                                           kernel_initializer=keras.initializers.he_uniform(),\n",
    "                                           bias_initializer=keras.initializers.he_uniform())\n",
    "        self.__dense2 = keras.layers.Dense(units=num_classes,\n",
    "                                           kernel_initializer=keras.initializers.he_uniform(),\n",
    "                                           bias_initializer=keras.initializers.he_uniform())\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        conv1 = self.__bnc1(self.__conv1(inputs), training=training)\n",
    "        pool1 = self.__pool(conv1)\n",
    "        conv2 = self.__bnc2(self.__conv2(pool1), training=training)\n",
    "        pool2 = self.__pool(conv2)\n",
    "        flattened = self.__flatten(pool2)\n",
    "        fc = self.__bnd(self.__dense1(flattened), training=training)\n",
    "        score = self.__dense2(fc)\n",
    "        return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing for training a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "550\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter\n",
    "lr = .001\n",
    "epochs = 5\n",
    "batch_size = 100\n",
    "total_step = int(x_tr.shape[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)> <BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)>\n"
     ]
    }
   ],
   "source": [
    "## create input pipeline with tf.data\n",
    "# for train\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 10000)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "\n",
    "# for validation\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))\n",
    "val_dataset = val_dataset.batch(batch_size = batch_size)\n",
    "print(tr_dataset, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "\n",
    "## create model \n",
    "cnn = SimpleCNN(num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(model, x, y, training):\n",
    "    score = model(x, training=training)\n",
    "    return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "## for tensorboard\n",
    "# creating check point (Object-based saving)\n",
    "checkpoint_dir = '../graphs/lecture07/convnet_mnist_batch_norm_kde/'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, 'cnn')\n",
    "checkpoint = tf.train.Checkpoint(cnn=cnn)\n",
    "\n",
    "# create writer for tensorboard\n",
    "summary_writer = tf.contrib.summary.create_file_writer(logdir=checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.100, val_loss : 0.047\n",
      "epoch :   2, tr_loss : 0.034, val_loss : 0.045\n",
      "epoch :   3, tr_loss : 0.021, val_loss : 0.035\n",
      "epoch :   4, tr_loss : 0.017, val_loss : 0.053\n",
      "epoch :   5, tr_loss : 0.011, val_loss : 0.051\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'../graphs/lecture07/convnet_mnist_batch_norm_kde/cnn-1'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tr_loss_hist = []\n",
    "val_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    avg_tr_loss = 0\n",
    "    avg_val_loss = 0\n",
    "    tr_step = 0\n",
    "    val_step = 0\n",
    "\n",
    "    with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): # for tensorboard\n",
    "        # for training\n",
    "        for x_mb, y_mb in tr_dataset:\n",
    "            with tf.GradientTape() as tape:\n",
    "                tr_loss = loss_fn(cnn, x_mb, y_mb, training=True)     \n",
    "                grads = tape.gradient(target=tr_loss, sources=cnn.variables)\n",
    "            opt.apply_gradients(grads_and_vars=zip(grads, cnn.variables))\n",
    "            tf.contrib.summary.scalar(name='tr_loss', tensor=tr_loss)\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "        else:\n",
    "            avg_tr_loss /= tr_step\n",
    "            tr_loss_hist.append(avg_tr_loss)\n",
    "            \n",
    "        # for validation\n",
    "        for x_mb, y_mb in val_dataset:\n",
    "            val_loss = loss_fn(cnn, x_mb, y_mb, training=False)\n",
    "            tf.contrib.summary.scalar(name='val_loss', tensor=val_loss)\n",
    "            avg_val_loss += val_loss\n",
    "            val_step += 1\n",
    "        else:\n",
    "            avg_val_loss /= val_step\n",
    "            val_loss_hist.append(avg_val_loss)\n",
    "#     if (epoch + 1) % 5 == 0:\n",
    "    print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n",
    "\n",
    "checkpoint.save(file_prefix=checkpoint_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f187422f128>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VdW58PHfk5mMZAICCSRAEsI8JEhFBGQQsaJWJq292l611Vprte2lfe+tw23f2r6KWqdqq61ttUCxttShDALiCAkoyGBCGBOGkATIQOZkvX/sQxJCQkI4OfsMz/fz2R/P2Wedc55sOc9ae6+11xJjDEoppXyDn90BKKWUch1N+kop5UM06SullA/RpK+UUj5Ek75SSvkQTfpKKeVDNOkrpZQP0aSvlFI+RJO+Ukr5kAC7A2grLi7OJCcn2x2GUkp5lK1bt5YYY+I7K+d2ST85OZmcnBy7w1BKKY8iIoe6Uk4v7yillA/RpK+UUj5Ek75SSvkQt7umr5TyLvX19RQWFlJTU2N3KF4hJCSExMREAgMDu/X+LiV9EZkDPA34A783xjzW5vUrgaeA0cBiY8zKVq/dBvy34+nPjTGvditSpZRHKiwsJCIiguTkZETE7nA8mjGG0tJSCgsLSUlJ6dZndHp5R0T8geeAa4DhwM0iMrxNscPA7cDrbd4bAzwEXAZMBB4SkehuRaqU8kg1NTXExsZqwncCESE2NvaSzpq6ck1/IpBvjNlvjKkDlgHXty5gjDlojNkBNLV579XAWmPMSWPMKWAtMKfb0SqlPJImfOe51GPZlaQ/ACho9bzQsa8rLuW9F+VERQ2P/ms3ZVX1PfHxSinlFdxi9I6I3CUiOSKSU1xc3K3PKK2s45WPDvC7D/Y7OTqllCc7ffo0zz///EW/b+7cuZw+fboHIrJXV5L+ESCp1fNEx76u6NJ7jTEvGWMyjTGZ8fGd3kXcroyESL46OoFXPjpASWVttz5DKeV9Okr6DQ0NF3zfO++8Q+/evXsqLNt0JelnA6kikiIiQcBiYFUXP381MFtEoh0duLMd+3rED2alUVPfyG837uupr1BKeZglS5awb98+xo4dS1ZWFlOmTGHevHkMH26NR7nhhhuYMGECI0aM4KWXXmp+X3JyMiUlJRw8eJCMjAzuvPNORowYwezZs6murrbrz7lknQ7ZNMY0iMi9WMnaH3jFGLNLRB4Fcowxq0QkC3gTiAauE5FHjDEjjDEnReR/sSoOgEeNMSd76G9hSHw4XxufyJ8+PcQdUwbTLyqkp75KKdUNj/xrF7uPljv1M4f3j+Sh60Z0+Ppjjz3Gzp07+fzzz9m4cSPXXnstO3fubB7y+MorrxATE0N1dTVZWVncdNNNxMbGnvMZe/fu5a9//Su/+93vWLhwIW+88Qa33nqrU/8OV+nSNX1jzDvGmDRjzBBjzC8c+35mjFnleJxtjEk0xoQZY2KNMSNavfcVY8xQx/aHnvkzWnx/RirGGJ5Zv7env0op5YEmTpx4zhj33/zmN4wZM4ZJkyZRUFDA3r3n546UlBTGjh0LwIQJEzh48KCrwnU6r7sjNykmlEVZSSzbUsB3pg4hKSbU7pCUUg4XapG7SlhYWPPjjRs3sm7dOj755BNCQ0OZNm1au2Pgg4ODmx/7+/t79OUdtxi942zfuyoVfz/hqXXa2lfK10VERFBRUdHua2VlZURHRxMaGsqXX37Jp59+6uLoXM8rk37fyBC+MWkQb35WSP6JSrvDUUrZKDY2lsmTJzNy5Eh+9KMfnfPanDlzaGhoICMjgyVLljBp0iSbonQdMcbYHcM5MjMzjTMWUSmtrGXKrzcwfVgfnrtlvBMiU0p1x549e8jIyLA7DK/S3jEVka3GmMzO3uuVLX2A2PBgvjU5hbd3HHP6aAGllPJUXpv0Ae68cjCRIQEsXZtrdyhKKeUWvDrpR/UK5NtTh7Buzwk+O3zK7nCUUsp2Xp30AW6/PJnYsCCeWJNndyhKKWU7r0/6YcEB3D1tCB/ml/DJvlK7w1FKKVt5fdIHuHXSIPpGBvPEmlzcbbSSUkq5kk8k/ZBAf753VSo5h07xfl73pm5WSvmG8PBwAI4ePcr8+fPbLTNt2jQ6G1r+1FNPUVVV1fzcXaZq9omkD7AwM4nE6F48sSZPW/tKqU7179+flStXdl6wA22TvrtM1ewzST8owI/7Z6bxxZEyVu8qsjscpZSLLFmyhOeee675+cMPP8zPf/5zZsyYwfjx4xk1ahT//Oc/z3vfwYMHGTlyJADV1dUsXryYjIwMbrzxxnPm3rn77rvJzMxkxIgRPPTQQ4A1idvRo0eZPn0606dPB1qmagZYunQpI0eOZOTIkTz11FPN3+eKKZy9bsK1C7lhbH+e35jP0rW5zBreF38/XbdTKZd6dwkc/8K5n9lvFFzzWIcvL1q0iPvvv5/vfve7AKxYsYLVq1dz3333ERkZSUlJCZMmTWLevHkdrj/7wgsvEBoayp49e9ixYwfjx7fc5f+LX/yCmJgYGhsbmTFjBjt27OC+++5j6dKlbNiwgbi4uHM+a+vWrfzhD39g8+bNGGO47LLLmDp1KtHR0S6ZwtlnWvoAAf5+PDArjbyiSt7acdTucJRSLjBu3DhOnDjB0aNH2b59O9HR0fTr14+f/vSnjB49mpkzZ3LkyBGKijq+ArBp06bm5Dt69GhGjx7d/NqKFSsYP34848aNY9euXezevfuC8Xz44YfceOONhIWFER4ezte+9jU++OADwDVTOPtUSx9g7sgEMhL28eTaPOaOSiDQ36fqPaXsdYEWeU9asGABK1eu5Pjx4yxatIjXXnuN4uJitm7dSmBgIMnJye1OqdyZAwcO8Pjjj5OdnU10dDS33357tz7nLFdM4exzGc/PT3hwVhoHS6v4+7ZCu8NRSrnAokWLWLZsGStXrmTBggWUlZXRp08fAgMD2bBhA4cOHbrg+6+88kpef/11AHbu3MmOHTsAKC8vJywsjKioKIqKinj33Xeb39PRlM5TpkzhH//4B1VVVZw5c4Y333yTKVOmOPGvvTCfa+kDzMjow5ik3vzmvXxuGDeA4AB/u0NSSvWgESNGUFFRwYABA0hISODrX/861113HaNGjSIzM5Nhw4Zd8P1333033/zmN8nIyCAjI4MJEyYAMGbMGMaNG8ewYcNISkpi8uTJze+56667mDNnDv3792fDhg3N+8ePH8/tt9/OxIkTAbjjjjsYN26cy1bj8tqplTvz4d4Sbn15M4/MG8Ftlyf3+Pcp5at0amXn06mVu2Hy0FguS4nh2Q35VNc12h2OUkq5hM8mfRHhh1enU1xRy58+OWh3OEop5RI+m/QBspJjmJoWzwvv76Oipt7ucJTyWu52GdmTXeqx9OmkD/DD2emcrqrnlQ8P2h2KUl4pJCSE0tJSTfxOYIyhtLSUkJCQbn+GT47eaW1UYhRXj+jL7z/Yz22XD6J3aJDdISnlVRITEyksLKS4WCc7dIaQkBASExO7/X6fT/oAD8xKZ83uTby4aT//NefCQ7eUUhcnMDCQlJQUu8NQDj5/eQcgvV8E88b0548fHeRERffvplNKKXenSd/h/plp1DU28cLGfXaHopRSPUaTvkNKXBjzxyfy2qeHOXra+fNdKKWUO9Ck38p9M1MBeGZ9vs2RKKVUz9Ck38qA3r24eWISf8sp4FDpGbvDUUopp9Ok38Z3pw8lwF94et1eu0NRSimn06TfRp/IEG77SjJvfn6EvUXnT4uqlFKeTJN+O74zdQhhQQE8uS7P7lCUUsqpNOm3IzosiG9dkcI7Xxxn55Eyu8NRSimn0aTfgTumpBDVK5Cla7W1r5TyHpr0OxAZEsi3pw5m/Zcn2HrolN3hKKWUU3Qp6YvIHBHJFZF8EVnSzuvBIrLc8fpmEUl27A8UkVdF5AsR2SMiP3Fu+D3r9suTiQsP4ok1uXaHopRSTtFp0hcRf+A54BpgOHCziAxvU+w/gVPGmKHAk8CvHPsXAMHGmFHABODbZysETxAaFMA904by8b5SPs4vsTscpZS6ZF1p6U8E8o0x+40xdcAy4Po2Za4HXnU8XgnMEBEBDBAmIgFAL6AOKHdK5C5yy2UDSYgK4fE1uTofuFLK43Ul6Q8AClo9L3Tsa7eMMaYBKANisSqAM8Ax4DDwuDHm5CXG7FIhgf5876pUth0+zYbcE3aHo5RSl6SnO3InAo1AfyAFeFBEBrctJCJ3iUiOiOS440ILCzITGRgTyhNr8mhq0ta+UspzdSXpHwGSWj1PdOxrt4zjUk4UUArcAvzbGFNvjDkBfARktv0CY8xLxphMY0xmfHz8xf8VPSzQ34/7Z6ay62g5/9513O5wlFKq27qS9LOBVBFJEZEgYDGwqk2ZVcBtjsfzgfXGugB+GLgKQETCgEnAl84I3NWuHzuA1D7hLF2bR6O29pVSHqrTpO+4Rn8vsBrYA6wwxuwSkUdFZJ6j2MtArIjkAw8AZ4d1PgeEi8gurMrjD8aYHc7+I1zB3094YFYa+Scq+efnbU90lFLKM4i7jUjJzMw0OTk5dofRrqYmw3XPfkhFTQPvPTiVQH+9t00p5R5EZKsx5rzL521p1roIfn7CD2enc/hkFX/LKbQ7HKWUumia9C/StPR4xg/szTPr91JT32h3OEopdVE06V8kEeGHV6dzrKyG1zcftjscpZS6KJr0u+HyIXFcPiSW5zfmU1XXYHc4SinVZZr0u+nB2emUVNbxx48P2h2KUkp1mSb9bpowKJqrhvXhxff3U1Zdb3c4SinVJZr0L8EDs9Ioq67n5Q8P2B2KUkp1iSb9SzByQBRzR/Xj5Q/2c/JMnd3hKKVUpzTpX6IHZqVRXd/Ii+/vszsUpZTqlCb9SzS0TwQ3jB3Aq58c5ER5jd3hKKXUBWnSd4Lvz0ylodHw3IZ8u0NRSqkL0qTvBINiw1iQmcTrWw5TeKrK7nCUUqpDmvSd5L4ZQxERnnlPW/tKKfelSd9JEqJ68fXLBrJyWyEHSs7YHY5SSrVLk74T3TNtKEH+fjy1Ls/uUJRSql2a9J0oPiKY2ycns2r7UXKPV9gdjlJKnUeTvpN9+8rBhAcFsHRtrt2hKKXUeTTpO1nv0CDumDKY1buK2FF42u5wlFLqHJr0e8C3rkgmOjSQJ9botX2llHvRpN8DIkIC+c7UIbyfV0z2wZN2h6OUUs006feQ//hKMvERwfy/1bm42+LzSinfpUm/h/QK8ufe6UPZcuAkH+aX2B2OUkoBmvR71OKJSQzo3YvH1+Rpa18p5RY06feg4AB/7psxlO0Fp1m354Td4SillCb9nnbT+ESSY0N5Yk0uTU3a2ldK2UuTfg8L8PfjB7PS+PJ4BW9/cczucJRSPk6TvgtcN7o/6X0jeHJdHg2NTXaHo5TyYZr0XcDPT/jBrDT2F5/hzc+O2B2OUsqHadJ3katH9GXUgCiefm8vdQ3a2ldK2UOTvouICA/OTqPwVDXLcwrsDkcp5aM06bvQ1LR4spKjeXb9XmrqG+0ORynlgzTpu5DV2k+nqLyWv3x6yO5wlFI+SJO+i00aHMuU1Die37iPM7UNdoejlPIxmvRt8ODsdE6eqeMPHx2wOxSllI/RpG+DsUm9mZnRlxc37aesqt7ucJRSPkSTvk0enJ1GRU0Dv/tgv92hKKV8SJeSvojMEZFcEckXkSXtvB4sIssdr28WkeRWr40WkU9EZJeIfCEiIc4L33NlJETy1dEJvPLRAUora+0ORynlIzpN+iLiDzwHXAMMB24WkeFtiv0ncMoYMxR4EviV470BwF+A7xhjRgDTAL2e4fCDWWnU1DfywsZ9doeilPIRXWnpTwTyjTH7jTF1wDLg+jZlrgdedTxeCcwQEQFmAzuMMdsBjDGlxhgdoO4wJD6cr41P5M+fHuJ4WY3d4SilfEBXkv4AoPUtpIWOfe2WMcY0AGVALJAGGBFZLSLbROTH7X2BiNwlIjkiklNcXHyxf4NH+/6MVJqM4dkNe+0ORSnlA3q6IzcAuAL4uuO/N4rIjLaFjDEvGWMyjTGZ8fHxPRySe0mKCWVRVhLLswsoOFlldzhKKS/XlaR/BEhq9TzRsa/dMo7r+FFAKdZZwSZjTIkxpgp4Bxh/qUF7m+9dlYqfCE+/p619pVTP6krSzwZSRSRFRIKAxcCqNmVWAbc5Hs8H1htrUdjVwCgRCXVUBlOB3c4J3Xv0jQzhG5MG8fdthewrrrQ7HKWUF+s06Tuu0d+LlcD3ACuMMbtE5FERmeco9jIQKyL5wAPAEsd7TwFLsSqOz4Ftxpi3nf9neL67pw0hJNCfJ9fm2R2KUsqLidUgdx+ZmZkmJyfH7jBs8fjqXJ7dkM87901heP9Iu8NRSnkQEdlqjMnsrJzeketG7rxyMJEhASzV1r5Sqodo0ncjUb0CuevKwazbU8Rnh0/ZHY5Sygtp0ncz35ycQkxYkLb2lVI9QpO+mwkLDuCeaUP4YG8Jn+4vtTscpZSX0aTvhm6dNIi+kcE8sSYXd+toV0p5Nk36bigk0J97r0ol++Ap3s/zrWkplFI9S5O+m1qUmURidC+eWJOnrX2llNNo0ndTQQF+fH9GKl8cKWP1riK7w1FKeQlN+m7sxnEDGBwfxtK1uTQ2aWtfKXXpNOm7sQB/P34wM428okre2nHU7nCUuji1FXCmBBrq7I5EtRJgdwDqwq4dlcBzG/J5cm0e145KIMBf62nl5qpPw0dPw6cvQEO1tS8wFEKiHFvvVo8dW6929p0tGxwJ/pqqnEWPpJvz8xMenJ3OnX/K4Y1thSzKGmh3SEq1r74atrwEHyyFmtMwcj4kXQY1ZdbzmtOOx2VQeRxKcluem6YLf3ZQRBcqiQ4qjuBI8NPG0lma9D3AzIw+jEnqzW/ey+eGcQMIDvC3OySlWjQ2wOd/gY2PQcUxSJ0NV/0PJIzu2vubmqCusqUCaN5Ot7OvzDqTOF0ANTut57VlnXyBQEhkO5VDF884gsJB5JIPk7vQpO8BRIQfzk7jGy9vYdmWAm67PNnukJSykvWef8L6n0NpvtWqv+llSJ58cZ/j5+dIypGcu15TV+NotPoPLlRJtN138kBL+bpO1rAQ/w4uPbWuJHp3fMYR2MutKg1N+h7iiqFxXJYSw7Mb8lmYmUSvIG3tK5sYA/vWw3uPwLHt0Gc43LwM0ubYk9z8/K3k26t3997f2AC15S2VQHuVRNszj5Kilsf1nSxz6hfYeb/F2ccxg2FAzy4uqEnfQ4gIP7w6nQW//YQ/fXKQb08dYndIyhcV5sC6h+HgB9B7INz4EoyabyVeT+UfAKEx1tYdDbVQU95O5dBO5XG2Qjld0FKmsdXoppE3wfxXnPN3dUCTvgfJSo5halo8v31/H7dcNpCIkEC7Q1K+ojgX3nsUvnwLQuPgml/DhNshINjuyOwXEAzh8dbWHfU1LRWEf5BzY2uHdml7mAdnp3Gqqp5XPjxodyjKF5wugH98F56fBPvfh+n/B77/OVz2bU34zhIYAhH9ID4dYlJ6/Ou0pe9hRif25uoRffn9B/u57fJB9A7t+ZaB8kFnSuGDJyD7d4DApHvgigcgLNbuyNQl0pa+B3pgVjqVdQ28uGm/3aEob1NbCRt/BU+Pgc0vwOiF8L2tcPUvNOF7CW3pe6D0fhHMG9OfP350kG9NTiE+Qk+z1SVqqIWtf4T3fw1VJZBxnTXWPj7d7siUk2lL30PdPzONusYmnt+Yb3coypM1NcL2ZfBsJrz7Y+iTAXesh0V/0YTvpbynpV95wvrHGxzh2CJbPW61z0vm8EiJC2P++ERe+/Qwd04ZTP/evewOSXkSYyD3XVj/v3BiNySMheuehsHT3epGIuV83pEBAU4dgrX/03m5wND2K4N2n3dQcQRHQID9HajfmzGUv39WyDPr8/nl10bZHY7yFAc/ssbaF26B2KGw4FUYfr0mex/hPUk/MRN+csS6Hbt5K3dsbfdVnLudOmjdXHH2NdPY+ff5B1vJP+RCFUTb/e08Dgju9o8tMTqUWyYO5LXNh/nO1MEMig3r1ucoH3FshzXWPn8tRCRYLfuxt3rN2a/qGu/5vy0CweHWRkL3P8cYa7bADiuJ9ioSx1ZW0PK4phya6jv/Pr/A8yuCkAucYZyzRfG9iZH8M6eOp9fmsXTxuO7/3cp7le6DDf8Xdq60bvmf9ShMvMuaE0b5HO9J+s4iAkGh1hbRt/ufY4w1IqKjs4vmfe28VnEMSvJanjfUdPg1ccDn/tCwx4/GX0bgHxLV/hlHSOT5FUdIFPQf75joSnmdiuPWaJxtr1p3ek55EC6/r/tz1CivoEm/p4hYd9oFhnT/9uyzGupapp5tp/KoqjjNn9/fSVqIYXpyr5azkaoSOHWgpWx7E0P5B0HKlZA+19oiL+EsSbmH1ouYNNVb0yVc+SPrrk/l8zTpe4KAIAjoeEKoUOBMYx7ffG8vby28gpEDotr/nMYGqGtVYVSegH3vwZdvw9sPWNuACVbyH3YtxA/Tzj1P0nYRk1ELYPpPrZkblXIQY9xrwe3MzEyTk5Njdxgep7ymnim/2sCEQdG8cnvWxb3ZGGtCrdy34ct34Ijj+EenWMl/2LXWXOmePJOiN7vURUyUVxCRrcaYzM7KaUvfS0SGBPLtqYP59b9z2XroFBMGRXf9zSLQZ5i1TXkQyo9B3rtWBbDlJfjkWQiNteZLT58LQ66y+jyUvZy1iInyKdrS9yJVdQ1c+esNpPWN4PU7JznnQ2srIH+dVQHsXW31KwT0giHTHf0A10BYnHO+S3VNe4uYzPiZfYuYKLegLX0fFBoUwD3ThvLoW7v5OL+Ey4c6IRkHR8CIG62tsR4OfWz1AeS+Y23iZ7Uwz/YDxOriLj3qvEVMXrSu3eulN9VF2tL3MjX1jUx/fCMJUSG8cfflSE+1/IyB4184KoC3rcdgdf6erQD6j7fWP1WXru0iJlN/rIuYqHN0taWvSd8Lvb75MD998wv+cHsW04f1cc2Xnj5szeXy5VvWbf6mEcL7WZd/hl1rDQvVBHXxThdYHbTbX4fAMJh8H0y62zoDU6oVpyZ9EZkDPA34A783xjzW5vVg4E/ABKAUWGSMOdjq9YHAbuBhY8zjF/ouTfqXrr6xiRlPvE9ESAD/uvcK/PxcfJ23+hTsXWudBeSvs+4xCAqHoTOtCiB1FvS6iI5mX9R2EZOJd+oiJuqCnHZNX0T8geeAWUAhkC0iq4wxu1sV+0/glDFmqIgsBn4FLGr1+lLg3Yv5A1T3Bfr7cf/MVB5YsZ3Vu45zzSgX33DVK9pafGP0Quuu5AObWvoBdv8D/AJg0GSrAkifC72TXBufO6uthE+eg4+fgfozMPYWmLpEj5Fymk5b+iLyFawW+tWO5z8BMMb8slWZ1Y4yn4hIAHAciDfGGBG5AZgMnAEqtaXvGo1Nhquf2gTA6vuvxN/Vrf32NDXB0W0tFUDxl9b+fqNg2FetCqDfKN8cgaKLmKhL5MzROwOAglbPC4HLOipjjGkQkTIgVkRqgP/COkv4YVcCV87h7yc8MCuNe17bxqrtR7hxXKLdIVmduomZ1jbzIWsisLMVwMbHYOMvIWqgox9grnU24B9od9Q9q6kRvvgbbPiF1S+SPAVmPgKJE+yOTHmpnh6y+TDwpDGm8kKjSETkLuAugIEDB/ZwSL5jzoh+jOgfyZNr9/LV0f0J9HezkTSxQ6yOycn3QWUx5P3bqgC2vQpbXrQmhEu92qoAhs70rs7L8xYxGaOLmCiX6ErSPwK0vqCY6NjXXplCx+WdKKwO3cuA+SLya6A30CQiNcaYZ1u/2RjzEvASWJd3uvOHqPP5+QkPzk7jW3/MYeXWQm6e6MYVang8jP+GtdWdgX0bHPcCvAtfrHBMDDfVqgDS53r25GGtFzGJGQIL/ggZ1+vwVuUSXbmmHwDkATOwkns2cIsxZlerMt8FRhljvuPoyP2aMWZhm895GL2m73LGGG564WOOldWw4YfTCAn0sJt4mhqhYLN1GejLt61ZQwEGZDoqgGut696e0Dpuu4jJtCUw9uvefwlLuYSzh2zOBZ7CGrL5ijHmFyLyKJBjjFklIiHAn4FxwElgsTFmf5vPeBhN+rb4OL+EW36/mYeuG843J6fYHU73GWN1/p6tAI5us/bHDGmpAJImut/dqW0XMZnygC5iopxOb85S57jld5+SV1TBph9PJzTIS2bfKD/quCHsbWtYaFO9dbdq+hyrAhgy3d7E2nYRk0l36yImqsdo0lfn2HroFDe98DE/npPOPdOG2h2O89WUWzeC5b4DeWug1jEx3NAZVh9A2hzX3diki5goG+iEa+ocEwZFc9WwPrz4/n5unTSIyBAvu44cEgkjv2ZtDXVw6COrAvjyHWtqCPGDpEmO9QHm9szCIrqIifIA2tL3ITuPlPHVZz7kvhmpPDArze5wXMMYa/rhsxVA0dmJ4TJa+gH6j7u0kTO6iIlyA3p5R7Xrnte2simvhE0/nk5MWJDd4bjeqYMt/QCHPrYmhotIsG4IS78WUqZ0fWK49hYxmfGQLmKibKFJX7Vrb1EFs5/axF1TBvOTuRl2h2OvqpOwd41jYrj3rLlugiIgdaZVAaTOar/TVRcxUW5Ir+mrdqX2jeDGsQN49ZOD/OcVKfSJDLE7JPuExsCYxdZWX+OYGO4t60xg15vWxHDJV1gVwLC5EJWoi5goj6ctfR90qPQMM554n69fNpBHrh9pdzjup6kJjmx1VADvQEmetT9mMJzcr4uYKLekLX3VoUGxYSzITOL1LYe5ZlQCl6XE9NwKW57Izw+Ssqxt1iNQkm+tDrb/fRhzsy5iojyatvR91NHT1Xz1mQ85eaaOoX3CWZSZxI3jBxAXri1XpTyRduSqTlXWNvD2jqMszy5g2+HTBPgJs4b3ZWFWElemxrvHHPxKqS7RpK8uyt6iCpZnF/D3z45w8kwdCVEhLJiQyILMJJJiQu0OTynVCU36qlvqGppYt6eI5dkFbNpbDMDkIXEszEpi9vC+njdLp1I+QpO+umRHTlezMqeQFTkFHDldTe/QQG4YO4BFWUlkJETaHZ5SqhVN+sppmpoMH+0rYXmk/pHdAAAOaklEQVR2AWt2FVHX2MSYxCgWZiUxb0x/IrxtHh+lPJAmfdUjTp2p483PjrA8u4Dcogp6Bfozd1QCi7KSyEqO1qGfStlEk77qUcYYtheWsTy7gH9tP0plbQOD48JYmJXETeMTiY/QoZ9KuZImfeUyVXUNvL3jGMuzC8g5dIoAP+GqYX1YPNEa+hngbguyK+WFNOkrW+SfqORvOQW8sa2Qkso6+kYGs2BCEgszkxgYq0M/leopmvSVreobm3hvzwlW5BSwMfcETQa+MjiWxROTuHpEPx36qZSTadJXbuNYWTVvbC1keU4BBSeriQwJ4IZx1tDPEf2j7A5PKa+gSV+5naYmw6f7S1meU8C7O49T19DEyAGRLMoayLwx/YnqpUM/leouTfrKrZ2uquOfnx9lWXYBe46VExzg1zz0U2f9VOriadJXHsEYw84j5SzLPsyqz49SUdtAcmwoC7OSmD8+0bcXeVHqImjSVx6nuq6Rd3ceY1l2AVsOnMTfT5ieHs+irIFMT9ehn0pdiCZ95dEOlJxhRU4BK7cWUlxRS3xEMPMnJLIwM4mUuDC7w1PK7WjSV16hvrGJjbnFLM8+zIbcYhqbDBNTYliclcQ1IxPoFaRDP5UCTfrKCxWV1/DGtkJWZBdwsLSKiOAArh/Xn0WZAxk5IFI7f5VP06SvvJYxhs0HTrI8u4B3vjhGbUMTGQmRLM5K4oaxA4gK1aGfyvdo0lc+oay6nlXbj7I8+zA7j5QTFODHNSP7sSgziUmDY/HTJR+Vj9Ckr3zOziNlrMgp4B+fHaG8poGBMaEszExk/oQk+kXp0E/l3TTpK59VU9/I6l3HWbalgE/2l+InMC29Dwszk5iR0YdAHfqpvJAmfaWAQ6UtQz+LymuJCw/ipvGJLMxKYkh8uN3hKeU0mvSVaqWhsYlNe4tZtqWA9V+eoKHJkJUczcLMJK4dnUBoUIDdISp1STTpK9WB4opa/r6tkOXZBewvOUN4cADXjenPoqwkxiRG6dBP5ZE06SvVCWMMOYdOsWyLNfSzur6RYf0iWJiZxI3jBhAdFmR3iEp1mSZ9pS5CRU09/9p+jOXZh9leWEaQvx+zR/RlUVYSk4fE6dBP5facmvRFZA7wNOAP/N4Y81ib14OBPwETgFJgkTHmoIjMAh4DgoA64EfGmPUX+i5N+spue46Vszy7gH98foTTVfUkRvdiwYQkFmQm0r93L7vDU6pdTkv6IuIP5AGzgEIgG7jZGLO7VZl7gNHGmO+IyGLgRmPMIhEZBxQZY46KyEhgtTFmwIW+T5O+chc19Y2s2V3EiuwCPswvQQT6RYbQLyqEfpEh9G31uPV/dSlIZYeuJv2uDFmYCOQbY/Y7PngZcD2wu1WZ64GHHY9XAs+KiBhjPmtVZhfQS0SCjTG1XfhepWwVEujPvDH9mTemPwUnq1i1/Sj7i89QVF5DXlEFH+wtobK24bz3RfUKJCHKUSlEhtA3KoSENhVFdGigdhgrW3Ql6Q8AClo9LwQu66iMMaZBRMqAWKCkVZmbgG2a8JUnSooJ5bvTh563v7K2geNlNdZWXkNRecvj42U17D5WTkllLW1PqIMC/KwzA0el0C8ymL6RISRE9aJflPW4T0QIQQF6I5lyLpcMThaREcCvgNkdvH4XcBfAwIEDXRGSUk4RHhzA0D7hDO3T8Y1e9Y1NFFfUNlcEx8sclUN5DcfKathReJo1ZTXUNjSd99648GD6RQW3nCWcvZQU1VJhRAQH6FmD6rKuJP0jQFKr54mOfe2VKRSRACAKq0MXEUkE3gT+wxizr70vMMa8BLwE1jX9i/kDlHJ3gf5+9O/d64KdwMYYyqrrmyuCojZnDoWnqtl66BSnqurPe29YkL/jbOHcy0mtK4m48GD8dQSSomtJPxtIFZEUrOS+GLilTZlVwG3AJ8B8YL0xxohIb+BtYIkx5iPnha2UdxEReocG0Ts0iGH9IjssV1PfyInyWkflUO2oFGopcjzffOAkReU1NDSd23by9xP6RAS3f7bQqiNaF6Xxfp0mfcc1+nuB1VhDNl8xxuwSkUeBHGPMKuBl4M8ikg+cxKoYAO4FhgI/E5GfOfbNNsaccPYfopQvCAn0Z2BsKANjQzss09RkKD1T56gIHGcMrc4c8osr+Si/hIoOOqFb9zP0i+rlqCRaKoyYsCC9nOTB9OYspXzUmdqGc/oZ2uuILu6gE7pvZEs/Q/OlpFZnDn0jtRPa1Zw5ZFMp5YXCggMYEh9+wdlGGxqbKK6sPadiaH3msPNIGev2FFFT314ndBDJsWGk9YsgvW8EaX0jSOsbTmx4cE/+WaoTmvSVUh0K8PcjIaoXCVEX7oQur244r5/heHk1+4rP8PaOY7xefbi5fFx4kKMCiCC9X0tlEBGiy1y6giZ9pdQlERGiQgOJCg0kvV/Eea8bYyiuqCW3qILc4xXkFVWQV1TJipwCquoam8v1jwo556wgvV8EQ/uE6x3OTqZJXynVo0SEPpEh9IkMYUpqfPP+pibDkdPV5BVVkFtUQd7xCnKLKvl4Xyl1jnsWRLAuEfUNP+fsICUuTFdA6yZN+kopW/j5CUkxoSTFhDIjo2/z/obGJg6drHJUAtaZQe7xCtbtOUGjYyhqoL8wOC7ccWYQTmpf6wwhKSZU70fohCZ9pZRbCfD3a+5gvmZUQvP+mvpG9hefYe+JlstEnxec4l/bjzaXCQn0I7XP2TOC8OYzg36RITrM1EGTvlLKI4QE+jO8fyTD+59789qZ2gb2nqgkz1ER5BZV8GF+MW9sK2wuExEcQJqj0zi9b3hz34EvjiTSpK+U8mhhwQGMTerN2KTe5+w/XVVHXlFlq/6CCt7deYy/bmmZyiIuPIjUPi2jiNL7WZeKIr14JJEmfaWUV+odGsTElBgmpsQ07zPGUFxZS97xcyuDv+UUcKadkUTNncd9rZFE3jBNhSZ9pZTPEBH6RFjTVl+RGte8v6nJcLTMMZLoeGVz53HbkUSDYkKb+wnOdh6nxIV51N3HmvSVUj7Pz09IjA4lMTqUq4adO5Lo8MmqcyuDogre+7JlJFGAnzA4Pqz5jODsGcJANx1JpElfKaU6EODvx+D4cAbHhzNnZMv+2gZrJFFe85DSSnYUlvHWjmPNZYID/Eh13F9wtjJI7xtBQpS9I4k06Sul1EUKDvAnIyGSjITzRxLln2jpL8g7UcnH+aX8fVvLEiQRwQGk9g1v6Tx2VAhxLhpJpElfKaWcJCw4gDFJvRnTZiRRWVU9ea3uL8g9XsG/dx7nr1taVqKNDQvixnED+O+vDu/RGDXpK6VUD4sKDSQrOYas5HNHEpVU1jVXAnlFFSRcYHU1Z9Gkr5RSNhAR4iOCiY8IZvLQuM7f4CSeM85IKaXUJdOkr5RSPkSTvlJK+RBN+kop5UM06SullA/RpK+UUj5Ek75SSvkQTfpKKeVDxBhjdwznEJFi4NAlfEQcUOKkcJxJ47o4GtfF0bgujjfGNcgYE99ZIbdL+pdKRHKMMZl2x9GWxnVxNK6Lo3FdHF+OSy/vKKWUD9Gkr5RSPsQbk/5LdgfQAY3r4mhcF0fjujg+G5fXXdNXSinVMW9s6SullOqARyZ9EZkjIrkiki8iS9p5PVhEljte3ywiyW4S1+0iUiwinzu2O1wU1ysickJEdnbwuojIbxxx7xCR8W4S1zQRKWt1vH7moriSRGSDiOwWkV0i8v12yrj8mHUxLpcfMxEJEZEtIrLdEdcj7ZRx+W+yi3HZ9Zv0F5HPROStdl7r2WNljPGoDfAH9gGDgSBgOzC8TZl7gN86Hi8GlrtJXLcDz9pwzK4ExgM7O3h9LvAuIMAkYLObxDUNeMuG45UAjHc8jgDy2vl/6fJj1sW4XH7MHMcg3PE4ENgMTGpTxo7fZFfisus3+QDwenv/r3r6WHliS38ikG+M2W+MqQOWAde3KXM98Krj8UpghvT88vNdicsWxphNwMkLFLke+JOxfAr0FpEEN4jLFsaYY8aYbY7HFcAeYECbYi4/Zl2My+Ucx6DS8TTQsbXtLHT5b7KLcbmciCQC1wK/76BIjx4rT0z6A4CCVs8LOf8ffnMZY0wDUAbEukFcADc5LgesFJGkHo6pq7oaux2+4jg9f1dERrj6yx2n1uOwWomt2XrMLhAX2HDMHJcrPgdOAGuNMR0eLxf+JrsSF7j+N/kU8GOgqYPXe/RYeWLS92T/ApKNMaOBtbTU5qp927BuLR8DPAP8w5VfLiLhwBvA/caYcld+94V0Epctx8wY02iMGQskAhNFZKQrvrczXYjLpb9JEfkqcMIYs7Unv+dCPDHpHwFa18aJjn3tlhGRACAKKLU7LmNMqTGm1vH098CEHo6pq7pyTF3OGFN+9vTcGPMOECgiLllBWkQCsRLra8aYv7dTxJZj1llcdh4zx3eeBjYAc9q8ZMdvstO4bPhNTgbmichBrEvAV4nIX9qU6dFj5YlJPxtIFZEUEQnC6uhY1abMKuA2x+P5wHrj6BWxM64213znYV2TdQergP9wjEiZBJQZY47ZHZSI9Dt7LVNEJmL9e+3xROH4zpeBPcaYpR0Uc/kx60pcdhwzEYkXkd6Ox72AWcCXbYq5/DfZlbhc/Zs0xvzEGJNojEnGyhHrjTG3tinWo8cqwFkf5CrGmAYRuRdYjTVi5hVjzC4ReRTIMcaswvph/FlE8rE6Che7SVz3icg8oMER1+09HReAiPwVa1RHnIgUAg9hdWphjPkt8A7WaJR8oAr4ppvENR+4W0QagGpgsQsqb7BaY98AvnBcDwb4KTCwVWx2HLOuxGXHMUsAXhURf6xKZoUx5i27f5NdjMuW32RbrjxWekeuUkr5EE+8vKOUUqqbNOkrpZQP0aSvlFI+RJO+Ukr5EE36SinlQzTpK6WUD9Gkr5RSPkSTvlJK+ZD/DyWqyTHorp4tAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')\n",
    "plt.plot(val_loss_hist, label = 'validation')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: 98.86%\n"
     ]
    }
   ],
   "source": [
    "yhat = np.argmax(cnn.predict(x_tst), axis = -1)\n",
    "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}