{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 03 : Linear and Logistic Regression\n",
    "### Logistic Regression with tf.data\n",
    "same contents, but different style with [Lec03_Logistic Regression with tf.data.ipynb](https://nbviewer.jupyter.org/github/aisolab/CS20/blob/master/Lec03_Linear%20and%20Logistic%20Regression/Lec03_Logistic%20Regression%20with%20tf.data.ipynb)\n",
    "\n",
    "* Creating the input pipeline with `tf.data`\n",
    "* Using `eager execution`  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)\n",
    "tf.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = (x_train  / 255)\n",
    "x_train = x_train.reshape(-1, 784).astype(np.float32)\n",
    "x_tst = (x_tst / 255)\n",
    "x_tst = x_tst.reshape(-1, 784).astype(np.float32)\n",
    "\n",
    "y_train = y_train.astype(np.int32)\n",
    "y_tst = y_tst.astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784) (55000,)\n",
      "(5000, 784) (5000,)\n"
     ]
    }
   ],
   "source": [
    "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n",
    "\n",
    "x_tr = x_train[tr_indices]\n",
    "y_tr = y_train[tr_indices]\n",
    "\n",
    "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n",
    "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0)\n",
    "\n",
    "print(x_tr.shape, y_tr.shape)\n",
    "print(x_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the graph of Softmax Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyper-par setting\n",
    "epochs = 30\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for train\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 10000)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "\n",
    "# for validation\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))\n",
    "val_dataset = val_dataset.batch(batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create weight and bias, initialized to 0 \n",
    "w = tf.Variable(initial_value=tf.random_normal(shape=[784,10],\n",
    "                                               stddev=np.sqrt(2. / (784 + 10)).astype(np.float32)), name='weights')\n",
    "b = tf.Variable(initial_value=tf.zeros(shape = [10], ), name='bias')\n",
    "\n",
    "# construct model\n",
    "def model(x):\n",
    "    score = tf.matmul(x, w) + b\n",
    "    return score\n",
    "    \n",
    "# use the cross entropy as loss function\n",
    "def loss_fn(model, x, y):\n",
    "    ce_loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=model(x))\n",
    "    return ce_loss\n",
    "\n",
    "# using gradient descent with learning rate of 0.01 to minimize loss\n",
    "opt = tf.train.GradientDescentOptimizer(learning_rate=.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create writer for tensorboard\n",
    "logdir = '../graphs/lecture03/logreg_tf_data_de/'\n",
    "summary_writer = tf.contrib.summary.create_file_writer(logdir=logdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   5, tr_loss : 0.416, val_loss : 0.423\n",
      "epoch :  10, tr_loss : 0.359, val_loss : 0.374\n",
      "epoch :  15, tr_loss : 0.335, val_loss : 0.353\n",
      "epoch :  20, tr_loss : 0.321, val_loss : 0.340\n",
      "epoch :  25, tr_loss : 0.312, val_loss : 0.332\n",
      "epoch :  30, tr_loss : 0.305, val_loss : 0.325\n"
     ]
    }
   ],
   "source": [
    "#epochs = 30\n",
    "#batch_size = 64\n",
    "#total_step = int(x_tr.shape[0] / batch_size)\n",
    "\n",
    "tr_loss_hist = []\n",
    "val_loss_hist = []\n",
    "tf.GradientTape.gradient\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    avg_tr_loss = 0\n",
    "    avg_val_loss = 0\n",
    "    tr_step = 0\n",
    "    val_step = 0\n",
    "    \n",
    "    # for training\n",
    "    with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): # for tensorboard\n",
    "        for x_mb, y_mb in tr_dataset:\n",
    "            with tf.GradientTape() as tape:\n",
    "                tr_loss = loss_fn(model=model, x=x_mb, y=y_mb)\n",
    "            tf.contrib.summary.scalar(name='tr_loss', tensor=tr_loss)\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            grads = tape.gradient(target=tr_loss, sources=[w, b])\n",
    "            opt.apply_gradients(grads_and_vars=zip(grads, [w, b]))\n",
    "        else:\n",
    "            avg_tr_loss /= tr_step\n",
    "            tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    \n",
    "        # for validation\n",
    "        for x_mb, y_mb in val_dataset:\n",
    "            val_loss = loss_fn(model=model, x=x_mb, y=y_mb)\n",
    "            tf.contrib.summary.scalar(name='val_loss', tensor=val_loss)\n",
    "            avg_val_loss += val_loss\n",
    "            val_step += 1\n",
    "        else:\n",
    "            avg_val_loss /= val_step\n",
    "            val_loss_hist.append(avg_val_loss)\n",
    "            \n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    val_loss_hist.append(avg_val_loss)\n",
    "    \n",
    "    if (epoch + 1) % 5 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f0eec166630>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmcXFWd9/HPr6qruqq39JqQdGdpQiD7RhuiyGZcQpAgm4A6iqPEURhwlHkGHbdhdMZ5PTwMOoMoMojjsIhxlOhEcQsDCsQkECALhEAS0glJd5be9+7z/HErne7qTnelqe5K3fq+X696ddW9p6t+JzS/e+qcc88x5xwiIuIvgVQHICIiyafkLiLiQ0ruIiI+pOQuIuJDSu4iIj6k5C4i4kNK7iIiPqTkLiLiQ0ruIiI+lJWqDy4tLXXTpk1L1ceLiKSlTZs2HXLOlQ1XLmXJfdq0aWzcuDFVHy8ikpbMbE8i5dQtIyLiQ0ruIiI+pOQuIuJDKetzFxF/6ezspLq6mra2tlSH4guRSISKigpCodCIfl/JXUSSorq6mvz8fKZNm4aZpTqctOac4/Dhw1RXV1NZWTmi91C3jIgkRVtbGyUlJUrsSWBmlJSUvKVvQUruIpI0SuzJ81b/LdOuW2bD7iM8taN2wPH8SIiPnzuNrKCuVyIiaZfcn9tzlH9bt7PfsWPbwC6pLGbB5MIURCUiqVZXV8dDDz3EZz7zmZP6vRUrVvDQQw9RWOiv3JF2yf1TF0znUxdM73fsmdcOc933n6W5oytFUYlIqtXV1fGd73xnQHLv6uoiK+vEqW7t2rWjHVpKpF1yH0w0HASgtaM7xZGISKrcdtttvPbaayxcuJBQKEQkEqGoqIiXX36ZHTt28IEPfIC9e/fS1tbGLbfcwqpVq4DjS6E0NTVx8cUX8853vpOnn36a8vJyHnvsMaLRaIprNjLDJnczux94P1DjnJs7yHkDvgWsAFqA651zzyU70KHkHEvunUruIqeCf/jFVrbtb0jqe86eVMBXL51zwvPf/OY32bJlC5s3b+aJJ57gkksuYcuWLb1TCe+//36Ki4tpbW3lbW97G1deeSUlJSX93uPVV1/l4Ycf5vvf/z4f/OAH+elPf8pHPvKRpNZjrCQy+vgAsHyI8xcDM2KPVcA9bz2skxMNecm9RS13EYlZsmRJvzni3/72t1mwYAFLly5l7969vPrqqwN+p7KykoULFwJw9tlns3v37rEKN+mGbbk75540s2lDFLkM+E/nnAOeNbNCM5vonHszSTEO61i3TJta7iKnhKFa2GMlNze39/kTTzzB7373O5555hlycnK48MILB51Dnp2d3fs8GAzS2to6JrGOhmTMGywH9vZ5XR07NmbUcheR/Px8GhsbBz1XX19PUVEROTk5vPzyyzz77LNjHN3YG9MBVTNbhdd1w5QpU5L2vseSuwZURTJXSUkJ5557LnPnziUajTJhwoTec8uXL+e73/0us2bN4qyzzmLp0qUpjHRsJCO57wMm93ldETs2gHPuXuBegKqqKpeEzwYgEDCyswIaUBXJcA899NCgx7Ozs/nVr3416Llj/eqlpaVs2bKl9/itt96a9PjGUjK6ZdYAHzXPUqB+LPvbj8kJB9VyFxGJSWQq5MPAhUCpmVUDXwVCAM657wJr8aZB7sSbCvnx0Qp2KNFQUH3uIiIxicyWuW6Y8w64MWkRjVA0HNRsGRGRGN+sshUNB2nR8gMiIoCPkntOKEsDqiIiMb5J7hENqIqI9PJNcs8JBdVyF5GE5eXlAbB//36uuuqqQctceOGFbNy4ccj3ueuuu2hpael9vWLFCurq6pIX6Aj5Jrl7fe5K7iJyciZNmsTq1atH/PvxyX3t2rWnxNrwvkrumi0jkrluu+027r777t7XX/va1/j617/OsmXLWLx4MfPmzeOxxx4b8Hu7d+9m7lxvwdvW1lauvfZaZs2axeWXX95vbZlPf/rTVFVVMWfOHL761a8C3mJk+/fv56KLLuKiiy4CvCWEDx06BMCdd97J3LlzmTt3LnfddVfv582aNYsbbriBOXPm8N73vndU1rDxxXruoHnuIqeUX90GB15K7nueNg8u/uYJT19zzTV89rOf5cYbvZnZjz76KI8//jg333wzBQUFHDp0iKVLl7Jy5coT7k96zz33kJOTw/bt23nxxRdZvHhx77lvfOMbFBcX093dzbJly3jxxRe5+eabufPOO1m3bh2lpaX93mvTpk384Ac/YP369TjnOOecc7jgggsoKioak6WFfdNyzwl7fe7OJW1VAxFJI4sWLaKmpob9+/fzwgsvUFRUxGmnncYXv/hF5s+fz7vf/W727dvHwYMHT/geTz75ZG+SnT9/PvPnz+899+ijj7J48WIWLVrE1q1b2bZt25Dx/PGPf+Tyyy8nNzeXvLw8rrjiCp566ilgbJYW9k3LPRIK4hy0d/UQiS0kJiIpMkQLezRdffXVrF69mgMHDnDNNdfw4IMPUltby6ZNmwiFQkybNm3QpX6Hs2vXLu644w42bNhAUVER119//Yje55ixWFrYVy130MqQIpnsmmuu4ZFHHmH16tVcffXV1NfXM378eEKhEOvWrWPPnj1D/v7555/fu/jYli1bePHFFwFoaGggNzeXcePGcfDgwX6LkJ1oqeHzzjuPn//857S0tNDc3MzPfvYzzjvvvCTWdmi+abn3rune2U1RimMRkdSYM2cOjY2NlJeXM3HiRD784Q9z6aWXMm/ePKqqqpg5c+aQv//pT3+aj3/848yaNYtZs2Zx9tlnA7BgwQIWLVrEzJkzmTx5Mueee27v76xatYrly5czadIk1q1b13t88eLFXH/99SxZsgSAT37ykyxatGjMdneyVPVRV1VVueHmj56Mxzbv45ZHNvO7z13AGePzkva+IpKY7du3M2vWrFSH4SuD/Zua2SbnXNVwv+ubbhlt2CEicpxvkntO2Oth0l2qIiI+Su7RsFcVrQwpkjqaipw8b/Xf0j/JPeS13HWXqkhqRCIRDh8+rASfBM45Dh8+TCQSGfF7+Ge2TGwqpO5SFUmNiooKqqurqa2tTXUovhCJRKioqBjx7/smuffOc1fLXSQlQqEQlZWVqQ5DYnzTLRPRbBkRkV4JJXczW25mr5jZTjO7bZDzU83s92b2opk9YWYj/y4xQrpDVUTkuGGTu5kFgbuBi4HZwHVmNjuu2B3Afzrn5gO3A/+c7ECHEwoGyAoYLeqWERFJqOW+BNjpnHvdOdcBPAJcFldmNvCH2PN1g5wfE1FttSciAiSW3MuBvX1eV8eO9fUCcEXs+eVAvpmVxL+Rma0ys41mtnE0RtSjISV3ERFI3oDqrcAFZvY8cAGwDxiQZZ1z9zrnqpxzVWVlZUn66OOOrekuIpLpEpkKuQ+Y3Od1RexYL+fcfmItdzPLA650zo35DrER7cYkIgIk1nLfAMwws0ozCwPXAmv6FjCzUjM79l5fAO5PbpiJydE+qiIiQALJ3TnXBdwEPA5sBx51zm01s9vNbGWs2IXAK2a2A5gAfGOU4h1SNBzU2jIiIiR4h6pzbi2wNu7YV/o8Xw2sTm5oJy8ayuJIc2eqwxARSTnf3KEKx6ZCquUuIuKr5J4T0mwZERHwWXL3+tyV3EVEfJfcNVtGRMRvyT0UpLPb0dndk+pQRERSylfJXWu6i4h4fJXctaa7iIjHV8lda7qLiHh8ldyjIe2jKiICfkvu6nMXEQH8ltzV5y4iAvgsueeEvaVy1HIXkUznq+QeDXvV0cqQIpLpfJbcvZa77lIVkUznr+Su2TIiIoDPkrvuUBUR8fgquWdnBTDTbBkREV8ldzMjGgoquYtIxksouZvZcjN7xcx2mtltg5yfYmbrzOx5M3vRzFYkP9TERENBWtQtIyIZbtjkbmZB4G7gYmA2cJ2ZzY4r9iW8jbMXAdcC30l2oImKhoO0qeUuIhkukZb7EmCnc+5151wH8AhwWVwZBxTEno8D9icvxJMTDWk3JhGRrATKlAN7+7yuBs6JK/M14Ddm9tdALvDupEQ3Ajlh7aMqIpKsAdXrgAeccxXACuBHZjbgvc1slZltNLONtbW1Sfro/iIaUBURSSi57wMm93ldETvW1yeARwGcc88AEaA0/o2cc/c656qcc1VlZWUji3gYarmLiCSW3DcAM8ys0szCeAOma+LKvAEsAzCzWXjJfXSa5sOIhoNaW0ZEMt6wyd051wXcBDwObMebFbPVzG43s5WxYp8HbjCzF4CHgeudc260gh5KNJRFW6c2yBaRzJbIgCrOubXA2rhjX+nzfBtwbnJDG5loOKCWu4hkPF/doQremu7qcxeRTOe75B4JBWnr7KGnJyW9QiIipwTfJfdjK0O2dan1LiKZy3fJXWu6i4j4MbmHtUm2iIj/kntIG3aIiPguueeo5S4i4r/krj53ERE/Jvdjs2XULSMiGcy3yV0tdxHJZL5L7jkhb0UFDaiKSCbzXXKPhL0qtWp9GRHJYL5L7jlhtdxFRHyX3DVbRkTEh8k9GDDCWQG13EUko/kuuYPXetdNTCKSyXyZ3HPCSu4iktl8mdyjoSAt6pYRkQzmz+QeDtKmlruIZLCEkruZLTezV8xsp5ndNsj5fzWzzbHHDjOrS36oiYuGgpotIyIZbdgNss0sCNwNvAeoBjaY2ZrYptgAOOf+pk/5vwYWjUKsCYuGgzS26SYmEclcibTclwA7nXOvO+c6gEeAy4Yofx3wcDKCGynNlhGRTJdIci8H9vZ5XR07NoCZTQUqgT+c4PwqM9toZhtra2tPNtaE5YSDmucuIhkt2QOq1wKrnXODZlbn3L3OuSrnXFVZWVmSP/q4aFh97iKS2RJJ7vuAyX1eV8SODeZaUtwlAxANZWk9dxHJaIkk9w3ADDOrNLMwXgJfE1/IzGYCRcAzyQ3x5EXDAVo6unDOpToUEZGUGDa5O+e6gJuAx4HtwKPOua1mdruZrexT9FrgEXcKZNSccBY9Djq6e1IdiohISgw7FRLAObcWWBt37Ctxr7+WvLDemkjo+CbZ2VnBFEcjIjL2fHmHao622hORDOfL5H5sTXdNhxSRTOXP5B4+3i0jIpKJ/Jnc1XIXkQyX0IDqKaVmO+zfPPB4dj7MvATM1OcuIhkv/ZL7q7+B335l8HOf/ANUnN1vtoyISCZKv+S++GMwa2X/Y/V74YeXwqEdUHF2b8u9tVMrQ4pIZkq/5B4t9B59FZSDBeDI616R3gFV3cQkIpnJHwOqWWEYN7k3ueeEvGtWS4da7iKSmfyR3AGKT+9N7pGwVy0tHiYimcqXyT0cDBAMmGbLiEjG8ldyb6uDliOYmbcbk1ruIpKh/JPcS6Z7P/sMqmoqpIhkKv8k9+LTvZ/Hkrta7iKSwfyT3AunAnZ8xoy22hORDOaf5B6KwLiK4zNmQkHNlhGRjOWf5A5QXKmWu4gIvkvup/fvc1dyF5EMlVByN7PlZvaKme00s9tOUOaDZrbNzLaa2UPJDTNBxadDy2ForfNmy6hbRkQy1LBry5hZELgbeA9QDWwwszXOuW19yswAvgCc65w7ambjRyvgIRUfnw6plruIZLJEWu5LgJ3Oudedcx3AI8BlcWVuAO52zh0FcM7VJDfMBPWZDun1uWttGRHJTIkk93Jgb5/X1bFjfZ0JnGlmfzKzZ81sebICPClF07yfR3YRCQdp69SqkCKSmZK15G8WMAO4EKgAnjSzec65ur6FzGwVsApgypQpSfroPsI5kD/Ja7nnZ9HR3UNXdw9ZQX+NG4uIDCeRrLcPmNzndUXsWF/VwBrnXKdzbhewAy/Z9+Ocu9c5V+WcqyorKxtpzEOLzZiJxlaG1KCqiGSiRJL7BmCGmVWaWRi4FlgTV+bneK12zKwUr5vm9STGmbjYXPdo2PtSokFVEclEwyZ351wXcBPwOLAdeNQ5t9XMbjezY/vdPQ4cNrNtwDrgb51zh0cr6CGVTIfmGvJpBdRyF5HMlFCfu3NuLbA27thX+jx3wOdij9SKzZgp7fR6jnSXqohkIv+NNMaSe2FrNaCWu4hkJv8l96JKAApa3wDU5y4imcl/yT07D/ImkNus5C4imStZ89xPLcWnE23cA8A/rd3OPf/7Wu+pgMFn330m555RmqroRERGnf9a7tCb3FcumMSkwijRULD3sW1/A49s2Dv8e4iIpDHfttyt6UG+feWZEM7td+rGB5/juT1HUxSYiMjY8G3LHYAjuwacWjSlkH11rdQ0tI1xUCIiY8fnyX3gTbKLpxYB8Nwbar2LiH/5NLl70yEHS+5zJhUQDgZ47o26AedERPzCn8k9Mg5ySgdN7tlZQeaWF6jfXUR8zZ/JHfrtpxpv8ZQiXtxXT0eX1nsXEX/KzOQ+tYiOrh62vdkwxkGJiIwN/yb3kunQsA86WwecWjwlNqiqrhkR8Sn/JvdjM2aO7h5w6rRxESaOi2jGjIj4lj9vYoLjM2b+6yoIRY8ftwAs+zKLp5TzvGbMiIhP+bflftp8eNsNMHkJnDbv+KOtHv707d6bmQ7qZiYR8SH/ttyDIbjkjoHHn7oTfv8PLF3aCHj97hfPmzjGwYmIjC7/ttxPZO6VAMw89JvYzUzqdxcR/8m85F40FSYvJWvrauZOytedqiLiSwkldzNbbmavmNlOM7ttkPPXm1mtmW2OPT6Z/FCTaN5VUPsyy8uO8JJuZhIRHxo2uZtZELgbuBiYDVxnZrMHKfpj59zC2OO+JMeZXHOugEAWy7qepKOrh63761MdkYhIUiXScl8C7HTOve6c6wAeAS4b3bBGWW4JTH8X0978FUaPumZExHcSSe7lQN+ti6pjx+JdaWYvmtlqM5s82BuZ2Soz22hmG2tra0cQbhLNu5pgYzXvy9+jQVUR8Z1kDaj+ApjmnJsP/Bb44WCFnHP3OueqnHNVZWVlSfroETprBWRF+VDOep7XMgQi4jOJJPd9QN+WeEXsWC/n3GHnXHvs5X3A2ckJbxRl58HMFSxpeZKa+iYO1OtmJhHxj0SS+wZghplVmlkYuBZY07eAmfW9C2glsD15IY6ieVcT6azjnYGX1DUjIr4y7B2qzrkuM7sJeBwIAvc757aa2e3ARufcGuBmM1sJdAFHgOtHMebkmb4MFynkip6n+d6TF/HMa4f7na6aVsRlCwcbXhARObWZcy4lH1xVVeU2btyYks/u5xe30PHcI1xo99Fmkd7D7Z3ddHY7/njbRYzPjwzxBiIiY8fMNjnnqoYr59+1ZRI172rCmx7g6fe+AZUX9B7eV9fK+374Bj/4027+bvnMFAYoInLylNynvAMKKuDxL/Y7XA6sy5vEJc9+k89cOJ38SCg18YmIjICSeyAAH30Marb1P95cQ9n/fJ6/7HqUh/88l1XnT09NfCIiI6DkDlB6hveIt38zNzz/IH/x5Lu4/h2VhLMyb501EUlPylZDec/t9GSP4286vstjz+8dvryIyClCyX0oOcVkve8feVtgB3t+fy89PamZWSQicrKU3IdhCz/MoZKz+UTrAzz1wsupDkdEJCFK7sMJBCi86t/It1a6H/9yqqMREUmIknsCsibOYfu0j/Kutt/y8vpfpzocEZFh6Q7VBLU01VN3x9nkWTsd+RX9zvVEiii99h4CxVNTFJ2IZIpE71BVyz1BOXnjeHrR/+XZrhm8cDS73yNy8Hkavv9+aErxGvUiIjFquZ8E5xz1rZ3ET5q5/+GHuXHvrbiSM8hZ9WuIjEtNgCLie2q5jwIzozAnTHFu/8cNH/4QX8r+O0JHdtD1Xx+EztZUhyoiGU7JPQnGRUN8+COf4PNdnyFQvR73k49Bd2eqwxKRDKblB5Jk8ZQi/vye6/nyb5r4xo774eefhqpP9C8UCMLEBZCVnZogRSRjKLkn0arzTuf6167jzl3NfO6lH8NLPxlYaOIC+OB/QtG0MY9PRDKHknsSBQLGnR9cwMV3Xc2O8BKunJ3X73xe+0HO2fH/CHzvfLji+3Dm+1IUqYj4nWbLjII/7TzEXz6wgfaungHnzils4Ed5/0740BY471a46Ited42ISAISnS2TUHI3s+XAt/D2UL3POffNE5S7ElgNvM05N2Tm9nNyB2hu76K1s7vfsZ01TXzqR5sozOpizRlrGLf9YW/3p/M+DxY3tl0yHQomjWHEIpIOkpbczSwI7ADeA1QDG4DrnHPb4srlA/8DhIGbMj25n8j2Nxv4i//4Mz3O8Yt3vE7501+G7vaBBYPZXtI/9xYIaQ9XEfEkcw/VJcBO59zrsTd+BLgMiNu6iH8E/gX425OMNaPMmljAT/7q7XzkvvUsf2oaD171e+bn1vUv5Lph0w/hiX+CFx6GFXfAjHenJmARSUuJJPdyoO9OFdXAOX0LmNliYLJz7n/MTMl9GJWluTwaS/DX/HgfHzpnClkB61dmzpnf4NLFH8XW3goPXgmzLoULvwjZ/QdpCedBTvEYRi8i6eAtz5YxswBwJ3B9AmVXAasApkyZ8lY/Oq2VF0Z59FNv56/+axMPrt/T71yPg46uHtbMGs8/f2QdZS99D568A7b/YuAbWQDmXOF14UyYPUbRi8ipLpE+97cDX3POvS/2+gsAzrl/jr0eB7wGNMV+5TTgCLByqH73TO1zT0RPj+MHT+/mX379MnnZWfzT5XNZXtEJu54C4v571b4MG38AHU0w8/1w/q0waVFK4haR0ZfMAdUsvAHVZcA+vAHVDznntp6g/BPArRpQfetePdjI5x59gZf21XPFonK+unIO46KhgQVbjsD678H6e6CtHirPh3GT+5cxg8oLve4dDdCKpK1kT4VcAdyFNxXyfufcN8zsdmCjc25NXNknUHJPms7uHv79Dzv593U76e5xWP+ueQoiIW68aDrXv6OScFcTbLgPNj8IXXEzcDpboOUwRItgwYfg7I9B2VljVxERSYqkJvfRoOR+cl6qrue32w9C3H+vF6rr+d8dtVSW5vL3K2axbNZ4LP4KANDTA7v+FzY9AC//Enq6YMo7YOL8gWULp3ot/MLJA8+JSEopuWeQda/U8PVfbuO12mbOm1HKly6ZTWVpbr8yZhAKxm6UaqqBzQ95Lfymg/3fzAHt9d7zSYth9kqYtdK7qUpEUk7JPcN0dvfwo2f2cNfvdtDQ1jVomfNmlPKp86dz7hklg7fujzn8GmxfA9vWwP7nvGN5EyAQN7kqnAcz3uMN5E5eomUURMaAknuGOtLcwX8/Vz1gXZvGti5++lw1tY3tzJ5YwKcuOJ0V8yYeb82fSN0b3hTMmvh71oDGA7DrSejugJxSOOtimPFeyM7vX84CMGEu5Ja8xdqJiJK7DNDe1c1jz+/n3qdeZ2dNE+WFURZMHrgl4NSSXK5cXM4Z4/MHeZc4bQ2w83deP/6O30BH44nLjp8DlefBtHfC1HN185XICCi5ywn19DjWvVLDA0/v5kB9W79zDth1qJnuHsfCyYVcdXYFl86fxLicQaZgxutqhwMvDdyFqrsdqjfC7j/CG89CV2wbwuAgm5bkT4DyKig/23tMXADhnJFVVMSHlNxlxGoa23js+f2s3lTNKwcbCWcFWDS5kKxg/3763HAW75o5nvfOOY3i3HBib97V4fXj7/mT1+rvx8HRPbBvE9THVrywYGx1zLgxgnAOTFzoXQAqzva6fbTDlWQAJXd5y5xzbN3fwE827mXr/vhEDG/Wt7GvrpVgwDinspiL501k2czx5Ef6D7wGzMjNPsmVLhoPeheB6o3QsH/g+daj3kWgucZ7HQx78/bjvw1YAIpPh/GzYPxsb4mGgnIG3DAgkiaU3GXUHUv+v9ryJr/acoDXa5tPWPb0slwuOLOMC84sY+npJURCSZhZ4xzUV3tJft8mbymGnv5r6NPd4c3+aexzgcguGDjoC5Bb5l0Ajl0Ixs+E/EmDXwh0cZAUUXKXMeWc49WaJp7eeYiunv5/U22d3fx591HWv36Y9q4ewlkBzqksZuK4gcsgTCqM8o7ppSycXEg4a5iZPCej9SjUbPdm/dS8fLzfv7cCQMM+r0zTgaHfKxj2bvQqroSiSm8/3HHlA6eKWsDrUiqaBpGBA9ciI6HkLqects5u1u86wpM7avnTzkPUt/YfeHUODja24RxEQgHeNq2Yt08v4awJ+QTiWsrRcJBZEwsGX2vnrWo54n0LqNkGzYcGnu9ogqO7vceR3UPPEDomWuQl+cKpEBpkgDhaBEVTj5cpnKKBZBmUkrukpfqWTtbvOszTrx3mmdcO88rBoRNnZWku88rHMb9iHGedlk84bt5+MGDMGJ+f2GyfkXDOuxg07GPAip3dXdBQffxCcHS3d99AV0f8m3gXkfhvE1kRBgwkZ2V7YwYFE71vBQXlXndS/DaNFoC88ZAfK5dTCoEkfhOSlFFyF1841NRO9dHWAcfrWjrYur+BF6vreKm6nv1xUzrjnV6ay8LJhSyIPYoGSfbFuWHyI6N0ERiOc96yEHV7vBlDdbsHmU2EtwBcw5vexaRh//EB5eEEQt5dxlmDzGrKzofc8d7FILfUex7OHVguKwL5px1/RAo19pACSu6SUWob23mttomeuP7+9u4etu1vYPPeOjbvraO2cZD9avs4rSDCjAl5TC/LY8aEPMbnR+LbzkRCQaaV5jBpXJRAIMXJrasDWo8MWFCOnk7vYtH4pncxaNzvzUDqibsHwTlvmejmWu/RVDOwzIlkRSBaPPi3hmihd6HIKfEe0WIIDjJjKpTrlY0UeuMS0cLBp7SGcrxvKFriIql7qIqc8srysynLH3ye+0VnjQe8Qd/99W28VF1HS0f/WTXH+vt3Hmzi1ZomHt24d0CZeNlZASpLc6kszaWiKEowEN8lBFOKczhjfD5njM8bnfGBrLDXih5M4Qh2O3MO2uqgc5BvQp0t3pITjW96C841vgktRweW6+nyBrBbDsOR171uq/ZBvoWcLAt43yryT/O6m3KKGdBtZebNhoqMg0jsZ3a+d79EvEiB112VU+K9l88uHErukjHMjPLCKOWF0WHL9vQ49te3crR5YCu2sb2T3Yda2HWoiV2HmnnlQCN/eLkmvsed7h5Hd59vEuPzszljfB454YFFaMjQAAAIwklEQVT/25XkhplSksPUkhymFucypSRndC4GwzHzBndP9E800tVBuzvB9V/vCOego9m7mLTVQWvsZ/wdzgDtjccvKI0HvJvcDrw0sJzr9sp2NA08NySLfWsYpOLBLG+RvHCe112VnTf4eEggGLuwHLu4DHFhOW2eN9tqFCm5iwwiEDAqinKoKBr8/DsSyHHdPY7qoy3srPG+Dbx6sInXDzVR1xI3SwjYvLeOQ039u4yCARu0S6i8MEpFkfeYXJxDWX42wbjuoaAZ4wsiTC6KUpqXnfruo+AJLlShyOgsKNfd5X1baKv3kn38hYVYd1TLYWg+7P1sOTRwkxvwLjYdTccfTQehc+A4ED3dxz9zwKU+ziV3QvEnRlq7hCi5i4ySYMCYWpLL1JJcls2aMGz55vYu3jjSwp7DLew53ExD28AWbHN7N9VHW6k+2sKfdx2hsX3w5Z37CgcDlBdFmVQYITtrYCuyIJLV2601Pj9CWX72oHcUR0IBJuRHKMwJDb1k9KkgmOV1taRicbqeHu8i0FbvJfvBxjULJo16GEruIqeI3OwsZk0sYNbEgoR/p761k9rYvQF9dXY7DjS0Un20lX1HvZ/761tpaO1/MXA4Xq3ppKahfcAy0ScSzgowoSCb0woiFOWEB9yDEAhAUU6Y0rxsSvOzKcsLU5KXTdYg3x4Kc8KMP8HFJG0FArGumcT/O44GH/2LimSecdHQCfvmZ09KPLk452hq76KmsZ2ahnbaOgcOJjd3dHGwoZ2ahjYONLRxsKGNPYdbBpTr6unhSHMHR1sSnHUD5GVnMb4gmwn5EXKzB367yA4FKc09fsEozcv2vkHElTMzxkVDFOV4/y5Zw+1X4GMJJXczWw58C2+D7Pucc9+MO/9XwI1AN9AErHLODbK7g4icisyM/EiI/EiI6WV5SXnPzm4vydc2tnOkuYPu+K8XDo62dHgXjMY2ahraOdjQxpv1Ay8KrR3d1Da103iCXcZOJD+S5SX5Qb415ISzKIh65wsi3sVgsCUvIqEgRblhSnLDFMce+ZEs4kdEAuZ9+8oJB0+Jbqthk7uZBYG7gfcA1cAGM1sTl7wfcs59N1Z+JXAnsHwU4hWRNBEKBphQEGFCwcA1hEaqrbObw80dHGpsH3RMoqvH0dDaydHmDupaO6lr6aS+tZOeuAuLc9DS0UV9qzfzqb7VK9fVM7BrqrP75O4FMoO8cJaX6LODBAdJ9Dcvm8GlC0a33z2RlvsSYKdz7nUAM3sEuAzoTe7Oub6TWHMZdqhYROTkHZstlMh01mTp6OqhrqWDw80dHIk9mgYZyO7qcbS0d9Hc3kVj7GdzezdukHQ4FtNcE0nu5cDePq+rgXPiC5nZjcDngDDwrsHeyMxWAasApkwZwQ0WIiJjLJwVYHxBhPFJ/AYyFpI22uCcu9s5Nx34O+BLJyhzr3OuyjlXVVZWlqyPFhGROIkk933A5D6vK2LHTuQR4ANvJSgREXlrEknuG4AZZlZpZmHgWmBN3wJmNqPPy0uAV5MXooiInKxh+9ydc11mdhPwON5UyPudc1vN7HZgo3NuDXCTmb0b6ASOAh8bzaBFRGRoCc1zd86tBdbGHftKn+e3JDkuERF5CzL39i0RER9TchcR8SEldxERH0rZNntmVgvsGeGvlwKDbEuftvxUHz/VBVSfU5mf6gKJ12eqc27YG4VSltzfCjPbmMgegunCT/XxU11A9TmV+akukPz6qFtGRMSHlNxFRHwoXZP7vakOIMn8VB8/1QVUn1OZn+oCSa5PWva5i4jI0NK15S4iIkNIu+RuZsvN7BUz22lmt6U6npNlZvebWY2ZbelzrNjMfmtmr8Z+FqUyxkSZ2WQzW2dm28xsq5ndEjuervWJmNmfzeyFWH3+IXa80szWx/7mfhxbQC8tmFnQzJ43s1/GXqdzXXab2UtmttnMNsaOpevfWqGZrTazl81su5m9Pdl1Savk3mfLv4uB2cB1ZjY7tVGdtAcYuAXhbcDvnXMzgN/HXqeDLuDzzrnZwFLgxth/j3StTzvwLufcAmAhsNzMlgL/Avyrc+4MvIXxPpHCGE/WLcD2Pq/TuS4AFznnFvaZMpiuf2vfAn7tnJsJLMD7b5Tcujjn0uYBvB14vM/rLwBfSHVcI6jHNGBLn9evABNjzycCr6Q6xhHW6zG8vXbTvj5ADvAc3q5jh4Cs2PF+f4On8gNv74Xf4+2M9kvA0rUusXh3A6Vxx9Lubw0YB+wiNuY5WnVJq5Y7g2/5V56iWJJpgnPuzdjzA8CEVAYzEmY2DVgErCeN6xPrxtgM1AC/BV4D6pxzxzbNTKe/ubuA/wMc2/W5hPStC3h7M//GzDbFtuyE9PxbqwRqgR/EuszuM7NcklyXdEvuvue8y3ZaTWEyszzgp8BnXf/N0tOuPs65bufcQrxW7xJgZopDGhEzez9Q45zblOpYkuidzrnFeN2yN5rZ+X1PptHfWhawGLjHObcIaCauCyYZdUm35H6yW/6li4NmNhEg9rMmxfEkzMxCeIn9Qefcf8cOp219jnHO1QHr8LouCs3s2N4H6fI3dy6w0sx24219+S68ft50rAsAzrl9sZ81wM/wLr7p+LdWDVQ759bHXq/GS/ZJrUu6Jfdht/xLU2s4vnvVx/D6rk95ZmbAfwDbnXN39jmVrvUpM7PC2PMo3vjBdrwkf1WsWFrUxzn3BedchXNuGt7/J39wzn2YNKwLgJnlmln+sefAe4EtpOHfmnPuALDXzM6KHVoGbCPZdUn14MIIBiNWADvw+kL/PtXxjCD+h4E38bYkrMabrVCCN/D1KvA7oDjVcSZYl3fifXV8Edgce6xI4/rMB56P1WcL8JXY8dOBPwM7gZ8A2amO9STrdSHwy3SuSyzuF2KPrcf+30/jv7WFwMbY39rPgaJk10V3qIqI+FC6dcuIiEgClNxFRHxIyV1ExIeU3EVEfEjJXUTEh5TcRUR8SMldRMSHlNxFRHzo/wMoK/XApeClfgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')\n",
    "plt.plot(val_loss_hist, label = 'validation')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc : 91.70%\n"
     ]
    }
   ],
   "source": [
    "yhat = np.argmax(model(x_tst), axis = 1)\n",
    "print('acc : {:.2%}'.format(np.mean(yhat == y_tst)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}