{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 07 : ConvNet in TensorFlow\n", "Specification of SimpleCNN is same that of [Lec07_ConvNet mnist by low-level.ipynb](https://nbviewer.jupyter.org/github/aisolab/CS20/blob/master/Lec07_ConvNet%20in%20Tensorflow/Lec07_ConvNet%20mnist%20by%20low-level.ipynb)\n", "### ConvNet mnist by high-level\n", "- Creating the **data pipeline** with `tf.data`\n", "- Using `tf.contrib.slim`, alias `slim`\n", "- Creating the model as **Class** with `slim`\n", "- Training the model with **Drop out** technique by `slim.dropout`\n", "- Using tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.12.0\n" ] } ], "source": [ "from __future__ import absolute_import, division, print_function\n", "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and Pre-process data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n", "x_train = x_train / 255\n", "x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)\n", "x_tst = x_tst / 255\n", "x_tst = x_tst.reshape(-1, 28, 28, 1).astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(55000, 28, 28, 1) (55000,)\n", "(5000, 28, 28, 1) (5000,)\n" ] } ], "source": [ "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n", "\n", "x_tr = x_train[tr_indices]\n", "y_tr = y_train[tr_indices].astype(np.int32)\n", "\n", "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n", "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0).astype(np.int32)\n", "\n", "print(x_tr.shape, y_tr.shape)\n", "print(x_val.shape, y_val.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define SimpleCNN class by high-level api (slim)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class SimpleCNN:\n", " def __init__(self, X, y, n_of_classes):\n", " \n", " self._X = X\n", " self._y = y\n", " self._is_training = tf.placeholder(dtype = tf.bool)\n", "\n", " with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn = tf.nn.relu,\n", " weights_initializer = tf.truncated_normal_initializer(),\n", " biases_initializer = tf.truncated_normal_initializer()):\n", " with slim.arg_scope([slim.conv2d], kernel_size = [5, 5], stride = 1, padding = 'SAME'):\n", " with slim.arg_scope([slim.max_pool2d], kernel_size = [2, 2], stride = 2, padding = 'SAME'):\n", " \n", " conv1 = slim.conv2d(inputs = self._X, num_outputs = 32, scope = 'conv1')\n", " pool1 = slim.max_pool2d(inputs = conv1, scope = 'pool1')\n", " conv2 = slim.conv2d(inputs = pool1, num_outputs = 64, scope = 'conv2')\n", " pool2 = slim.max_pool2d(inputs = conv2, scope = 'pool2')\n", " flattened = slim.flatten(inputs = pool2)\n", " fc = slim.fully_connected(inputs = flattened, num_outputs = 1024, scope = 'fc1')\n", " dropped = slim.dropout(inputs = fc, keep_prob = .5, is_training = self._is_training)\n", " self._score = slim.fully_connected(inputs = dropped, num_outputs = n_of_classes,\n", " activation_fn = None, scope = 'score')\n", " self.ce_loss = self._loss(labels = self._y, logits = self._score, scope = 'ce_loss')\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1)\n", " \n", " def _loss(self, labels, logits, scope):\n", " with tf.variable_scope(scope):\n", " ce_loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, logits = logits)f.reduce_mean(\n", " return ce_loss\n", " \n", " def predict(self, sess, x_data, is_training = True):\n", " feed_prediction = {self._X : x_data, self._is_training : is_training}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of SimpleCNN" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "550\n" ] } ], "source": [ "# hyper-parameter\n", "lr = .01\n", "epochs = 30\n", "batch_size = 100\n", "total_step = int(x_tr.shape[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)>\n", "<BatchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)>\n" ] } ], "source": [ "## create input pipeline with tf.data\n", "# for train\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 10000)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)\n", "\n", "# for validation\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))\n", "val_dataset = val_dataset.batch(batch_size = batch_size)\n", "val_iterator = val_dataset.make_initializable_iterator()\n", "print(val_dataset)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "## define Iterator\n", "# tf.data.Iterator.from_string_handle의 output_shapes는 default = None이지만 꼭 값을 넣는 게 좋음\n", "handle = tf.placeholder(dtype = tf.string)\n", "iterator = tf.data.Iterator.from_string_handle(string_handle = handle,\n", " output_types = tr_iterator.output_types,\n", " output_shapes = tr_iterator.output_shapes)\n", "\n", "x_data, y_data = iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "## connecting data pipeline with model\n", "cnn = SimpleCNN(X = x_data, y = y_data, n_of_classes = 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create training op and train model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "\n", "# equal to 'var_list = None'\n", "training_op = opt.minimize(loss = cnn.ce_loss)\n", "\n", "#for tensorboard\n", "loss_summ = tf.summary.scalar(name = 'loss', tensor = cnn.ce_loss)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "## for tensorboard\n", "tr_writer = tf.summary.FileWriter('../graphs/lecture07/convnet_mnist_high/train/', graph = tf.get_default_graph())\n", "val_writer = tf.summary.FileWriter('../graphs/lecture07/convnet_mnist_high/val/', graph = tf.get_default_graph())\n", "saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 5, tr_loss : 7.688, val_loss : 9.466\n", "epoch : 10, tr_loss : 3.639, val_loss : 3.486\n", "epoch : 15, tr_loss : 1.158, val_loss : 1.244\n", "epoch : 20, tr_loss : 0.483, val_loss : 0.416\n", "epoch : 25, tr_loss : 0.467, val_loss : 0.459\n", "epoch : 30, tr_loss : 0.455, val_loss : 0.542\n" ] }, { "data": { "text/plain": [ "'../graphs/lecture07/convnet_mnist_high/cnn/'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n", "sess = tf.Session(config = sess_config)\n", "sess.run(tf.global_variables_initializer())\n", "tr_handle, val_handle = sess.run(fetches = [tr_iterator.string_handle(), val_iterator.string_handle()])\n", "\n", "tr_loss_hist = []\n", "val_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", "\n", " avg_tr_loss = 0\n", " avg_val_loss = 0\n", " tr_step = 0\n", " val_step = 0\n", "\n", " # for mini-batch training\n", " sess.run(tr_iterator.initializer) \n", " try:\n", " while True:\n", " _, tr_loss, tr_loss_summ = sess.run(fetches = [training_op, cnn.ce_loss, loss_summ],\n", " feed_dict = {handle : tr_handle, cnn._is_training : True})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", "\n", " # for validation\n", " sess.run(val_iterator.initializer)\n", " try:\n", " while True:\n", " val_loss, val_loss_summ = sess.run(fetches = [cnn.ce_loss, loss_summ],\n", " feed_dict = {handle : val_handle, cnn._is_training : False})\n", " avg_val_loss += val_loss\n", " val_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", "\n", " avg_tr_loss /= tr_step\n", " avg_val_loss /= val_step\n", " tr_writer.add_summary(summary = tr_loss_summ, global_step = epoch + 1)\n", " val_writer.add_summary(summary = val_loss_summ, global_step = epoch + 1)\n", " tr_loss_hist.append(avg_tr_loss)\n", " val_loss_hist.append(avg_val_loss)\n", " \n", " if (epoch + 1) % 5 == 0:\n", " print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n", "\n", "tr_writer.close()\n", "val_writer.close()\n", "saver.save(sess = sess, save_path = '../graphs/lecture07/convnet_mnist_high/cnn/')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.legend.Legend at 0x7f71c02823c8>" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3WuQXGW97/Hvvy/Tncx0ksnFmBsmHjgSEy4JI8ZCEIxagEcCys3SbaDQ7KLwgNvj3ub4Bj2lVXhUQKoUK4rseErB7CAmZxduRAxejoIkGkJC0ARMdib3BHK/zPT0/7xYT8/0THpmeu7p1b9PVdda61mXedZ08us1z1r9PObuiIhIfCVGugIiIjK0FPQiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzCnoRkZhT0IuIxJyCXkQk5lIjXQGAiRMn+syZM0e6GiIiVWXdunUH3H1Sb9udFUE/c+ZM1q5dO9LVEBGpKma2vZLtKmq6MbN/MrNNZrbRzB4zs6yZzTKzF8xsq5n91MzqwraZsLw1rJ/Z/9MQEZGB6jXozWwacDfQ5O5zgSRwK/B14AF3Pxd4E7gj7HIH8GYofyBsJyIiI6TSm7EpYJSZpYDRwG7g/cDKsH45cH2YXxSWCesXmpkNTnVFRKSvem2jd/edZvZN4D+Bk8AvgXXAIXfPh82agWlhfhqwI+ybN7PDwATgwCDXXUTOUq2trTQ3N3Pq1KmRrkosZLNZpk+fTjqd7tf+vQa9mTUSXaXPAg4B/wZc3a+f1vm4S4AlAOecc85ADyciZ5Hm5mZyuRwzZ85Ef9APjLtz8OBBmpubmTVrVr+OUUnTzQeAv7v7fndvBX4GXAaMC005ANOBnWF+JzADIKwfCxwsU/ll7t7k7k2TJvX6dJCIVJFTp04xYcIEhfwgMDMmTJgwoL+OKgn6/wQWmNno0Na+EHgFWAPcGLZZDKwK86vDMmH9r13DWInUHIX84Bno77LXoHf3F4huqv4ZeDnsswz4IvB5M9tK1Ab/SNjlEWBCKP88sHRANezBi9ve4BtPv0qhoM8REZHuVPTUjbvf6+7nu/tcd/8Hdz/t7q+7+6Xufq673+Tup8O2p8LyuWH960NV+Zd2HOI7a17jWEu+941FpGYcOnSI7373u33e79prr+XQoUNDUKORVdV93TRkolsEx04p6EWkQ3dBn8/3nBVPPfUU48aNG6pqjZizoguE/splo0eNjp1W0ItIh6VLl/Laa69x8cUXk06nyWazNDY28uqrr/K3v/2N66+/nh07dnDq1CnuuecelixZAnR0x3Ls2DGuueYa3vve9/KHP/yBadOmsWrVKkaNGjXCZ9Y/VR30Ddmo+kdPtY5wTUSkO1/5v5t4ZdeRQT3mO6eO4d6PzOl2/X333cfGjRtZv349zz33HB/+8IfZuHFj++OJP/zhDxk/fjwnT57kXe96Fx/72MeYMGFCp2Ns2bKFxx57jO9///vcfPPNPPHEE3zyk58c1PMYLtUd9Jli0OuKXkS6d+mll3Z6Bv2hhx7iySefBGDHjh1s2bLljKCfNWsWF198MQCXXHIJ27ZtG7b6DraqDvox4YpeTTciZ6+erryHS319ffv8c889x69+9Sv++Mc/Mnr0aK688sqyz6hnMpn2+WQyycmTJ4elrkOhum/GZnVFLyJnyuVyHD16tOy6w4cP09jYyOjRo3n11Vd5/vnnh7l2w6+qr+j11I2IlDNhwgQuu+wy5s6dy6hRo5g8eXL7uquvvprvfe97zJ49m3e84x0sWLBgBGs6PKo66OvrUpjBUTXdiEgXP/nJT8qWZzIZfvGLX5RdV2yHnzhxIhs3bmwv/8IXvjDo9RtOVd10k0gYDXUpPXUjItKDqg56iNrp1XQjItK9qg/6XDalp25ERHpQ9UHfkEnpqRsRkR5Uf9Bn07oZKyLSg6oP+lw2xTHdjBUR6Vb1B72abkRkgBoaGgDYtWsXN954Y9ltrrzyStauXdvjcR588EFOnDjRvny2dHtc9UHfkNHNWBEZHFOnTmXlypX93r9r0J8t3R73GvRm9g4zW1/yOmJmnzOz8Wb2jJltCdPGsL2Z2UNmttXMNpjZ/KE8gVw2zYmWNto0ypSIBEuXLuU73/lO+/KXv/xlvvrVr7Jw4ULmz5/PBRdcwKpVq87Yb9u2bcydOxeAkydPcuuttzJ79mxuuOGGTn3d3HnnnTQ1NTFnzhzuvfdeIOoobdeuXVx11VVcddVVQNTt8YEDBwC4//77mTt3LnPnzuXBBx9s/3mzZ8/mM5/5DHPmzOFDH/rQkPSp0+s3Y939r8DFAGaWJBr8+0miIQKfdff7zGxpWP4icA1wXni9G3g4TIdEsb+bY6fyjB2dHqofIyL99YulsOflwT3mWy+Aa+7rdvUtt9zC5z73Oe666y4AVqxYwdNPP83dd9/NmDFjOHDgAAsWLOC6667rdjzWhx9+mNGjR7N582Y2bNjA/Pkd16xf+9rXGD9+PG1tbSxcuJANGzZw9913c//997NmzRomTpzY6Vjr1q3j0Ucf5YUXXsDdefe738373vc+Ghsbh6U75L423SwEXnP37cAiYHkoXw5cH+YXAT/yyPPAODObMii1LSNX7Kr4tG7Iikhk3rx57Nu3j127dvHSSy/R2NjIW9/6Vr70pS9x4YUX8oEPfICdO3eyd+/ebo/x29/+tj1wL7zwQi688ML2dStWrGD+/PnMmzePTZs28corr/RYn9///vfccMMN1NfX09DQwEc/+lF+97vfAcPTHXJf+7q5FXgszE92991hfg9Q7DVoGrCjZJ/mULabIZBTV8UiZ7cerryH0k033cTKlSvZs2cPt9xyCz/+8Y/Zv38/69atI51OM3PmzLLdE/fm73//O9/85jd58cUXaWxs5LbbbuvXcYqGozvkiq/ozawOuA74t67r3N2BPjWSm9kSM1trZmv379/fl107UVfFIlLOLbfcwuOPP87KlSu56aabOHz4MG95y1tIp9OsWbOG7du397j/FVdc0d4x2saNG9mwYQMAR44cob6+nrFjx7J3795OHaR11z3y5Zdfzs9//nNOnDjB8ePHefLJJ7n88ssH8Wx71pcr+muAP7t78W+dvWY2xd13h6aZfaF8JzCjZL/poawTd18GLANoamrq951UdVUsIuXMmTOHo0ePMm3aNKZMmcInPvEJPvKRj3DBBRfQ1NTE+eef3+P+d955J7fffjuzZ89m9uzZXHLJJQBcdNFFzJs3j/PPP58ZM2Zw2WWXte+zZMkSrr76aqZOncqaNWvay+fPn89tt93GpZdeCsCnP/1p5s2bN2yjVll0MV7BhmaPA0+7+6Nh+RvAwZKbsePd/V/M7MPAZ4FriW7CPuTul/Z07KamJu/t+dTubN13jA/c/xse+vg8rrtoar+OISKDa/PmzcyePXukqxEr5X6nZrbO3Zt627eiK3ozqwc+CPxjSfF9wAozuwPYDtwcyp8iCvmtwAng9kp+Rn/lNEC4iEiPKgp6dz8OTOhSdpDoKZyu2zpw16DUrgJquhER6VnVfzN2dF2ShOmpG5GzTaXNwtK7gf4uqz7ozUxdFYucZbLZLAcPHlTYDwJ35+DBg2Sz2X4fo6rHjC3KZdMKepGzyPTp02lubmYgj05Lh2w2y/Tp0/u9f0yCPsUxfTNW5KyRTqeZNWvWSFdDgqpvugGNMiUi0pN4BL3GjRUR6VYsgj6XTevxShGRbsQi6BsyKY4o6EVEyopF0OtmrIhI9+IR9JkUp1oLtLYVRroqIiJnnVgEfekoUyIi0lk8gj6jwUdERLoTi6DPZaOxYvUsvYjImWIS9OqqWESkO7EIejXdiIh0LxZBrwHCRUS6F4ugLz51oy9NiYicqaKgN7NxZrbSzF41s81m9h4zG29mz5jZljBtDNuamT1kZlvNbIOZzR/aU4BcJroZq8crRUTOVOkV/beB/3D384GLgM3AUuBZdz8PeDYsA1wDnBdeS4CHB7XGZWTTCVIJ07djRUTK6DXozWwscAXwCIC7t7j7IWARsDxsthy4PswvAn7kkeeBcWY2ZdBr3rmONGTVVbGISDmVXNHPAvYDj5rZX8zsB2ZWD0x2991hmz3A5DA/DdhRsn9zKOvEzJaY2VozWzsYo9A0ZFJquhERKaOSoE8B84GH3X0ecJyOZhoAPBoYsk+DQ7r7MndvcvemSZMm9WXXsnLZNEf11I2IyBkqCfpmoNndXwjLK4mCf2+xSSZM94X1O4EZJftPD2VDKpdJ6QtTIiJl9Br07r4H2GFm7whFC4FXgNXA4lC2GFgV5lcDnwpP3ywADpc08QwZjTIlIlJepYOD/3fgx2ZWB7wO3E70IbHCzO4AtgM3h22fAq4FtgInwrZDLpdN8fp+Bb2ISFcVBb27rweayqxaWGZbB+4aYL36TAOEi4iUF4tvxkLUdKObsSIiZ4pN0I/JpmnJFzidbxvpqoiInFViE/TtPViq+UZEpJP4Bb2ab0REOolN0HcMPqKgFxEpFZugb1DQi4iUFZugb++qWE03IiKdxCfo20eZUjcIIiKlYhP0aroRESkvPkGfUdCLiJQTm6DPppPUJRNqoxcR6SI2QQ+hGwR1VSwi0km8gl6jTImInCFWQZ/TuLEiImeIVdA3ZNSDpYhIV7EK+lxWTTciIl1VFPRmts3MXjaz9Wa2NpSNN7NnzGxLmDaGcjOzh8xsq5ltMLP5Q3kCpaIBwnUzVkSkVF+u6K9y94vdvTjS1FLgWXc/D3g2LANcA5wXXkuAhwersr3RzVgRkTMNpOlmEbA8zC8Hri8p/5FHngfGmdmUAfycihUHCI9GMxQREag86B34pZmtM7MloWyyu+8O83uAyWF+GrCjZN/mUDbkctkUrW3O6XxhOH6ciEhVqGhwcOC97r7TzN4CPGNmr5audHc3sz5dRocPjCUA55xzTl927VaupBuEbDo5KMcUEal2FV3Ru/vOMN0HPAlcCuwtNsmE6b6w+U5gRsnu00NZ12Muc/cmd2+aNGlS/8+gRENWo0yJiHTVa9CbWb2Z5YrzwIeAjcBqYHHYbDGwKsyvBj4Vnr5ZABwuaeIZUsU+6dUNgohIh0qabiYDT5pZcfufuPt/mNmLwAozuwPYDtwctn8KuBbYCpwAbh/0Wnej/YpeT96IiLTrNejd/XXgojLlB4GFZcoduGtQatdH7V0Vq+lGRKRdrL4ZOyZbbLpR0IuIFMUq6DuabtRGLyJSFK+gz+ipGxGRrmIV9HWpBJlUQk03IiIlYhX0EPqk1xW9iEi72AW9OjYTEeksdkGfy6b1hSkRkRKxC/qGTEo3Y0VESsQv6DVurIhIJ7ELeg0QLiLSWfyCXk03IiKdxC7oNcqUiEhnsQv6XDZNW8E52do20lURETkrxC7o27tBUDu9iAgQw6DPZdVVsYhIqfgGva7oRUSAPgS9mSXN7C9m9u9heZaZvWBmW83sp2ZWF8ozYXlrWD9zaKpeXkMYTlBNNyIikb5c0d8DbC5Z/jrwgLufC7wJ3BHK7wDeDOUPhO2GTUdXxeoGQUQEKgx6M5sOfBj4QVg24P3AyrDJcuD6ML8oLBPWLwzbD4ti080RXdGLiACVX9E/CPwLUAjLE4BD7l5M02ZgWpifBuwACOsPh+2HRU4DhIuIdNJr0JvZfwP2ufu6wfzBZrbEzNaa2dr9+/cP2nHrNcqUiEgnlVzRXwZcZ2bbgMeJmmy+DYwzs1TYZjqwM8zvBGYAhPVjgYNdD+ruy9y9yd2bJk2aNKCTKJVOJhiVTqqrYhGRoNegd/f/6e7T3X0mcCvwa3f/BLAGuDFsthhYFeZXh2XC+l/7MPdHUOwGQUREBvYc/ReBz5vZVqI2+EdC+SPAhFD+eWDpwKrYd7mMerAUESlK9b5JB3d/DnguzL8OXFpmm1PATYNQt35TV8UiIh1i981YUNONiEipeAa9BggXEWkXy6DXAOEiIh1iGfQNmZR6rxQRCWIZ9DmNMiUi0i6WQd+QSeEOx1s0ypSISCyDPpdVV8UiIkWxDPqGrLoqFhEpimXQ5zLqqlhEpCieQa+uikVE2sUy6DuabhT0IiLxDPpMcYBwtdGLiMQy6ItP3ahjMxGRmAZ9g0aZEhFpF8ugTyaM0XVJXdGLiBDToIfQDYKCXkSkosHBs2b2JzN7ycw2mdlXQvksM3vBzLaa2U/NrC6UZ8Ly1rB+5tCeQnkNGfVJLyIClV3Rnwbe7+4XARcDV5vZAuDrwAPufi7wJnBH2P4O4M1Q/kDYbtg1ZNMc0VM3IiIVDQ7u7n4sLKbDy4H3AytD+XLg+jC/KCwT1i80Mxu0GldojEaZEhEBKmyjN7Okma0H9gHPAK8Bh9y9mKTNwLQwPw3YARDWHyYaPHxYaZQpEZFIRUHv7m3ufjEwnWhA8PMH+oPNbImZrTWztfv37x/o4c7QkNEA4SIi0Menbtz9ELAGeA8wzsxSYdV0YGeY3wnMAAjrxwIHyxxrmbs3uXvTpEmT+ln97uWyaTXdiIhQ2VM3k8xsXJgfBXwQ2EwU+DeGzRYDq8L86rBMWP9rH4GhnhpCG32hoFGmRKS2pXrfhCnAcjNLEn0wrHD3fzezV4DHzeyrwF+AR8L2jwD/x8y2Am8Atw5BvXtV7Kr4WEueMaFLBBGRWtRr0Lv7BmBemfLXidrru5afAm4alNoNQGlXxQp6Eallsf1mrLoqFhGJxDfo1VWxiAgQ46BXV8UiIpEYB72abkREIMZB39F0o6AXkdoW26DXAOEiIpHYBn19XbiiV9ONiNS42AZ9ImGhvxs9dSMitS22QQ8aZUpEBGIe9BplSkQk7kGfVVfFIiKxDvpcNq2bsSJS8+Id9JkUx3QzVkRqXKyDXqNMiYjEPOhzGiBcRCTeQd+QTXGipY02jTIlIjUs3kGfUTcIIiKVjBk7w8zWmNkrZrbJzO4J5ePN7Bkz2xKmjaHczOwhM9tqZhvMbP5Qn0R3iiNLHT2tG7IiUrsquaLPA//D3d8JLADuMrN3AkuBZ939PODZsAxwDXBeeC0BHh70WldIo0yJiFQQ9O6+293/HOaPApuBacAiYHnYbDlwfZhfBPzII88D48xsyqDXvALqqlhEpI9t9GY2k2ig8BeAye6+O6zaA0wO89OAHSW7NYeyrsdaYmZrzWzt/v37+1jtyqirYhGRPgS9mTUATwCfc/cjpevc3YE+Pdri7svcvcndmyZNmtSXXStWDHp9O1ZEallFQW9maaKQ/7G7/ywU7y02yYTpvlC+E5hRsvv0UDbsGjLFcWN1M1ZEalclT90Y8Aiw2d3vL1m1Glgc5hcDq0rKPxWevlkAHC5p4hlWaroREYFUBdtcBvwD8LKZrQ9lXwLuA1aY2R3AduDmsO4p4FpgK3ACuH1Qa9wHo+uSmOmpGxGpbb0Gvbv/HrBuVi8ss70Ddw2wXoPCzNTfjYjUvFh/MxaiL00p6EWklsU+6KNRpnQzVkRqV/yDXqNMiUiNi33Qq6tiEal1sQ/6hkxKj1eKSE2LfdDnsimOKOhFpIbVQNCndTNWRGpa7IO+IZPiVGuB1rbCSFdFRGRE1ETQg7pBEJHaFfugz2nwERGpcTUT9HqWXkRqVeyDXl0Vi0iti33Qq+lGRGpd7INeA4SLSK2LfdDnwlM3+tKUiNSq+Ad9Nmqj1+OVIlKrKhlK8Idmts/MNpaUjTezZ8xsS5g2hnIzs4fMbKuZbTCz+UNZ+Upk0wmSCdO3Y0WkZlVyRf+vwNVdypYCz7r7ecCzYRngGuC88FoCPDw41ew/jTIlIrWu16B3998Cb3QpXgQsD/PLgetLyn/kkeeBcWY2ZbAq21+5rHqwFJHa1d82+snuvjvM7wEmh/lpwI6S7ZpD2YhqyKQ4qqduRKRGDfhmbBgM3Pu6n5ktMbO1ZrZ2//79A61Gj3LZlL4wJSI1q79Bv7fYJBOm+0L5TmBGyXbTQ9kZ3H2Zuze5e9OkSZP6WY3KRF0V64peRGpTf4N+NbA4zC8GVpWUfyo8fbMAOFzSxDNiNMqUiNSyVG8bmNljwJXARDNrBu4F7gNWmNkdwHbg5rD5U8C1wFbgBHD7ENS5zzRAuIjUsl6D3t0/3s2qhWW2deCugVZqsOWyuhkrIrUr9t+MhagbhJZ8gdP5tpGuiojIsKuJoNcoUyJSy2oi6Nv7u1HzjYjUoJoI+gaNMiUiNawmgr7YVbGCXkRqUW0EvZpuRKSG1UTQd4wypW4QRKT21EbQq+lGRGpYTQR9TjdjRaSG1UTQZ1IJRqWTrF6/ixdePzjS1RERGVY1EfRmxrduvojDJ1u5Zdnz3P7on9i8+8hIV0tEZFhUd9Dv2Qh/+j4c3dvrptdeMIXn/vlKll5zPuu2v8m1D/2Of/rpena8cWIYKioiMnIs6odsZDU1NfnatWv7vuNvvgFrvgqWgLddBnNugNnXQUPP/dsfPtHKw795jUf/398puPPJBW/js1edy4SGTD/PQERk+JnZOndv6nW7qg56gH2bYdOTsPFncHBLFPqzrugI/dHju911z+FTfPvZv7FibTPZVILPXPF2Pn3529uf0hEROZvVTtAXucPeTVHob/oZvPE6WBLefiXM/Sic/2EY1Vh21637jvGtX/6VX2zcw4T6Oj5zxdt518zxzJ6SY3SdQl9Ezk61F/Sl3GHPhugqf9OTcGg7JNIw8b9C49tg3NvOnGYaWL/jEF//xav8MTyZkzCYNbGeOVPHMmfqGOZMHcs7p45hfH3d4NVVRKSfRjTozexq4NtAEviBu9/X0/aDHvSl3GHXX2Dz6qiZ583tUfC3drkJO3pCe/Afy76VXa05tp3Msvlolg1vpPjr0QwHPcdJskwZm2XO1DG8c+pYpo3LMnZUHeNGp6NXmM+mk5Bvgdbj0HIcCm0wahxkxoDZ0JyriNSUEQt6M0sCfwM+CDQDLwIfd/dXuttnSIO+HHc4cTCE/raO8C9OD++EttNld80nshxOjOVAIcfufD3mzig7zWhOM4rT1Nup9vm0nTnQSRtJTqXGcCo9lpb0WFrrxpHPjKOQHYePGo+NHk8yM4pUOksqlSZVlyWdriNVlyFdlyGdzpBIpSFZB8l0eNVFf7F0Wk7pA0Uk5ioN+qFogL4U2Orur4eKPA4sAroN+mFnBvUTo9f0S85c7w4tx+D4geh1omOaOn6ACScOMuH4Ac47foB8AVqSE2ixLKcsyzGy7PU6jhUyHG1Lc7itjjfzaU60OJn8EUblj9DQeoSG00cYy3EabRvj7CjjOM5oK//h0l+tpMiTJE+KAgncDMcokACMAtGyWyKahnVtliJvKQqdpmkKlqKtfZqCRALDMOv6SkCYTxi4JSkk6vBkOkwzeKKOQrIOknUUkpn2Dygzw4BEwkMtIUEBM4hqGb19hmHJFCTTWCINqWhqqejDLpGsw5JpSKZIJtMkk0YikSKRSJJMJUkkkiQSCZLJJMlEikQqScLACgXM8+Bt4AWskMc8lBXasEJb9JuzBJZMYck6SCSj+idS0Qdup+VU9IBA+J2IjIShCPppwI6S5Wbg3UPwc4aOGWRy0Wv8rG43SwB14dVX7s7pfIGTLW2caG1jV0ueUydO0HLsIK2nT9Dacpp8awttraXTVgr507TlWyjkWyDfgnmeRCFPwltJhmnC8yQKrSS9o9woRMEF0dQLQAFzoikFcKJ9PU8yvBKeJ1XIU+fHSXkrSdpIeSspz2M44MUTal82vL0YnAQF6miljjx15EnYyN8XGgkF7/hwLZCgYGGK4SRwiD6Qix/A7dsalEyjD+czPzSij8Ez54us01/v5d+D8FHavkWxDmWPa8VJ+WOV1LbTdqVT66ZFofP5dT4Xb5/2cr5hTdffaKLkzBJl6u5nzHQ+asGMru9Rx7/+RPt88fxLf4sdyx2/mx3z/5lLPnJn2d/DYBmxR0rMbAmwBOCcc84ZqWqMGDMjm06STSfpeBYoB0weuUoNIXenreC0Fgq05Vtpaz2Ft7bQlj9NW+tJPHyIeaEQhZtD9FEU/lu6Rcse/TdpKxTwtjwUWvB8PvrgK7TibdGLtlYo5KP5QhuFQgEPUwptFLyAl8xTaIvqaMnoo8kSFCwV5pMhmJMUSIa/iAokCnnM81ghT8LbsEIr5m0kSss8D+4k2j9YveND1wsdf1d5W/SXJB7KPSxH+5SuK+ocsB3zHeFZjBtKsvLMcCwfl6XxFR2q44P9zGgv/YmdVlqnWG//qyYqS7SvK3+0rvMd51Y++ktr5l1iPUwt0b5V8S/bLtXt/PNKflKn0PaSeTxcRIWz8gLtvyfr+KDsOP/OH3/1jdMZakMR9DuBGSXL00NZJ+6+DFgGURv9ENRDziJmRipppJIJSKdg1KiRrpJIzRiKLhBeBM4zs1lmVgfcCqwegp8jIiIVGPQrenfPm9lngaeJHq/8obtvGuyfIyIilRmSNnp3fwp4aiiOLSIifVPdvVeKiEivFPQiIjGnoBcRiTkFvYhIzCnoRURi7qzoptjM9gPb+7n7RODAIFbnbBC3c4rb+UD8zilu5wPxO6dy5/M2d+95SD3OkqAfCDNbW0nvbdUkbucUt/OB+J1T3M4H4ndOAzkfNd2IiMScgl5EJObiEPTLRroCQyBu5xS384H4nVPczgfid079Pp+qb6MXEZGexeGKXkREelDVQW9mV5vZX81sq5ktHen6DJSZbTOzl81svZkN4yC6g8fMfmhm+8xsY0nZeDN7xsy2hGljT8c4m3RzPl82s53hfVpvZteOZB37ysxmmNkaM3vFzDaZ2T2hvCrfpx7Op2rfJzPLmtmfzOylcE5fCeWzzOyFkHk/DV3B9368am266c8g5Gc7M9sGNLl71T77a2ZXAMeAH7n73FD2v4E33P2+8IHc6O5fHMl6Vqqb8/kycMzdvzmSdesvM5sCTHH3P5tZDlgHXA/cRhW+Tz2cz81U6ftkZgbUu/sxM0sDvwfuAT4P/MzdHzez7wEvufvDvR2vmq/o2wchd/cWoDgIuYwgd/8t8EaX4kXA8jC/nOg/YVXo5nyqmrvvdvc/h/mjwGaisZ6r8n3q4XyqlkeOhcV0eDnwfmBlKK/4ParmoC83CHlVv7lEb+QvzWxH+n+fAAAB5klEQVRdGFM3Lia7++4wv4d4DIz7WTPbEJp2qqKJoxwzmwnMA14gBu9Tl/OBKn6fzCxpZuuBfcAzwGvAIXfPh00qzrxqDvo4eq+7zweuAe4KzQax4u7FEaar2cPAfwEuBnYD3xrZ6vSPmTUATwCfc/cjpeuq8X0qcz5V/T65e5u7X0w07valwPn9PVY1B31Fg5BXE3ffGab7gCeJ3tw42BvaUYvtqftGuD4D4u57w3/CAvB9qvB9Cu2+TwA/dvefheKqfZ/KnU8c3icAdz8ErAHeA4wzs+LIgBVnXjUHfawGITez+nAjCTOrBz4EbOx5r6qxGlgc5hcDq0awLgNWDMPgBqrsfQo3+h4BNrv7/SWrqvJ96u58qvl9MrNJZjYuzI8ieuhkM1Hg3xg2q/g9qtqnbgDC41IP0jEI+ddGuEr9ZmZvJ7qKh2gs359U4/mY2WPAlUQ97e0F7gV+DqwAziHqpfRmd6+KG5zdnM+VRM0BDmwD/rGkbfusZ2bvBX4HvAwUQvGXiNq1q+596uF8Pk6Vvk9mdiHRzdYk0QX5Cnf/XyEnHgfGA38BPunup3s9XjUHvYiI9K6am25ERKQCCnoRkZhT0IuIxJyCXkQk5hT0IiIxp6AXEYk5Bb2ISMwp6EVEYu7/A4xt0gYwdywDAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')\n", "plt.plot(val_loss_hist, label = 'validation')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test acc: 94.63%\n" ] } ], "source": [ "yhat = cnn.predict(sess = sess, x_data = x_tst)\n", "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }