{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 03 : Linear and Logistic Regression\n",
    "### Logistic Regression with tf.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = (x_train  / 255)\n",
    "x_train = x_train.reshape(-1, 784)\n",
    "x_tst = (x_tst / 255)\n",
    "x_tst = x_tst.reshape(-1, 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784) (55000,)\n",
      "(5000, 784) (5000,)\n"
     ]
    }
   ],
   "source": [
    "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n",
    "\n",
    "x_tr = x_train[tr_indices]\n",
    "y_tr = y_train[tr_indices]\n",
    "\n",
    "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n",
    "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0)\n",
    "\n",
    "print(x_tr.shape, y_tr.shape)\n",
    "print(x_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the graph of Softmax Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyper-par setting\n",
    "epochs = 30\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 784), (?,)), types: (tf.float64, tf.uint8)>\n",
      "<BatchDataset shapes: ((?, 784), (?,)), types: (tf.float64, tf.uint8)>\n"
     ]
    }
   ],
   "source": [
    "# for train\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 10000)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)\n",
    "\n",
    "# for validation\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))\n",
    "val_dataset = val_dataset.batch(batch_size = batch_size)\n",
    "val_iterator = val_dataset.make_initializable_iterator()\n",
    "print(val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Iterator\n",
    "handle = tf.placeholder(dtype = tf.string)\n",
    "iterator = tf.data.Iterator.from_string_handle(string_handle = handle,\n",
    "                                               output_types = tr_iterator.output_types)\n",
    "X, Y = iterator.get_next()\n",
    "X = tf.cast(X, dtype = tf.float32)\n",
    "Y = tf.cast(Y, dtype = tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create weight and bias, initialized to 0 \n",
    "w = tf.get_variable(name = 'weights', shape = [784, 10], dtype = tf.float32,\n",
    "                    initializer = tf.contrib.layers.xavier_initializer())\n",
    "b = tf.get_variable(name = 'bias', shape = [10], dtype = tf.float32,\n",
    "                    initializer = tf.zeros_initializer())\n",
    "# construct model\n",
    "score = tf.matmul(X, w) + b\n",
    "\n",
    "# use the cross entropy as loss function\n",
    "ce_loss = tf.losses.sparse_softmax_cross_entropy(labels = Y, logits = score)\n",
    "ce_loss_summ = tf.summary.scalar(name = 'ce_loss', tensor = ce_loss) # for tensorboard\n",
    "\n",
    "# using gradient descent with learning rate of 0.01 to minimize loss\n",
    "opt = tf.train.GradientDescentOptimizer(learning_rate=.01)\n",
    "training_op = opt.minimize(ce_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/logreg_tf_data/train',\n",
    "                                     graph = tf.get_default_graph())\n",
    "val_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/logreg_tf_data/val',\n",
    "                                     graph = tf.get_default_graph())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   5, tr_loss : 0.416, val_loss : 0.432\n",
      "epoch :  10, tr_loss : 0.358, val_loss : 0.383\n",
      "epoch :  15, tr_loss : 0.334, val_loss : 0.362\n",
      "epoch :  20, tr_loss : 0.320, val_loss : 0.349\n",
      "epoch :  25, tr_loss : 0.311, val_loss : 0.341\n",
      "epoch :  30, tr_loss : 0.304, val_loss : 0.335\n"
     ]
    }
   ],
   "source": [
    "#epochs = 30\n",
    "#batch_size = 64\n",
    "#total_step = int(x_tr.shape[0] / batch_size)\n",
    "\n",
    "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n",
    "sess = tf.Session(config = sess_config)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "tr_handle, val_handle = sess.run(fetches = [tr_iterator.string_handle(), val_iterator.string_handle()])\n",
    "\n",
    "tr_loss_hist = []\n",
    "val_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    avg_tr_loss = 0\n",
    "    avg_val_loss = 0\n",
    "    tr_step = 0\n",
    "    val_step = 0\n",
    "    \n",
    "    # for mini-batch training\n",
    "    sess.run([tr_iterator.initializer])\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss,tr_loss_summ = sess.run(fetches = [training_op, ce_loss, ce_loss_summ],\n",
    "                                               feed_dict = {handle : tr_handle})\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    # for validation\n",
    "    sess.run([val_iterator.initializer])\n",
    "    try:\n",
    "        while True:\n",
    "            val_loss, val_loss_summ = sess.run(fetches = [ce_loss, ce_loss_summ],\n",
    "                                                          feed_dict = {handle : val_handle})\n",
    "            avg_val_loss += val_loss\n",
    "            val_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    train_writer.add_summary(tr_loss_summ, global_step = epoch)\n",
    "    val_writer.add_summary(val_loss_summ, global_step = epoch)\n",
    "\n",
    "    avg_tr_loss /= tr_step\n",
    "    avg_val_loss /= val_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    val_loss_hist.append(avg_val_loss)\n",
    "    \n",
    "    if (epoch + 1) % 5 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n",
    "\n",
    "train_writer.close()\n",
    "val_writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f0ba8096048>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmcXFWd9/HPr9burl7SW9ZOSIBAmoRsNAEHwyKKETSIiIA6igvMqLg8M/p60JkR5NFnmGeQYRwRXsDg6AzIMLiAM1EUDSKokESSmAWSAIF09nR636qX8/xxb3dXd6q7K0l1V1f19/161etup6rPTcH33jr33HPNOYeIiOSWQKYrICIi6adwFxHJQQp3EZEcpHAXEclBCncRkRykcBcRyUEKdxGRHKRwFxHJQQp3EZEcFMrUH66oqHBz587N1J8XEclKGzZsOOKcqxytXMbCfe7cuaxfvz5Tf15EJCuZ2RuplFOzjIhIDlK4i4jkIIW7iEgOylibu4jklq6uLmpra+no6Mh0VXJCXl4eVVVVhMPhE3q/wl1E0qK2tpaioiLmzp2LmWW6OlnNOUddXR21tbXMmzfvhD5DzTIikhYdHR2Ul5cr2NPAzCgvLz+pX0EKdxFJGwV7+pzsv2XWhfu63Uf5h5+/jB4PKCIyvKwL9017Grj3mVdpau/OdFVEZAJpaGjgO9/5znG/7/LLL6ehoWEMapRZWRfuFYVRAOpaOzNcExGZSIYL9+7ukU8E16xZw5QpU8aqWhmTdeFeFosAUNcaz3BNRGQiueWWW3j11VdZunQp5557LitXrmT16tWcddZZALz3ve/lnHPOYeHChdx///3975s7dy5Hjhxh9+7dVFdXc+ONN7Jw4UIuu+wy2tvbM7U7Jy3rukKWF/rh3qJwF5movvbTrWzb15TWzzxrZjG3vmfhsNvvuOMOtmzZwsaNG3nmmWe44oor2LJlS39XwoceeoiysjLa29s599xzufrqqykvLx/0GTt37uQHP/gBDzzwAB/4wAf44Q9/yIc//OG07sd4GfXM3cweMrNDZrZlmO1mZt8ys11mttnMlqe/mgPKY2qWEZHRrVixYlAf8W9961ssWbKE888/nz179rBz585j3jNv3jyWLl0KwDnnnMPu3bvHq7ppl8qZ+78B3wa+P8z2dwHz/dd5wL3+dEz0Ncsc1Zm7yIQ10hn2eInFYv3zzzzzDE8//TS///3vKSgo4OKLL07ahzwajfbPB4PBrG6WGfXM3Tn3LHB0hCJXAt93nj8AU8xsRroqOFQkFKAoL6Q2dxEZpKioiObm5qTbGhsbKS0tpaCggJdffpk//OEP41y78ZeONvdZwJ6E5Vp/3f40fHZSFYVRhbuIDFJeXs4FF1zAokWLyM/PZ9q0af3bVq1axX333Ud1dTVnnnkm559/fgZrOj7G9YKqmd0E3AQwZ86cE/6csliEo2pzF5EhHnnkkaTro9EoP/vZz5Ju62tXr6ioYMuWgUuLX/ziF9Nev/GUjq6Qe4HZCctV/rpjOOfud87VOOdqKitHfUrUsMpjEfWWEREZQTrC/UngI36vmfOBRufcmDXJgNcdUs0yIiLDG7VZxsx+AFwMVJhZLXArEAZwzt0HrAEuB3YBbcDHxqqyfcpjUY62xuntdQQCGqhIRGSoUcPdOXf9KNsd8Jm01SgFZbEIPb2Opo4uphRExvNPi4hkhawbfgAG7lI9onZ3EZGksjPc/btUj6rdXUQkqewM9/7xZdQdUkROTGFhIQD79u3j/e9/f9IyF198MevXrx/xc+6++27a2tr6lyfKEMLZGe4aGVJE0mTmzJk8/vjjJ/z+oeE+UYYQzspwL41pZEgRGeyWW27hnnvu6V++7bbb+PrXv86ll17K8uXLOfvss3niiSeOed/u3btZtGgRAO3t7Vx33XVUV1dz1VVXDRpb5lOf+hQ1NTUsXLiQW2+9FfAGI9u3bx+XXHIJl1xyCTAwhDDAXXfdxaJFi1i0aBF33313/98bj6GFs27IX4BwMEBJflh3qYpMVD+7BQ78Kb2fOf1seNcdw26+9tpr+cIXvsBnPuN13nvsscd46qmn+NznPkdxcTFHjhzh/PPPZ/Xq1cM+n/Tee++loKCA7du3s3nzZpYvHxjk9hvf+AZlZWX09PRw6aWXsnnzZj73uc9x1113sXbtWioqKgZ91oYNG/jud7/LCy+8gHOO8847j4suuojS0tJxGVo4K8/cwWuaOaJmGRHxLVu2jEOHDrFv3z42bdpEaWkp06dP5ytf+QqLFy/m7W9/O3v37uXgwYPDfsazzz7bH7KLFy9m8eLF/dsee+wxli9fzrJly9i6dSvbtm0bsT7PPfccV111FbFYjMLCQt73vvfx29/+FhifoYWz8swdvIuqGvZXZIIa4Qx7LF1zzTU8/vjjHDhwgGuvvZaHH36Yw4cPs2HDBsLhMHPnzk061O9oXn/9de68807WrVtHaWkpN9xwwwl9Tp/xGFo4a8/cy2IRPbBDRAa59tprefTRR3n88ce55ppraGxsZOrUqYTDYdauXcsbb7wx4vsvvPDC/sHHtmzZwubNmwFoamoiFotRUlLCwYMHBw1CNtxQwytXruQnP/kJbW1ttLa28uMf/5iVK1emcW9HlsVn7lE2vFGf6WqIyASycOFCmpubmTVrFjNmzOBDH/oQ73nPezj77LOpqalhwYIFI77/U5/6FB/72Meorq6murqac845B4AlS5awbNkyFixYwOzZs7ngggv633PTTTexatUqZs6cydq1a/vXL1++nBtuuIEVK1YA8MlPfpJly5aN29OdzBs9YPzV1NS40fqPjuSbv3iFe9buYtc3Ltf4MiITwPbt26murs50NXJKsn9TM9vgnKsZ7b1Z2yxTHovQ66ChvSvTVRERmXCyNtzLCvuGIFC7u4jIUFkb7hUxDR4mMtFkqpk3F53sv2XWhnuZP76MBg8TmRjy8vKoq6tTwKeBc466ujry8vJO+DOyt7eMPzKkBg8TmRiqqqqora3l8OHDma5KTsjLy6OqquqE35+14V5aEAY0eJjIRBEOh5k3b16mqyG+rG2WCQUDlBaENXiYiEgSWRvu4N2lqjZ3EZFjpRTuZrbKzF4xs11mdkuS7aeY2a/MbLOZPWNmJ95QdBzKC6McUZu7iMgxRg13MwsC9wDvAs4Crjezs4YUuxP4vnNuMXA78Pfprmgy5TpzFxFJKpUz9xXALufca865OPAocOWQMmcBv/bn1ybZPia8wcMU7iIiQ6US7rOAPQnLtf66RJuA9/nzVwFFZlY+9IPM7CYzW29m69PRXaq8MEp9W5yeXvWrFRFJlK4Lql8ELjKzl4CLgL1Az9BCzrn7nXM1zrmaysrKk/6j5bEIzkF9m87eRUQSpdLPfS8wO2G5yl/Xzzm3D//M3cwKgaudc2P++O/yhLtUKwqjo5QWEZk8UjlzXwfMN7N5ZhYBrgOeTCxgZhVm1vdZXwYeSm81kyvTg7JFRJIaNdydc93AzcBTwHbgMefcVjO73cxW+8UuBl4xsx3ANOAbY1TfQfrO1vVEJhGRwVIafsA5twZYM2TdVxPmHwceT2/VRtd35q7ukCIig2X1HaqlBRHMNOyviMhQWR3uwYBRWhDRAztERIbI6nAHrzukLqiKiAyW9eGuu1RFRI6V9eFeURjVAztERIbI+nDXsL8iIsfK+nAvL4xQ39ZFd09vpqsiIjJhZH+4+33d69u6MlwTEZGJI/vDXXepiogcI+vDvf8uVXWHFBHpl/Xh3tcsc0QXVUVE+mV/uPvNMkfVHVJEpF/Wh/uU/DABQzcyiYgkyPpwDwRMd6mKiAyR9eEO/o1MuqAqItIvJ8K9PBZVV0gRkQQ5Ee5lhWqWERFJlBPhXqFhf0VEBkkp3M1slZm9Yma7zOyWJNvnmNlaM3vJzDab2eXpr+rwymJRGtu76NL4MiIiQArhbmZB4B7gXcBZwPVmdtaQYn+L9+DsZcB1wHfSXdGRlBf648uoaUZEBEjtzH0FsMs595pzLg48Clw5pIwDiv35EmBf+qo4ur67VNXuLiLiSSXcZwF7EpZr/XWJbgM+bGa1wBrgs8k+yMxuMrP1Zrb+8OHDJ1Dd5PoHD1O7u4gIkL4LqtcD/+acqwIuB/7dzI75bOfc/c65GudcTWVlZZr+9MDgYeoOKSLiSSXc9wKzE5ar/HWJPgE8BuCc+z2QB1Sko4KpqPDb3HXmLiLiSSXc1wHzzWyemUXwLpg+OaTMm8ClAGZWjRfu6Wt3GUVxXphgwPS4PRER36jh7pzrBm4GngK24/WK2Wpmt5vZar/YXwM3mtkm4AfADc45N1aVHmpgfBk1y4iIAIRSKeScW4N3oTRx3VcT5rcBF6S3asenXDcyiYj0y4k7VAGNDCkikiBnwr28MKo2dxERX+6EeyzCET2NSUQEyLFwb+7oJt6t8WVERHIm3Mv6xpdpU9OMiEjOhHt5zBuCQE0zIiK5FO7+mbsuqoqI5FK4xzQEgYhInxwKd39kSJ25i4jkTrgX54cIBYw6tbmLiOROuJt548uozV1EJIfCHby7VI+ozV1EJMfCPRbhqEaGFBHJsXAv1OBhIiKQY+FeFotwVM0yIiK5Fe4VhVGaO7vp7O7JdFVERDIqp8K970HZ6jEjIpNdToa77lIVkckupXA3s1Vm9oqZ7TKzW5Js/ycz2+i/dphZQ/qrOroKf3wZXVQVkclu1GeomlkQuAd4B1ALrDOzJ/3npgLgnPtfCeU/Cywbg7qOqswfgkDdIUVkskvlzH0FsMs595pzLg48Clw5QvnrgR+ko3LHq29kSDXLiMhkl0q4zwL2JCzX+uuOYWanAPOAX5981Y5fUTREOGhqlhGRSS/dF1SvAx53ziXti2hmN5nZejNbf/jw4TT/aW98mfJYVIOHicikl0q47wVmJyxX+euSuY4RmmScc/c752qcczWVlZWp1/I4aPAwEZHUwn0dMN/M5plZBC/AnxxayMwWAKXA79NbxeNTXhjR4GEiMumNGu7OuW7gZuApYDvwmHNuq5ndbmarE4peBzzqnHNjU9XUlOvMXURk9K6QAM65NcCaIeu+OmT5tvRV68SVF6rNXUQkp+5QBa/NvTXeQ0eXxpcRkckr+8J9xy/gvz4Gw7T+6C5VEZFsDPfWQ7D1R3BwS9LN/Xep6qKqiExi2Rfup7/dm+78ZdLNfXepHtEQBCIyiWVfuBdNh+lnw66nk24u7xv2V2fuIjKJZV+4A5z+DnjzD9DReMym8kKvWaZOZ+4iMollZ7jPfwe4HnjtmWM2xSJBIqGALqiKyKSWneFetQKiJbDzF8ds8saXiWhkSBGZ1LIz3IMhOO0S2PWrpF0iywt1l6qITG7ZGe7gNc0070/aJbIsFlWzjIhMatkb7iN0iayIRTQEgYhMatkb7iN0idSwvyIy2WVvuMNAl8j2wc/jLi+M0hbvoT2u8WVEZHLK7nAfpktk341M6usuIpNVdod7X5fIXYPb3fWgbBGZ7LI73IfpElnWNwSB2t1FZJLK7nCHpF0iK/whCI6ox4yITFLZH+5JukTqzF1EJruUwt3MVpnZK2a2y8xuGabMB8xsm5ltNbNH0lvNEfR1iUwI94JIkLywxpcRkclr1HA3syBwD/Au4CzgejM7a0iZ+cCXgQuccwuBL4xBXYd3+jtgzwv9XSK98WWiuqAqIpNWKmfuK4BdzrnXnHNx4FHgyiFlbgTucc7VAzjnDqW3mqOYf9kxXSLLCyPqCikik1Yq4T4L2JOwXOuvS3QGcIaZPW9mfzCzVemqYEqqzoW8wV0idZeqiExm6bqgGgLmAxcD1wMPmNmUoYXM7CYzW29m6w8fPpymP43XJfLUS2Dn0/1dItUsIyKTWSrhvheYnbBc5a9LVAs86Zzrcs69DuzAC/tBnHP3O+dqnHM1lZWVJ1rn5Oa/A1oOwIE/AWqWEZHJLZVwXwfMN7N5ZhYBrgOeHFLmJ3hn7ZhZBV4zzWtprOfo+rpE+k0zZbEIHV29tMW7x7UaIiITwajh7pzrBm4GngK2A48557aa2e1mttov9hRQZ2bbgLXAl5xzdWNV6aT6u0R6o0T2jy+jphkRmYRCqRRyzq0B1gxZ99WEeQf8lf/KnPmXwXN3Q3vDwPgyrXFmlxVktFoiIuMt++9QTXT6wCiR5TFvCAI9tENEJqPcCveELpFlsYEzdxGRySa3wj2hS2R5LAyozV1EJqfcCnfo7xJZcHQ7ZbEIf9rbMPp7RERyTO6Fe0KXyGtqqvj5lgPsOdqW2TqJiIyz3Av3hC6RN/zZXAJmfO93uzNdKxGRcZV74Q5el8g9LzAj0sm7F8/g0XV7aOroynStRETGTW6Ge3+XyLV84q2n0tLZzWPr9oz+PhGRHJGb4d7XJXLn05xdVcJ588r47vO76e7pzXTNRETGRW6Ge1+XyF3eKJE3rjyVvQ3t/GzLgUzXTERkXORmuMOgUSLftmAqp1bEePC3r+H8IYFFRHJZ7oZ7X5fIrT8iEDA+/tZ5bKptZP0b9Zmtl4jIOMjdcC+aDgvfB7/7F9i3kauXVzGlIMyDvx3fkYhFRDIhd8Md4IpvQqwSfnQj+XTy4fNO4RfbDrL7SGumayYiMqZyO9wLyuC998KRHfD0rXzkz04hHAjw3edfz3TNRETGVG6HO8Bpl8D5n4YX72fqgedYvXQmj62vpaFNA4qJSO7K/XAHuPRWqKyGJz7NX9SU0N7VwyMvvpnpWomIjJnJEe7hPLj6AWivZ/4Lf8PK08v53u92E+/WTU0ikpsmR7iDN5jY2/4OXv5v/mbmBg42dfLfm/dlulYiImMipXA3s1Vm9oqZ7TKzW5Jsv8HMDpvZRv/1yfRXNQ3ecjPMXcmZG/8vF1Y088BvX9dNTSKSk0YNdzMLAvcA7wLOAq43s7OSFP1P59xS//VgmuuZHoEAvPdezILcFfoOO/bX8/tX6zJdKxGRtEvlzH0FsMs595pzLg48Clw5ttUaQ1Nmw7vvoqJhE18s+B8efE7dIkUk96QS7rOAxPFya/11Q11tZpvN7HEzm53sg8zsJjNbb2brDx8+fALVTZOz3w+L3s9f9P4Xda/8jl2HmjNXFxGRMZCuC6o/BeY65xYDvwS+l6yQc+5+51yNc66msrIyTX/6BF1xJ65oOndHvsP3n92e2bqIiKRZKuG+F0g8E6/y1/VzztU55zr9xQeBc9JTvTGUX0rwffdxih2kevM/UNfSOfp7RESyRCrhvg6Yb2bzzCwCXAc8mVjAzGYkLK4GsuNUeN6FNC69iesDT/P8moczXRsRkbQZNdydc93AzcBTeKH9mHNuq5ndbmar/WKfM7OtZrYJ+Bxww1hVON1K3/1/eDN8Kiu33crRl5/LdHVERNLCMtXPu6amxq1fvz4jf3uoXds3EXn0GqZbHd2rvknB+TdkukoiIkmZ2QbnXM1o5SbPHaojOL16CfuuWcOLvdUU/PzzdP/0i9DTlelqiYicMIW77/xFp9N89Q94oPsKQhseoPf774VW3eAkItlJ4Z7gXUtmE3vPHXwh/ml63nwBd/9FsH9zpqslInLcFO5DfPC8Ocx/xyd4X8etNLV14v71Mtjyw0xXS0TkuCjck/j0xadx3gWXcmnz19hfcAY8/nF4+jbo7cl01UREUqJwT8LM+Mrl1Vy4/CwuOvTX7Jh9DTz3T/DItdDekOnqiYiMSuE+jEDA+IerF3NR9UzeuesqNi+9FV5bCw9eCvs2Zrp6IiIjUriPIBwM8O0PLufcU8q4et0CNr7t36GjEe6/CP7rY1D3aqarKCKSlMJ9FHnhIA/eUMPpU4u4/qkAG6/6FVz4JdjxFHz7XPjp56FJT3QSkYlF4Z6C4rww3/v4uUwtjvLRh19hy5mfhc9vhHM/CS89DN9aBr/4O2g7mumqiogACveUTS3K4z8+cR754SBXfed5vv1iI93vvAM+ux4WXgW/+xf456Xw7D9CvDXT1RWRSU7hfhxmlxWw5vMrWbVoBnf+Ygfvu/d37IyXw1X3wad+B3PfCr/+uhfyL9wP3fFMV1lEJikNHHaC/mfzfv7uiS20dHTzV5edwY0rTyUYMNjzIjz9NXjjOSiugnM+Cks/BCXJHl4lInJ8Uh04TOF+Eo60dPK3P97Cz7ceYNmcKdx5zRJOqywE5+DVX8Hz34LXfwMWgNPfDss/AmesgmA401UXkSylcB8nzjme3LSPrz6xlY6uHr70zjP5+AXzCATMK3D0dXjpP2Djw9C8H2JTYen1sPyjUH5aZisvIllH4T7ODjV18OUf/YlfvXyIFXPL+MdrFnNKeWygQE837Hoa/vh92PFzcD1wygXe2Xz1aogUZK7yIpI1FO4Z4Jzjh3/cy9d+upXuHseX3nkmHzp/DtFQcHDB5gOw8REv6Otfh2gJLLgc5l8Gp18KeSWZ2QERmfAU7hm0v7GdW374J36z4zBTi6J8cuU8PnjeKRRGQ4ML9vbCG897zTY7n4L2egiEYM5bvKA/451QcQaYZWZHRGTCSWu4m9kq4J+BIPCgc+6OYcpdDTwOnOucGzG5czncwTuLf35XHff+ZhfP76qjOC/ER94ylxsumEtFYfTYN/R0Q+06L+R3/AIObfXWl86F+e+EMy6DU94K4bxx3Q8RmVjSFu5mFgR2AO8AaoF1wPXOuW1DyhUB/wNEgJsne7gn2rSngft+8yo/33qASDDAtefO5saVpzK7bIR29oY9A0H/+m+guwPCBTDvQphzPlStgJnL1FYvMsmkM9zfAtzmnHunv/xlAOfc3w8pdzfwS+BLwBcV7sd69XAL9//mNX70Ui29Dt69eAZ/edFpVM8oHvmNXe3w+m+9sH91LRz1BywLhGD6Ypi9wn+dByVVY78jIpIx6Qz39wOrnHOf9Jf/HDjPOXdzQpnlwN845642s2dQuI/oQGMH//rcazzywpu0xnu45MxK/vwtp7ByfiXhYAo3DbfWQe2L3g1Te16EvRugu93bVjRzIOxn1cD0s3V2L5JDxi3czSwA/Bq4wTm3e6RwN7ObgJsA5syZc84bb7xxfHuVYxrbuvj3P+zmu8/vpq41TmlBmCsWz2D1klnUnFI60Fd+ND1dcHCLH/YvwJ510Pimt80CUFntNeHMXAozl8O0hWq7F8lS49YsY2YlwKtAi/+W6cBRYPVIZ++T+cx9qHh3L8/uOMwTm/bxy20H6OjqZWZJHu9ZOpMrl8yiekYRdrw9Zpr2w76XBr/ajnjbAiGY2hf4y2DGEqhcAJHYyJ8pIhmXznAP4V1QvRTYi3dB9YPOua3DlH8GNcucsNbObp7efpAnNu7j2R2H6e51zJ9ayJVLZ7J6ySzmlJ9gE4tz0FgL+zcODvz2er+AeT1zpp4FUxf402oonw+hSLp2T0ROUrq7Ql4O3I3XFfIh59w3zOx2YL1z7skhZZ9B4Z4WR1vjrPnTfp7cuI8Xd3tjxS+pKuGiM6dy0RkVLKmaQiiVNvrhOAcNb8D+zXD4ZTi0DQ5thyM7vTtowTvLLz/dC/rKaqg4HcpOg7JTIW+UC8Eikna6iSnH7G1o57837ePnWw+waU8DvQ6K8kJccFoFF55RyYVnVFBVmqYLp92dULfLC/r+1zao3w0k/PcSq/RCvi/sy08dWFbwi4wJhXsOa2zr4vlXj/DsjsM8u+Mw+xo7ADi1MsaF872gP//UcgoioVE+6TjF27zhEupehaOveV0yj77uzTftHVy2oByKZ3mvkr5plb9upvcKJbmZS0RGpHCfJJxzvHq4hd/s8ML+hdfr6OjqJRIMcHZVCctmT2HZnFKWzZnCjJK8478wm6p4m3dmf9QP/vrd0LjXC/3GWuhoOPY9sakDwd8X+iVVfvjPgqIZau8XGULhPkl1dPWwbvdRntt5hA1v1LN5byPx7l4AphVHWTa7lOWneIG/aGYJ+ZHgKJ+YJvFW70HijbV+4O+Fplp/us97dTYe+75BB4CZUDTd68tfNN0L/6Lp3kBrGn9HJgmFuwBeN8uXDzTx0psNvPRmPS/taeCNujYAQgGjekYxS2aXcNaMEhbMKGLB9KL0N+ekqrM54QCwzzsI9B8IRjgAhPIHh33RDCic6l0TiFVCrGJgGs4f//0SSSOFuwyrrqWTjXsavMDfU8/mPY00d3YD3gnw3PIY1TOKqJ5eTPWMYqpnFjNzLJt0jke81RsyufmA9/CTQdO++f3Q1Zb8/ZHChLCv9K4NxCq8aYE/jZX7y+Ve+Ymw3yI+hbukzDlHbX072/c3sX1/szc90NR/hg9QnBdiwYxiFkwvYv60Is6YWsgZ04oojU3QNvF4K7QehtYj/rRvfujyYWirg96u5J8TjA4Efn4p5E3xpvmlkD9lyLqE5UghBPT8eUk/hbuctJbObl450MS2/c28vL+J7fub2HGwhRb/LB+gsijKGdMKmT+1iDOmFXnz04ooyc+i58Q6B51NXsi31nnTtjrvjt7+dUegvcG76avvNdwBAbxhH6LF3vWAvGIv8PNKBl592/KnDGzrm8+f4o0Aql8MkoTCXcaEc479jR3sONjMzoMtvHKwmZ0Hm9l5qIW2eE9/ucqiKHPKCphdms+csgKqygqYXVrA7LJ8ZpTkE0x13JyJyjmv6Scx8DsS55u8A0ZHY5JXE8SbR/78YMQ/EEwZCP1oIUSLIFLkTfuXC72DRd9yYhn1Nso5CncZV729jr0N7ew81MyOgy28eqiFPfVt7Dnazv7GdnoT/jMLBYxZpfn9YV9VWkBVab7/KqCyMJr6oGnZqqfbD/8G7wDRP21Mss5f39kC8RbvwnO8ZfS/Ad5BItkBIJJwIEi2PHRdOB9Cefo1MQEo3GXC6OrpZV9DO3uOtvuB38abR9vYU99O7dE26lrjg8pHggFm9Yd9PrOmeKE/qzSf6cV5TC2OHvtc2smmt9cL+L6w72zxDhZ9wd/ZPPAaWmbQ9hboak397waj3oiioXzvJrS+0A/leevDBd4AdJHCgWm0b77In/rrQxEIhCHov/rmA6HByzqgDJJquGeoz5tMJuFggFPKY5xSnnzUybZ4N3vr26mtb6e2oZ3a+jZvvr6dX247yJGW+DHvKYsqDUaoAAAKH0lEQVRFmF6cx/SSPKYV5/nzUaaXeAeAacVRSvLDE6OHz1gIBPy2/DQM89DbM+QA0HzsQaCr3XsaWHcHdHUMzPcvt3tlWuu8g0Vni3dR+3gOHMMJRhIOEn5zVP/ykPlwwcDBZ7Rp38EpkJsnCgp3ybiCSIj507xeOMm0x3vY2+AF/sGmDg40dnKgqcOf72DTnoZjzv4BIqEAU4uiTC2KMq04z5tPmE4rjjK1KI8p+eHcbwYaSSA4cKE33Xp7vGsTfWEfb/amnS3QE/cuSvf4r94ur7mqN3G5yzuA9L2n74DT0eDdDxFv8dc3g+s9sToGQkN+iUQHLwfD3i+WYNg70IQS5ge9hvwCGW4+EPKG2C6Zld5/6yEU7jLh5UeCnD61iNOnJg9/gM7uHg41eaF/oLGDQ82dHGrypgebOth5qIXndx2hqaP7mPcGzPslUBaLUB6LUlYYoSIWobwwSlksQkVhhLJYlLJYmJL8CFMKwqk9MUu8A0dfu/1Ycs775dDV7v+KSPxl0Z58OuyvkCHLXe3egaj/1eVNuzsH5kfqOZXMFXfBuZ8Ym38Ln8JdckI0FGR2WcHIDx3HG57hUFMnB5s7ONTUyaHmDo62xqlrjVPX0snR1jjb9zVxpKUz6YGgT2E0xJSCMKUFXthPKYgwJT9MqT9fFotQGotQVhChNBamLBYhPxzM3WaiTDPzHieZqUdK9vYO+cXRnfDLpO/XSHxgvnTumFdJ4S6TSl44yJzygpQeehLv7qW+Lc6Rlk7qWuI0tHfR0Banoa2Len/a0Banvq2L2vp26tviNLZ3MVwfhWgoQGmBH/qxgQNDSf7gV3F+mCn5EUr8bbGIDgoTXiAAgeiEGulU4S4yjEgowLRi74Jtqnp6HU3tXvjXt8U52tpFfWuco21xb9oa97d1sXVfE43tXTS2d9HTO3yvtVDAKM4PU5QXoigvRHFe33x40HLftDAvREEkRCwaJBYJEYuGKIgEiYYCOkhMIgp3kTQKBoxSv0kmVc45Wjq7+4O+sb2LpoT5hrYumjq6aO7opqndm+4+0kazv665c/jmo0ShgFEQCfaHfWFemOK8UP+vhUG/HvISf0mEKIx6BwkdILKHwl0kw8zMPwsPU1V6/O/v6fUODs0dXTS1d9MW76Y13kNrZ/fAK97jre/018e7vYNFRze19e39B5TuEX5BgHfxuSASGnSQiEVCFPi/EvIjQQoiQfLDQfLC/nwkYT7sLeeHj31/JKgDRzqlFO5mtgr4Z7xnqD7onLtjyPa/BD4D9AAtwE3OuW1prquIJBEMWP9ZNidwcOjjnKMt3uMFfUcXjW0Dvx7a4j3+yztA9B1A2v3l+tY4tfXttHV2097VQ3tXDx1dx9c1Mdj3yyLhYNF3EMmPBIlFgv0HloLE+WiIgnCQgujgg0ee/8oPBwkHbdIdOEYNdzMLAvcA7wBqgXVm9uSQ8H7EOXefX341cBewagzqKyJjxMyI+c0vMzn5ce97ex0d3T20+weGDj/02+ID61rj3bQN+WXRd+DoW3+oucM7sPjb2uI9o/7CGCpg9Ad/Xn/wB8gLeeuiIX853HdgCPSXi4YCREIBIsEA0XCASDDoLYcCg7blhQNEQ175qP++TDZjpXLmvgLY5Zx7DcDMHgWuBPrD3TnXlFA+xqCnKIvIZBQImH92HaI8zZ8d7+6lve/gEB8I/fauHjriA78cvKl3MOmf96cdXd5ntHR2c6QlTueQbR3dPcP2fDoefQeBvgNINBTgC28/g/csmXnyHz6CVMJ9FrAnYbkWOG9oITP7DPBXQAR4W1pqJyKSRN+Zc0nB2A0t7Zyjs7uXeE8v8e5ebz7h1dnd400Ttnd29XhTf3unf5Do7EpY193LlDGsd5+0XVB1zt0D3GNmHwT+Fvjo0DJmdhNwE8CcOXPS9adFRNLOzPqbZrJRKvdQ7wVmJyxX+euG8yjw3mQbnHP3O+dqnHM1lZWVqddSRESOSyrhvg6Yb2bzzCwCXAc8mVjAzOYnLF4B7ExfFUVE5HiN2izjnOs2s5uBp/C6Qj7knNtqZrcD651zTwI3m9nbgS6gniRNMiIiMn5SanN3zq0B1gxZ99WE+c+nuV4iInISNG6piEgOUriLiOQghbuISA5SuIuI5CBz6bi/9kT+sNlh4I0TfHsFcCSN1ZkIcm2fcm1/IPf2Kdf2B3Jvn5LtzynOuVFvFMpYuJ8MM1vvnKvJdD3SKdf2Kdf2B3Jvn3JtfyD39ulk9kfNMiIiOUjhLiKSg7I13O/PdAXGQK7tU67tD+TePuXa/kDu7dMJ709WtrmLiMjIsvXMXURERpB14W5mq8zsFTPbZWa3ZLo+J8vMdpvZn8xso5mtz3R9ToSZPWRmh8xsS8K6MjP7pZnt9Kcn8XTP8TXM/txmZnv972mjmV2eyToeLzObbWZrzWybmW01s8/767Pyexphf7L2ezKzPDN70cw2+fv0NX/9PDN7wc+8//RH5x3987KpWcZ/nusOEp7nClyfzQ/jNrPdQI1zLmv75prZhXgPRv++c26Rv+7/AUedc3f4B+FS59z/zmQ9UzXM/twGtDjn7sxk3U6Umc0AZjjn/mhmRcAGvOcu3EAWfk8j7M8HyNLvybyHrcaccy1mFgaeAz6P94S7HznnHjWz+4BNzrl7R/u8bDtz73+eq3MujvdgkCszXKdJzzn3LHB0yOorge/5899jmAe4TETD7E9Wc87td8790Z9vBrbjPUIzK7+nEfYnazlPi78Y9l8O77Glj/vrU/6Osi3ckz3PNau/ULwv7xdmtsF/DGGumOac2+/PHwCmZbIyaXKzmW32m22yovkiGTObCywDXiAHvqch+wNZ/D2ZWdDMNgKHgF8CrwINzrluv0jKmZdt4Z6L3uqcWw68C/iM3ySQU5zX9pc97X/J3QucBiwF9gPfzGx1ToyZFQI/BL7gnGtK3JaN31OS/cnq78k51+OcW4r3ONMVwIIT/axsC/fjfZ7rhOec2+tPDwE/xvtCc8FBv120r330UIbrc1Kccwf9//F6gQfIwu/Jb8f9IfCwc+5H/uqs/Z6S7U8ufE8AzrkGYC3wFmCKmfU9WCnlzMu2cB/1ea7ZxMxi/sUgzCwGXAZsGfldWeNJBh63+FHgiQzW5aT1BaDvKrLse/Iv1v0rsN05d1fCpqz8nobbn2z+nsys0sym+PP5eB1HtuOF/Pv9Yil/R1nVWwbA79p0NwPPc/1Ghqt0wszsVLyzdfAeefhINu6Pmf0AuBhvBLuDwK3AT4DHgDl4o39+wDmXFRcph9mfi/F+6jtgN/AXCW3VE56ZvRX4LfAnoNdf/RW8duqs+55G2J/rydLvycwW410wDeKdeD/mnLvdz4lHgTLgJeDDzrnOUT8v28JdRERGl23NMiIikgKFu4hIDlK4i4jkIIW7iEgOUriLiOQghbuISA5SuIuI5CCFu4hIDvr/H/wM75ekh20AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')\n",
    "plt.plot(val_loss_hist, label = 'validation')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc : 91.83%\n"
     ]
    }
   ],
   "source": [
    "yhat = np.argmax(sess.run(score, feed_dict = {X : x_tst}), axis = 1)\n",
    "print('acc : {:.2%}'.format(np.mean(yhat == y_tst)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}