{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 05 : Variable sharing and managing experiments\n",
    "### Applied example with tf.placeholder\n",
    "Ref : [Toward Best Practices of TensorFlow Code Patterns](https://wookayin.github.io/TensorFlowKR-2017-talk-bestpractice/ko/#1) by Jongwook Choi, Beomjun Shin  \n",
    "  \n",
    "- Using **low-level api**\n",
    "- Creating the **input pipeline** with `tf.placeholder`\n",
    "- Creating the model as **Class**\n",
    "- Training the model with **learning rate scheduling** by exponential decay learning rate\n",
    "- Saving the model and Restoring the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train  / 255\n",
    "x_train = x_train.reshape(-1, 784)\n",
    "x_tst = x_tst / 255\n",
    "x_tst = x_tst.reshape(-1, 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784) (55000,)\n",
      "(5000, 784) (5000,)\n"
     ]
    }
   ],
   "source": [
    "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n",
    "\n",
    "x_tr = x_train[tr_indices]\n",
    "y_tr = y_train[tr_indices]\n",
    "\n",
    "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n",
    "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0)\n",
    "\n",
    "print(x_tr.shape, y_tr.shape)\n",
    "print(x_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define DNN Classifier with two hidden layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DNNClassifier:\n",
    "    def __init__(self, X, y, n_of_classes, hidden_dims = [100, 50], name = 'DNN'):\n",
    "        \n",
    "        with tf.variable_scope(name):\n",
    "            with tf.variable_scope('input_layer'):\n",
    "                self.X = X\n",
    "                self.y = y\n",
    "        \n",
    "            h = self.X\n",
    "        \n",
    "            for layer, h_dim in enumerate(hidden_dims):\n",
    "                with tf.variable_scope('hidden_layer_{}'.format(layer + 1)):\n",
    "                    h = tf.nn.tanh(self.__fully_connected(X = h, output_dim = h_dim))\n",
    "        \n",
    "            with tf.variable_scope('output_layer'):\n",
    "                score = self.__fully_connected(X = h, output_dim = n_of_classes)\n",
    "        \n",
    "            with tf.variable_scope('ce_loss'):\n",
    "                self.loss = self.__loss(score = score, y = self.y)\n",
    "                \n",
    "            with tf.variable_scope('prediction'):\n",
    "                self.__prediction = tf.argmax(input = score, axis = 1)\n",
    "        \n",
    "    def __fully_connected(self, X, output_dim):\n",
    "        w = tf.get_variable(name = 'weights',\n",
    "                            shape = [X.shape[1], output_dim],\n",
    "                            initializer = tf.random_normal_initializer())\n",
    "        b = tf.get_variable(name = 'biases',\n",
    "                            shape = [output_dim],\n",
    "                            initializer = tf.constant_initializer(0.0))\n",
    "        return tf.matmul(X, w) + b\n",
    "    \n",
    "    def __loss(self, score, y):\n",
    "        loss = tf.losses.sparse_softmax_cross_entropy(labels = y, logits = score)\n",
    "        return loss\n",
    "        \n",
    "    def predict(self, sess, X):\n",
    "        feed_predict = {self.X : X}\n",
    "        return sess.run(fetches = self.__prediction, feed_dict = feed_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of DNN Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create placeholders for x_data and y_data\n",
    "x_data = tf.placeholder(dtype = tf.float32, shape = [None, 784])\n",
    "y_data = tf.placeholder(dtype = tf.int32, shape = [None])\n",
    "\n",
    "dnn = DNNClassifier(X = x_data, y = y_data, n_of_classes = 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create training op and train model\n",
    "Applying exponential decay learning rate to train DNN model  \n",
    "```python\n",
    "decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)\n",
    "\n",
    "```\n",
    "Ref : https://www.tensorflow.org/api_docs/python/tf/train/exponential_decay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "859\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter\n",
    "epochs = 15\n",
    "batch_size = 64\n",
    "learning_rate = .005\n",
    "total_step = int(x_tr.shape[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Applying exponential decay learning rate to train dnn model\n",
    "global_step = tf.Variable(initial_value = 0 , trainable = False)\n",
    "exp_decayed_lr = tf.train.exponential_decay(learning_rate = learning_rate,\n",
    "                                            global_step = global_step,\n",
    "                                            decay_steps = total_step * 5,\n",
    "                                            decay_rate = .9,\n",
    "                                            staircase = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = exp_decayed_lr)\n",
    "\n",
    "# equal to 'var_list = None'\n",
    "training_op = opt.minimize(loss = dnn.loss,\n",
    "                           var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),\n",
    "                           global_step = global_step) \n",
    "\n",
    "# create summary op for tensorboard\n",
    "loss_summ = tf.summary.scalar(name = 'loss', tensor = dnn.loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_writer = tf.summary.FileWriter(logdir = '../graphs/lecture05/applied_example_wp/train',\n",
    "                                     graph = tf.get_default_graph())\n",
    "val_writer = tf.summary.FileWriter(logdir = '../graphs/lecture05/applied_example_wp/val',\n",
    "                                     graph = tf.get_default_graph())\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   5, tr_loss : 0.25, val_loss : 0.30\n",
      "epoch :  10, tr_loss : 0.16, val_loss : 0.24\n",
      "epoch :  15, tr_loss : 0.12, val_loss : 0.20\n"
     ]
    }
   ],
   "source": [
    "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n",
    "sess = tf.Session(config = sess_config)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "val_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    avg_val_loss = 0\n",
    "    \n",
    "    for step in range(total_step):\n",
    "        \n",
    "        batch_indices = np.random.choice(range(x_tr.shape[0]), size = batch_size, replace = False)\n",
    "        val_indices = np.random.choice(range(x_val.shape[0]), size = batch_size, replace = False)\n",
    "        \n",
    "        batch_xs = x_tr[batch_indices] \n",
    "        batch_ys = y_tr[batch_indices]\n",
    "        val_xs = x_val[val_indices]\n",
    "        val_ys = y_val[val_indices]\n",
    "        \n",
    "        _, tr_loss, tr_loss_summ = sess.run(fetches = [training_op, dnn.loss, loss_summ],\n",
    "                                   feed_dict = {x_data : batch_xs, y_data : batch_ys})\n",
    "\n",
    "        val_loss, val_loss_summ = sess.run(fetches = [dnn.loss, loss_summ],\n",
    "                                           feed_dict = {x_data : val_xs, y_data : val_ys})\n",
    "        avg_tr_loss += tr_loss / total_step\n",
    "        avg_val_loss += val_loss / total_step\n",
    "        \n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    val_loss_hist.append(avg_val_loss)\n",
    "    \n",
    "    train_writer.add_summary(summary = tr_loss_summ, global_step = (epoch + 1))\n",
    "    val_writer.add_summary(summary = val_loss_summ, global_step = (epoch + 1))\n",
    "    \n",
    "    if (epoch + 1) % 5 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.2f}, val_loss : {:.2f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n",
    "        saver.save(sess = sess, save_path = '../graphs/lecture05/applied_example_wp/dnn', global_step = (epoch + 1))\n",
    "\n",
    "train_writer.close()\n",
    "val_writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fe49c7c84a8>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8ZGWd7/HPr5ZUlkon6Sy9Jel0d3qDpptuIuK0IIhLowOIjAquiNojo4PO6Iw4My91HL2jd7xc5YoyqIA6CBdwAR0QXPCibNLN0nTTK/SS9JY9nT2p1HP/OJWkOp2NdCWVqvq+X6965dQ5p+r8uqG/58lznvMcc84hIiLpxZfsAkREJPEU7iIiaUjhLiKShhTuIiJpSOEuIpKGFO4iImlI4S4ikoYU7iIiaUjhLiKShgLJOnBJSYmrqqpK1uFFRFLS1q1bG51zpRPtl7Rwr6qqYsuWLck6vIhISjKzg5PZT90yIiJpSOEuIpKGFO4iImkoaX3uIpJe+vv7qauro6enJ9mlpIXs7GzKy8sJBoNT+rzCXUQSoq6ujvz8fKqqqjCzZJeT0pxzNDU1UVdXx5IlS6b0HeqWEZGE6Onpobi4WMGeAGZGcXHxaf0WpHAXkYRRsCfO6f5dpl64H98Bv/kC9JxIdiUiIrNW6oV7y0F4/FvQsDvZlYjILNLa2sp3vvOdV/25t73tbbS2tk5DRcmVeuFeutL72ahwF5FhY4V7JBIZ93MPPvgghYWF01VW0qTeaJnCxeDPUstdRE5yww038PLLL3P22WcTDAbJzs6mqKiIXbt2sWfPHt7xjndQW1tLT08Pn/rUp9i8eTMwPBVKR0cHl1xyCa9//et54oknWLRoEffffz85OTlJ/pNNTeqFuz8AxdXQuCfZlYjIGP71lzt46Uhir4udsXAOX7z0zDG3f+1rX2P79u08//zz/OEPf+Dtb38727dvHxpKeNtttzF37ly6u7t5zWtew5VXXklxcfFJ37F3717uuusuvve97/Hud7+bn/70p7z//e9P6J9jpkzYLWNmt5lZvZltn2C/15hZxMz+KnHljaFkhVruIjKuc88996Qx4jfddBPr1q3jvPPOo7a2lr17957ymSVLlnD22WcDcM4553DgwIGZKjfhJtNyvwP4NvCjsXYwMz/wdeCRxJQ1gdKVsPMB6O+BYPaMHFJEJm+8FvZMycvLG1r+wx/+wG9/+1uefPJJcnNzufDCC0cdQx4KhYaW/X4/3d3dM1LrdJiw5e6cewxonmC3vwV+CtQnoqgJlawAF4WmfTNyOBGZ/fLz82lvbx91W1tbG0VFReTm5rJr1y6eeuqpGa5u5p12n7uZLQKuAC4CXnPaFU1G/IiZ+Wtm5JAiMrsVFxezceNG1qxZQ05ODvPmzRvatmnTJm655RZWr17NypUrOe+885JY6cxIxAXVbwKfc85FJ7qjysw2A5sBKisrp37E4mrAoEEXVUVk2E9+8pNR14dCIR566KFRtw32q5eUlLB9+/Clxc9+9rMJr28mJSLca4C7Y8FeArzNzCLOuV+M3NE5dytwK0BNTY2b8hGDOVC0WGPdRUTGcNrh7pwbuhxtZncAvxot2BOpubOPopKVmFruIiKjmsxQyLuAJ4GVZlZnZh8xs4+b2cenv7xT/fy5Ojb8229oCy/xLqhGB5JRhojIrDZhy905d/Vkv8w5d81pVTMJVcXe8KbD/koKB3qh5QAUL5vuw4qIpJSUm1tmWVkYgN0DC7wVulNVROQUKRfuc7KDzJsTYktXqbdCd6qKiJwi5cIdoLoszI4mg/B8tdxFZErCYa8X4MiRI/zVX40+a8qFF17Ili1bxv2eb37zm3R1dQ29ny1TCKdmuJeGebmhE1eqOWZE5PQsXLiQ++67b8qfHxnus2UK4dQM97IwHb0RuuYs81rubupD5kUkPdxwww3cfPPNQ++/9KUv8ZWvfIWLL76YDRs2cNZZZ3H//fef8rkDBw6wZo13p3t3dzdXXXUVq1ev5oorrjhpbpnrrruOmpoazjzzTL74xS8C3mRkR44c4aKLLuKiiy4CvCmEGxsbAbjxxhtZs2YNa9as4Zvf/ObQ8VavXs3HPvYxzjzzTN7ylrdMyxw2qTflL1Bdlg/A0WAl1b0noP0YzFmQ5KpEZMhDN8CxFxP7nfPPgku+Nubm97znPXz605/mE5/4BAD33HMPDz/8MNdffz1z5syhsbGR8847j8suu2zM55N+97vfJTc3l507d7Jt2zY2bNgwtO2rX/0qc+fOZWBggIsvvpht27Zx/fXXc+ONN/Loo49SUlJy0ndt3bqV22+/naeffhrnHK997Wt5wxveQFFR0YxMLZyyLXeAfW6Rt0J3qopkvPXr11NfX8+RI0d44YUXKCoqYv78+fzTP/0Ta9eu5U1vehOHDx/m+PHjY37HY489NhSya9euZe3atUPb7rnnHjZs2MD69evZsWMHL7300rj1/OlPf+KKK64gLy+PcDjMO9/5Tv74xz8CMzO1cEq23EvCWRTkBHmuu4xN4M0xs/TC5BYlIsPGaWFPp3e9613cd999HDt2jPe85z3ceeedNDQ0sHXrVoLBIFVVVaNO9TuR/fv3841vfINnnnmGoqIirrnmmil9z6CZmFo4JVvuZkZ1WZjnmkMQKlDLXUQAr2vm7rvv5r777uNd73oXbW1tlJWVEQwGefTRRzl48OC4n7/ggguGJh/bvn0727ZtA+DEiRPk5eVRUFDA8ePHT5qEbKyphs8//3x+8Ytf0NXVRWdnJz//+c85//zzE/inHV9KttzBGzHz253HYYFGzIiI58wzz6S9vZ1FixaxYMEC3ve+93HppZdy1llnUVNTw6pVq8b9/HXXXceHP/xhVq9ezerVqznnnHMAWLduHevXr2fVqlVUVFSwcePGoc9s3ryZTZs2sXDhQh599NGh9Rs2bOCaa67h3HPPBeCjH/0o69evn7GnO5lL0kiTmpoaN9H40fF877FX+OqDO9ld80tCB34Pn9V4d5Fk2rlzJ6tXr052GWlltL9TM9vqnKuZ6LMp2S0DUD3Pu6han70YOo5Dd0uSKxIRmT1SN9xLvXB/xZV7KzT9r4jIkJQN90WFOeQE/WzrjT1KSxdVRZIuWd286eh0/y5TNtx9PmNpaR5b2vLBH9JFVZEky87OpqmpSQGfAM45mpqayM7OnvJ3pOxoGfBuZtpyoAVKlmsCMZEkKy8vp66ujoaGhmSXkhays7MpLy+f8udTO9xLw9z//BEiS5YTOPpssssRyWjBYJAlS5ZMvKPMiJTtloHhaQiacqqg9RD0J/4uLxGRVJTS4b48NhzyoK8ccNC4N7kFiYjMEikd7ouL8wj4jO19870V6ncXEQFSPNyDfh+Li3PZ0j4XzKcRMyIiMROGu5ndZmb1ZrZ9jO3vM7NtZvaimT1hZusSX+bYqsvC7Grsg6IqjXUXEYmZTMv9DvBm1h3DfuANzrmzgH8Dbk1AXZNWXRbmYFMX0eIVuktVRCRmwnB3zj0GNI+z/Qnn3ODELk8BUx+YOQXVZWEGoo7WvCXQtA8GIjN5eBGRWSnRfe4fAR6acK8EWh575F6dvwKi/dByYCYPLyIyKyUs3M3sIrxw/9w4+2w2sy1mtiVRd7EtLc0DYFdkobdC/e4iIokJdzNbC3wfuNw51zTWfs65W51zNc65mtLS0kQcmtysAIsKc9jSGXs4rUbMiIicfribWSXwM+ADzrmkXNGsLguzvQnIX6Cx7iIiTGJuGTO7C7gQKDGzOuCLQBDAOXcL8AWgGPiOmQFEJvOUkESqLgvz9P4mXPUKTC13EZGJw905d/UE2z8KfDRhFU1BdVmYnv4oHfnLyD98LzgH3olGRCQjpfQdqoOWxyYQOxKshL52OHEkyRWJiCRXWoT74OyQe6MaMSMiAmkS7oW5WZSEs3i2KzYCR3eqikiGS+mHdcRbVhrm+eYByC5Uy11EMl5atNzB65rZ19CJK12plruIZLy0CvcTPRF6Cpap5S4iGS+twh3geGgxdDZA15hznYmIpL20CffBCcRedou8FbpTVUQyWNqE+7w5IcKhAC/0zPNW6E5VEclgaTNaxsxYVhZma6tBIEctdxHJaGnTcgeoLg2zt6ELSqrVcheRjJZe4V4Wpr69l/65yzViRkQyWtqFO0BDdhW01kJfV3ILEhFJkrQK98EJxA5QDjho2pvcgkREkiStwr1ibi5ZAR/b++Z7K3SnqohkqLQKd7/PWFqSxzPtRWB+9buLSMZKq3AHWFYWZndDL8xdAg27kl2OiEhSpF24V5eGqW3pYqB4hbplRCRjpV+4l4VxDlpyl0DzyzDQn+ySRERmXNqF+/J53oiZQ/5yiEageX+SKxIRmXlpF+5LSvLwGezsW+Ct0EVVEclAE4a7md1mZvVmtn2M7WZmN5nZPjPbZmYbEl/m5IUCfirn5rKlq8RboWkIRCQDTablfgewaZztlwDLY6/NwHdPv6zTU10WZkdjFOYs0gRiIpKRJgx359xjwHhPvrgc+JHzPAUUmtmCRBU4FcvKwuxv7CRaskItdxHJSInoc18E1Ma9r4utO4WZbTazLWa2paGhIQGHHl11aZj+AUd7eCk07oVodNqOJSIyG83oBVXn3K3OuRrnXE1paem0HWdwArHDgQro74QTh6ftWCIis1Eiwv0wUBH3vjy2LmkGw333wEJvhUbMiEiGSUS4PwB8MDZq5jygzTl3NAHfO2X52UHmz8nmua4yb4XuVBWRDDPhY/bM7C7gQqDEzOqALwJBAOfcLcCDwNuAfUAX8OHpKvbVqC4L83xzH+QUqeUuIhlnwnB3zl09wXYHfCJhFSVIdVmYe7fU4havxNRyF5EMk3Z3qA5aVhams2+AroJlarmLSMZJ23CvLvUuqh4LVkJXE3Q2JbkiEZGZk7bhPjiB2D4XG3Kv1ruIZJC0DffivCwKc4M83zPPW6E7VUUkg0x4QTVVmRnVpWGebXEQzNUcMyKSUdK25Q7eiJm9jV1QslwtdxHJKGkf7s2dffQWLlfLXUQySlqH+7LYNAT12YuhrRZ6O5JckYjIzEjrcF8eC/f9g5NUNu1NYjUiIjMnrcN9YUEOOUE/L/bO91boTlURyRBpO1oGwOczlpXl8cwJH/gCGusuIhkjrVvu4N2puqehB+Yu1YgZEckY6R/uZWGOtPUQmasRMyKSOTIi3AGacqqg+RUY6E9uQSIiMyADwj0fgEO+CohGvIAXEUlzaR/ui4tzCfiMHf2DI2bU7y4i6S/twz3o91FVksczHSXeCo2YEZEMkPbhDt6ImZ2NA1BQobHuIpIRMiPcy8IcbO4iWrxCLXcRyQgZE+4DUUdbeAk07oVoNNkliYhMq0mFu5ltMrPdZrbPzG4YZXulmT1qZs+Z2TYze1viS526weGQdf4K6O+CE3VJrkhEZHpNGO5m5gduBi4BzgCuNrMzRuz2L8A9zrn1wFXAdxJd6OlYVhrGDHZFFnor1O8uImluMi33c4F9zrlXnHN9wN3A5SP2ccCc2HIBcCRxJZ6+nCw/iwpz2NpV6q1o2JXcgkREptlkJg5bBNTGva8DXjtiny8Bj5jZ3wJ5wJsSUl0CVZeFeaG5F3KLdVFVRNJeoi6oXg3c4ZwrB94G/NjMTvluM9tsZlvMbEtDQ0OCDj051aVhXmnowJWsULeMiKS9yYT7YaAi7n15bF28jwD3ADjnngSygZKRX+Scu9U5V+OcqyktLZ1axVNUXRamNxKlI3+Z13J3bkaPLyIykyYT7s8Ay81siZll4V0wfWDEPoeAiwHMbDVeuM9s03wCgyNmjmZVQncLdDYmuSIRkekzYbg75yLAJ4GHgZ14o2J2mNmXzeyy2G6fAT5mZi8AdwHXODe7msaD4b53IDZiRv3uIpLGJvUkJufcg8CDI9Z9IW75JWBjYktLrMLcLErCIZ7vnsfbwZtArOr1yS5LRGRapPVj9kaqLstjS+sABPP04A4RSWsZMf3AoOqyMPsaOnElyzX1r4iktcwK99Iw7T0Regur1XIXkbSWWeEeeyrT8dBiOHEYetuTXJGIyPTIqHBfPs8bMfOKW+StUOtdRNJURoV7WX6I/FCAbb2Dj9xTuItIesqo0TJmxrKyMM+0OfAFNNZdRNJWRrXcwRsxs7uxB4qr1XIXkbSVkeHe0N5Lf9FytdxFJG1lXriXehdVG3OqoHk/RPqSW5CIyDTIvHCPzTFzwBaBG4Dml5NckYhI4mVcuFfMzSUr4GNH3+CIGXXNiEj6ybhw9/uMpSV5PNNeDJjGuotIWsq4cAeva+alpggUVqjlLiJpKWPDva6lm4HiFRoxIyJpKWPD3TloyV0CjfsgGk12SSIiCZWx4Q5Q66+ASDe0HUpyRSIiiZWR4b6kJA+fwa7IAm+F7lQVkTSTkeEeCvhZXJzHls5Sb4X63UUkzWRkuAMsKw2zrckHeaUaMSMiaSdjw726LMyBpk5cyQqNdReRtDOpcDezTWa228z2mdkNY+zzbjN7ycx2mNlPEltm4lWXhekfcJwIL/Va7s4luyQRkYSZcD53M/MDNwNvBuqAZ8zsAefcS3H7LAc+D2x0zrWYWdl0FZwogyNmjgQqKehphc4GCM/6skVEJmUyLfdzgX3OuVecc33A3cDlI/b5GHCzc64FwDlXn9gyE28w3PdEF3or1O8uImlkMuG+CKiNe18XWxdvBbDCzB43s6fMbFOiCpwu4VCABQXZPNcda61rxIyIpJFEPWYvACwHLgTKgcfM7CznXGv8Tma2GdgMUFlZmaBDT111WZitzX2Qla+x7iKSVibTcj8MVMS9L4+ti1cHPOCc63fO7Qf24IX9SZxztzrnapxzNaWlpVOtOWGWlYZ5ubETV6KnMolIeplMuD8DLDezJWaWBVwFPDBin1/gtdoxsxK8bppXEljntKguC9PVN0B3wTK13EUkrUwY7s65CPBJ4GFgJ3CPc26HmX3ZzC6L7fYw0GRmLwGPAv/gnGuarqITZfCi6rGsxdB+BHraklyRiEhiTKrP3Tn3IPDgiHVfiFt2wN/HXiljeSzc97mFLAVo3AvlNUmtSUQkETL2DlWA4nCIotwgL/bO81ZoOKSIpIlEjZZJWdVlYf7cOgD+LF1UFZG0kdEtd/DCfU9DN8zVRVURSR8ZH+7LSsO0dPXTW1QNR56DruZklyQictoyPtwHR8y8svjd0N0Ct22CtrokVyUicnoU7rFw3+pfBx/4GbQfhR+8Bep3JbkyEZGpy/hwX1iQQ26Wn331HVD1evjwgxCNwG1vhdo/J7s8EZEpyfhw9/nMm4agocNbMf8suPZhyJ0LP7wM9jyc3AJFRKYg48MdvK6ZffUdwyvmLoFrH4HSlXDX1fD8XckrTkRkChTueOF+tK2Hjt7I8MpwKVzzK6+r5hcfh8e/lbwCRUReJYU73nBIgJfjW+8AoXx4371w5jvhN1+Ah/8ZotEkVCgi8uoo3BkeMbN3ZLgDBEJw5Q/g3M3w5LfhF9fBQP8MVygi8upk/PQDAIuLcwn67eR+93g+H1zyP71nrP7+K9DVBO/+IWTlzWyhIiKTpJY7EPT7qCrOGzvcAczggn+AS78FL//OG0mju1lFZJZSuMdUl4V58XArTR294+94zjXw7h/DsRe9sfCttePvLyKSBAr3mKvPraS1q5/Lvv04Lx05Mf7Oq/8SPvBzaD/uBXz9zpkpUkRkkhTuMResKOXej7+Ogajjyu8+wYMvHh3/A1Ub4+5m3QSHnp6ZQkVEJkHhHmdteSEPfHIjqxfk8zd3PsuNj+wmGnVjf2D+GvjII5BbDD+6XHezisisoXAfoWxONndtPo9315Rz0+/38df/tfXkm5tGKqryAr5sVexu1p/MWK0iImNRuI8iFPDz9SvX8sVLz+D3u+p553ce52BT59gfyCuBD/0SlpzvjYPX3awikmQK9zGYGR/euIQfXXsux0/0cvnNj/P4vsaxPxDKh/feC2uu1N2sIpJ0kwp3M9tkZrvNbJ+Z3TDOfleamTOzmsSVmFwbq0t44JMbKcsP8cHb/sztj+/HuTH64QNZ8M7vw7l/Hbub9eO6m1VEkmLCcDczP3AzcAlwBnC1mZ0xyn75wKeAtBs2srg4j5/9zUbeuKqMf/3lS3zup9vojQyMvrPPB5d8HS7+Amz7v3DXVdA3TpeOiMg0mEzL/Vxgn3PuFedcH3A3cPko+/0b8HWgJ4H1zRrhUID/fP85XP/Gau7ZUsfVtz5FffsYf1QzOP8zcOlN8PLv4Vtnw2+/BM2vzGjNIpK5JhPui4D42zDrYuuGmNkGoMI5998JrG3W8fmMv3/LSm5+7wZ2Hm3nsv/zONvqWsf+wDkfgg/9CspfA4/fBDet94ZM7vg5RPpmrnARyTinfUHVzHzAjcBnJrHvZjPbYmZbGhoaTvfQSfP2tQu477rX4fcZ77rlSe5//vDYO1dthKt/An+3HS76F2h6Ge69Bm5c7V14bXp5xuoWkcxhY14cHNzB7HXAl5xzb429/zyAc+7fY+8LgJeBwVm35gPNwGXOuS1jfW9NTY3bsmXMzSmhqaOX6+58lj/vb+av37CUf3zrKvw+G/9D0QF4+VHYejvsfgjcAFSd781Zs/pSb4phEZExmNlW59yEg1YmE+4BYA9wMXAYeAZ4r3Nuxxj7/wH47HjBDukR7gB9kSj/+ssd3Pn0IS5cWcq3rlpPQU5wch9uPwbP/Rc8+0NoPQQ5c+Hs93pBX7J8WusWkdQ02XCfsFvGORcBPgk8DOwE7nHO7TCzL5vZZadfamrLCvj46hVn8ZV3rOFPexu54juPDz9seyL58+GCz8L1L8D7f+Y90u/pW+DbNXD722HbvdCfltenRWSaTdhyny7p0nKP9/QrTVx357P0D0S56er1XLSy7NV/SUc9PH8nbL0DWg5AThGse693cbZ0ZaJLFpEUk7BumemSjuEOUNvcxeYfb2X3sRN8btMqNl+wFLMJ+uFHE43Cgce8kN/5K4j2Q+VfeCF/xuUQzEl47SIy+ynck6irL8I/3LuN/37xKFesX8S/v/MssoP+qX9hRwO88BMv6JtfgewCWHe19+Du8hrwncZ3i0hKUbgnmXOOb/9+H//rN3tYUJDNB19XxdXnVlCYm3U6XwoH/uiF/EsPeK353GJY/lZYeQkseyOEwgn7M4jI7KNwnyUe39fIzY/u44mXm8gO+njnhnKu3VhFdVn+6X1xdyvs+y3s+TXsfQR62sCf5Q2rXHkJrNgEhRWJ+UOIyKyhcJ9ldh49wR2PH+Dnzx+mLxLlghWlXLuxiguWl+KbaGz8RAb64dBTXtDvfgiaYzdGzTsLVm7ywn7Bem/eGxFJaQr3Waqpo5efPH2IHz91kPr2XpaW5vHhjUu4csMicrMCiTlI417Y/SDs/jXUPgUuCuF5sOKtsOISWHohZOUm5lgiMqMU7rNcXyTKgy8e5Qd/2s+Lh9uYkx3g6tdW8sHXVbGoMIEjYbqavW6b3Q/Bvt9BXzsEsr2AX7HJe81ZkLjjici0UrinCOccWw+2cNvj+/n19mOYGZvOnM+1r69iQ2XR1IZRjiXSBwf/5LXo9zzk3RULsHC916JfuQnmr/VmtRSRWUnhnoLqWrr48ZMHuevPhzjRE2FdeQHXvn4Jl6xZQFYgwf3lzkH9S16Lfs+voW4L4CA8H+adCaWrvJumSldB6QrvZioRSTqFewrr7I3ws2fruP2JA7zS0Mm8OaHYUMpK5uadxlDK8XTUw56HYf9j0LDL67ePdA9vD8/3Qv6k0F/lPT9WRGaMwj0NRKOO/7e3gdv+tJ8/7m0kFPBxxfpFfHjjElbOP82hlBMfHNoOQcNuL+zjf/bFzZ2TWzwc+CUrh4M/f766d0SmgcI9zew93s7tTxzgZ8/W0dMfZWN1MW9ePY8Ni4tYvWAOQf8MDXN0Dk4cjoV9fPDv9MbaDwoVxIJ+5ckt/YJyhb7IaVC4p6mWzj7ueuYQd/35ELXNXrdJdtDHuvJCzllcxDmLi1hfWTR93Tdjcc7r2mnYBY17Tm7td8Y9mCUrfHLYl66CslUwp1zj8EUmQeGeAY60dvPsoRa2Hmzh2YMt7DhygkjU+++5tCSPDbGwP2dxEdWl4dO/WWqqOpugcTfU7xxu5Tfsho7jw/sE84ZDvywu+AsqFPoicRTuGai7b4AXD7ex9WAs8A+10NzpPas1PzvA+soizqn0wn5dRQH52ZN8qMh06Wo+OewHw7/j2PA+wby4C7lxLf2CSoW+ZCSFu+Cc40BT11DYP3eohd3H23HO6/ZeOS9/qGV/zuIiKufmJnZc/VR1NXtdOyNb+u1Hh/cJ5kLJCiis9O6+Dc+DcFnczzLIK4PADHdPiUwzhbuM6kRPP88fah1q2T93qJWO3ggAxXlZrK8sYn1lIWvLC1i7qJCC3CS37uN1t0DDnpNb+ieOeN07Pa2jfyanaETwz4O80lPX5c7V1MmSEhTuMikDUcfe+va41n0r+xs7h7YvLcljbXkB6yoKWVteyJkL55ze3PTTJdLrXdDtqIfOei/wOwZ/xi/XQ3/XqZ83vzdmf6jFnw3+APgC4AuOWA56J4Kh5di2UZdHfDaQ5c3e6Q95+wRCI5az4l5BjSySUyjcZcrauvrZdriVF2pbeaGujRdqW6lv7wUg4DNWLchnbXkhZ5cXsraigOVl+fiTdbF2Kno7Tg38juOxk0LsNdDnzbYZjXjz5g9ERl+ORqa3Vn8s8ANxoT/yJBAIeU/mCuZAIAeC2V63VSB7xPJY+4xYH8jR9YxZTOEuCXWsrYfna1vZVtfKC3WtbKtro73HC7bcLD9rFhawrqLAC/2KQsqLcmZH//10cw6iA7HQHzwZxF7x7wf6h08YA30w0OstR2I/B3pPXh7o8+YCGoi9Ir3Dy0Pbeoe39Xd7r0hP3HK3NyPoVPhDkD3H69bKmRv7WeR1X+UUDr8fuS0rrN82ppnCXaZVNOrY39TphX1tGy/UtbLjyAn6Il6YzM3L8rpzygtZV1HAGQsKKMsPJW84ZiZyzjtZ9HedGvr9PSOWB/fpGn7fe8K7ztHdAl0tw8v9nWMf0xc4NfRHnhTSVX1FAAALRUlEQVRyS7zrHnmxn9mF+k3hVUhouJvZJuBbgB/4vnPuayO2/z3wUSACNADXOucOjvedCvf00xeJsvtYOy/UeV062+ra2FPvjc4ByAr4KC/MoXxuLhVFOVTMzaWiKJeKuTlUFOVSmBvMjNZ+qov0Dgd9d4s3uin+fXfzqzspmN+bxiKvxHudFP6jvM8uzOjfDhIW7mbmB/YAbwbqgGeAq51zL8XtcxHwtHOuy8yuAy50zr1nvO9VuGeGzt4ILx5uY+/xdmpbuqlr6aK2uZvali5au/pP2jccClA+SuhXzM2lvCiHvFCCHmYiyTF4Uuhs9O5a7mryfo76vhF620b/Hl8gFvhx4Z9T6J0kzBd7WezlG34x4v3IfUbbPnRxPDh8kXvoZ3D4uocvEHcdJHDyRfHBz/r8CTkpTTbcJ/Ov5Vxgn3PuldgX3w1cDgyFu3Pu0bj9nwLe/+rKlXSVFwpw3tJizltafMq2Ez391MWCvra5i7qWbmqbuzjY1Mmf9jbS3T9w0v7FeVmUFw22/L3wX1SYw4KCHBYUZjMn2TdlyfgCIW9Cufz5k9s/0hsL/AlOBi0HvaGwLup1RTkXW469iH/vvPdJYcOB/7pPwkWfn9ajTSbcFwG1ce/rgNeOs/9HgIdG22Bmm4HNAJWVlZMsUdLVnOwgZywMcsbCOadsc87R1NlHbXMXtbHQr4u1/HccbuORHcfoHzj5H2k4FGB+QTYLYq/5BTksLMhmfkE2CwtzmF+gE0BKCYRgzkLvlUjxJwBGnAhGnhiGLpYPXgyPW46OvEgev70vdiG9b/TPLjw7sX+mUST091wzez9QA7xhtO3OuVuBW8HrlknksSW9mBkl4RAl4RDrK099UMhA1HH8RA9HWrs52tbD0bbYz9Yejp7oYfexBho6ehnZ6zjyBLCgIMf7WZgztC7p0zLI9BrsjiG9L+JOJtwPAxVx78tj605iZm8C/hl4g3OuNzHliYzO7zMWFuawcJznzfYPRDl+oodjbT0caevhWFs3R1q990fbutl9rH3ME0BpfojivCyKw1kUh0OU5GVRkh+iOC9EcTiLknAWxXkhCnKCGgEks9Jkwv0ZYLmZLcEL9auA98bvYGbrgf8ENjnn6hNepcgUBP0+yotyKS/KHXOfvkiU+vaeWOt/+ATQ1NlHU0cv+xs72XKgheauvlNOAuDd1DU3L3YCCGfFTgjebxzxJwFvOTQ77+6VtDRhuDvnImb2SeBhvKGQtznndpjZl4EtzrkHgP8AwsC9saFsh5xzl01j3SIJkRWY+AQAXjdQS1cfTR1e6Dd29tHY3ktTZy9NHX00dvTR1NnLwaYuGjt66eobGPV7wqEAJbGgHz4BhCjJD1EaWx48UYRDAQ0NlSnTTUwi06CrL+KdCGK/ATR19NHQ0UtjRy+NHd6JobGjl6bOvqFpmUcKBXxDwV+SN3gSGD4xlMROAqX5XveQTgSZIZFDIUXkVcrNCpA7N0DF3PF/IwCIDERp7hwM/5ODv7G9l4aOXo609bDtcBvNnX0MRE9tkGX5fZTmeyeCsvwQpfkhSsMhyuYM/sz2toezCAXUNZQJFO4iSRbw+yibk03ZnOwJ943GuocaY91DDR29NLTH/Wzvpba5i2cPttA0xm8EhbnBk4K/ND9EWb4X/qVxJ4f87GBqTQgnJ1G4i6QQn88ojvXLQ/64+/YPRGnq6KO+vWco+Ovbe+OWe9h6qIX6E730RkafYCwU8JGb5Sc3K0BOlp+8LD85ce9zg35ys/zkZAVi+w1u9/YZWhccXg4F/GQFfGQFfDp5TCOFu0iaCvp9zI/dxDUe5xztvZFTTgDtPf109w3QNfSK0NU3QHffAPXtPUPLgz/7Bl79DJR+nxGKBX2W3zcU+oMngJB/8P3wtiz/yfsMbs8JDp9csoN+coLe8kk/Y8uhgC/tr1Eo3EUynJkxJzvInOwgy0rDU/6e/oFoXODHTgT9g+Hvve/sG6AvEh169UZi7wfi13kvb90AnX0RWrqiJ+3XG/cdUzmpmDEU9tmjnRRi60JBP9lB39D6weXsgJ/Q4PIo27KDPkKx9UG/JeVEonAXkYQI+n0U5PgoyJnZO3yjUUffQJSefu9kMvjbRPz77n7v/eAJpyf209sepbs/MrRfa3c/x9p6hk5Mvf0D9EQGTpnuYrJ8xtBJIDvgnQDe+9pKPnr+0gT/TZxM4S4iKc3nM7J9XngWTuNxBqKOnthJoicSpTt2AumNDNDTH41tGz7JeNuiw5/pjw6tLwmHprFSj8JdRGQS/D4jLxRImamn03vmHBGRDKVwFxFJQwp3EZE0pHAXEUlDCncRkTSkcBcRSUMKdxGRNKRwFxFJQ0l7WIeZNQAHp/jxEqAxgeVMt1SqN5VqhdSqN5VqhdSqN5VqhdOrd7FzrnSinZIW7qfDzLZM5kkks0Uq1ZtKtUJq1ZtKtUJq1ZtKtcLM1KtuGRGRNKRwFxFJQ6ka7rcmu4BXKZXqTaVaIbXqTaVaIbXqTaVaYQbqTck+dxERGV+qttxFRGQcKRfuZrbJzHab2T4zuyHZ9YzFzCrM7FEze8nMdpjZp5Jd02SYmd/MnjOzXyW7lvGYWaGZ3Wdmu8xsp5m9Ltk1jcfM/i72/8F2M7vLzMZ/sOkMM7PbzKzezLbHrZtrZr8xs72xn0XJrHHQGLX+R+z/hW1m9nMzm87ndrwqo9Ubt+0zZubMrCTRx02pcDczP3AzcAlwBnC1mZ2R3KrGFAE+45w7AzgP+MQsrjXep4CdyS5iEr4F/No5twpYxyyu2cwWAdcDNc65NYAfuCq5VZ3iDmDTiHU3AL9zzi0Hfhd7Pxvcwam1/gZY45xbC+wBPj/TRY3jDk6tFzOrAN4CHJqOg6ZUuAPnAvucc6845/qAu4HLk1zTqJxzR51zz8aW2/HCZ1FyqxqfmZUDbwe+n+xaxmNmBcAFwA8AnHN9zrnW5FY1oQCQY2YBIBc4kuR6TuKcewxoHrH6cuCHseUfAu+Y0aLGMFqtzrlHnHOR2NungPIZL2wMY/zdAvxv4B+BabnwmWrhvgiojXtfxywPTAAzqwLWA08nt5IJfRPvf7ZX/zj5mbUEaABuj3Uhfd/M8pJd1Ficc4eBb+C10I4Cbc65R5Jb1aTMc84djS0fA+Yls5hX4VrgoWQXMR4zuxw47Jx7YbqOkWrhnnLMLAz8FPi0c+5EsusZi5n9JVDvnNua7FomIQBsAL7rnFsPdDJ7ugxOEeurvhzvpLQQyDOz9ye3qlfHecPqZv3QOjP7Z7wu0TuTXctYzCwX+CfgC9N5nFQL98NARdz78ti6WcnMgnjBfqdz7mfJrmcCG4HLzOwAXnfXG83sv5Jb0pjqgDrn3OBvQvfhhf1s9SZgv3OuwTnXD/wM+Isk1zQZx81sAUDsZ32S6xmXmV0D/CXwPje7x3gvwzvRvxD791YOPGtm8xN5kFQL92eA5Wa2xMyy8C5KPZDkmkZlZobXJ7zTOXdjsuuZiHPu8865cudcFd7f6++dc7OydemcOwbUmtnK2KqLgZeSWNJEDgHnmVlu7P+Li5nFF4DjPAB8KLb8IeD+JNYyLjPbhNeleJlzrivZ9YzHOfeic67MOVcV+/dWB2yI/X+dMCkV7rELJp8EHsb7x3GPc25Hcqsa00bgA3gt4Odjr7clu6g08rfAnWa2DTgb+B9JrmdMsd8w7gOeBV7E+3c3q+6oNLO7gCeBlWZWZ2YfAb4GvNnM9uL99vG1ZNY4aIxavw3kA7+J/Vu7JalFxhmj3uk/7uz+7UVERKYipVruIiIyOQp3EZE0pHAXEUlDCncRkTSkcBcRSUMKdxGRNKRwFxFJQwp3EZE09P8BMKSBxwbUGioAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')\n",
    "plt.plot(val_loss_hist, label = 'validation')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: 94.62%\n"
     ]
    }
   ],
   "source": [
    "yhat = dnn.predict(sess = sess, X = x_tst)\n",
    "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Restore model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example 1\n",
    "Restore my model at epoch 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = tf.placeholder(dtype = tf.float32, shape = [None, 784])\n",
    "y_data = tf.placeholder(dtype = tf.int32, shape = [None])\n",
    "\n",
    "dnn_restore = DNNClassifier(X = x_data, y = y_data, n_of_classes = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_checkpoint_path: \"../graphs/lecture05/applied_example_wp/dnn-15\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-5\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-10\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-15\"\n",
      "\n"
     ]
    }
   ],
   "source": [
    "ckpt_list = tf.train.get_checkpoint_state(checkpoint_dir = '../graphs/lecture05/applied_example_wp/')\n",
    "print(ckpt_list) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../graphs/lecture05/applied_example_wp/dnn-15\n"
     ]
    }
   ],
   "source": [
    "# restore my model at epoch 15\n",
    "sess = tf.Session()\n",
    "saver = tf.train.Saver()\n",
    "saver.restore(sess = sess, save_path = '../graphs/lecture05/applied_example_wp/dnn-15')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: 94.62%\n"
     ]
    }
   ],
   "source": [
    "yhat = dnn_restore.predict(sess = sess, X = x_tst)\n",
    "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example 2\n",
    "Restore my model at epoch 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = tf.placeholder(dtype = tf.float32, shape = [None, 784])\n",
    "y_data = tf.placeholder(dtype = tf.int32, shape = [None])\n",
    "\n",
    "dnn_restore = DNNClassifier(X = x_data, y = y_data, n_of_classes = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_checkpoint_path: \"../graphs/lecture05/applied_example_wp/dnn-15\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-5\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-10\"\n",
      "all_model_checkpoint_paths: \"../graphs/lecture05/applied_example_wp/dnn-15\"\n",
      "\n"
     ]
    }
   ],
   "source": [
    "ckpt_list = tf.train.get_checkpoint_state(checkpoint_dir = '../graphs/lecture05/applied_example_wp/')\n",
    "print(ckpt_list) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../graphs/lecture05/applied_example_wp/dnn-10\n"
     ]
    }
   ],
   "source": [
    "# restore my model at epoch 10\n",
    "sess = tf.Session()\n",
    "saver = tf.train.Saver()\n",
    "saver.restore(sess = sess, save_path = '../graphs/lecture05/applied_example_wp/dnn-10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: 93.52%\n"
     ]
    }
   ],
   "source": [
    "yhat = dnn_restore.predict(sess = sess, X = x_tst)\n",
    "print('test acc: {:.2%}'.format(np.mean(yhat == y_tst)))\n",
    "sess.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}