{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 03 : Linear and Logistic Regression\n", "### Logistic Regression with ce loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.12.0\n" ] } ], "source": [ "from __future__ import absolute_import, division, print_function\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "%matplotlib inline\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and Pre-process data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()\n", "x_train = x_train / 255\n", "x_train = x_train.reshape(-1, 784)\n", "x_tst = x_tst / 255\n", "x_tst = x_tst.reshape(-1, 784)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(55000, 784) (55000,)\n", "(5000, 784) (5000,)\n" ] } ], "source": [ "tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)\n", "\n", "x_tr = x_train[tr_indices]\n", "y_tr = y_train[tr_indices]\n", "\n", "x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)\n", "y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0)\n", "\n", "print(x_tr.shape, y_tr.shape)\n", "print(x_val.shape, y_val.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the graph of Softmax Classifier" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# create placeholders for X (birth rate) and Y (life expectancy)\n", "X = tf.placeholder(dtype = tf.float32, shape = [None, 784])\n", "Y = tf.placeholder(dtype = tf.int32, shape = [None])\n", "\n", "# create weight and bias, initialized to 0 \n", "w = tf.get_variable(name = 'weights', shape = [784, 10], dtype = tf.float32,\n", " initializer = tf.contrib.layers.xavier_initializer())\n", "b = tf.get_variable(name = 'bias', shape = [10], dtype = tf.float32,\n", " initializer = tf.zeros_initializer())\n", "# construct model\n", "score = tf.matmul(X, w) + b\n", "\n", "# use the cross entropy as loss function\n", "ce_loss = tf.losses.sparse_softmax_cross_entropy(labels = Y, logits = score)\n", "ce_loss_summ = tf.summary.scalar(name = 'ce_loss', tensor = ce_loss) # for tensorboard\n", "\n", "# using gradient descent with learning rate of 0.01 to minimize loss\n", "opt = tf.train.GradientDescentOptimizer(learning_rate=.01)\n", "training_op = opt.minimize(ce_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "859\n" ] } ], "source": [ "epochs = 30\n", "batch_size = 64\n", "total_step = int(x_tr.shape[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/logreg_tf_placeholder/train',\n", " graph = tf.get_default_graph())\n", "val_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/logreg_tf_placeholder/val',\n", " graph = tf.get_default_graph())" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 5, tr_loss : 0.43, val_loss : 0.42\n", "epoch : 10, tr_loss : 0.36, val_loss : 0.36\n", "epoch : 15, tr_loss : 0.34, val_loss : 0.34\n", "epoch : 20, tr_loss : 0.33, val_loss : 0.33\n", "epoch : 25, tr_loss : 0.32, val_loss : 0.32\n", "epoch : 30, tr_loss : 0.31, val_loss : 0.31\n" ] } ], "source": [ "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n", "sess = tf.Session(config = sess_config)\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "val_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " avg_val_loss = 0\n", " \n", " for step in range(total_step):\n", " \n", " batch_indices = np.random.choice(range(x_tr.shape[0]), size = batch_size, replace = False)\n", " val_indices = np.random.choice(range(x_val.shape[0]), size = batch_size, replace = False)\n", " \n", " batch_xs = x_tr[batch_indices] \n", " batch_ys = y_tr[batch_indices]\n", " val_xs = x_val[val_indices]\n", " val_ys = y_val[val_indices]\n", " \n", " _, tr_loss = sess.run(fetches = [training_op, ce_loss],\n", " feed_dict = {X : batch_xs, Y : batch_ys})\n", " tr_loss_summ = sess.run(ce_loss_summ, feed_dict = {X : batch_xs, Y : batch_ys})\n", "\n", " val_loss, val_loss_summ = sess.run(fetches = [ce_loss, ce_loss_summ],\n", " feed_dict = {X : val_xs, Y: val_ys})\n", " avg_tr_loss += tr_loss / total_step\n", " avg_val_loss += val_loss / total_step\n", " \n", " tr_loss_hist.append(avg_tr_loss)\n", " val_loss_hist.append(avg_val_loss)\n", " train_writer.add_summary(tr_loss_summ, global_step = epoch)\n", " val_writer.add_summary(val_loss_summ, global_step = epoch)\n", " \n", " if (epoch + 1) % 5 == 0:\n", " print('epoch : {:3}, tr_loss : {:.2f}, val_loss : {:.2f}'.format(epoch + 1, avg_tr_loss, avg_val_loss))\n", "\n", "train_writer.close()\n", "val_writer.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.legend.Legend at 0x7f71b02694a8>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmYXHWd7/H3t/beqvesnY3sO0k6CRAQEMQIlwAim+OMcZS4wDCMo3cYr1eQGWccRQa9A/KA4uhcJWaiYJyJolyDgCymIyRkTwhZOksn6fS+V9fv/lGVptN0uitJdaqr+vN6nnq6zqlTp74nBZ8+/TvnfI855xARkcziSXUBIiKSfAp3EZEMpHAXEclACncRkQykcBcRyUAKdxGRDKRwFxHJQAp3EZEMpHAXEclAvlR9cElJiRs/fnyqPl5EJC1t2LDhuHOutL/lUhbu48ePp6KiIlUfLyKSlsxsXyLLaVhGRCQDKdxFRDKQwl1EJAOlbMxdRDJLR0cHlZWVtLa2prqUjBAKhSgrK8Pv95/V+xXuIpIUlZWV5OXlMX78eMws1eWkNecc1dXVVFZWMmHChLNah4ZlRCQpWltbKS4uVrAngZlRXFx8Tn8FKdxFJGkU7Mlzrv+W6Rfu+16F5x8A3R5QROS00i/cD70BL/8rtNamuhIRGURqa2t57LHHzvh91157LbW1mZcn6RfuOfGrbpuOp7YOERlUThfukUikz/etXbuWgoKCgSorZdIu3DuzimNPFO4i0s19993H22+/zYUXXsjChQu57LLLWLZsGTNmzADgxhtvZMGCBcycOZMnnnii633jx4/n+PHj7N27l+nTp3PnnXcyc+ZMrrnmGlpaWlK1Oecs7U6FXLWthTuA9voqAqkuRkR69dVfbmHrofqkrnPGqDD3Xz/ztK9//etfZ/Pmzbz55pu88MILXHfddWzevLnrVMKnnnqKoqIiWlpaWLhwITfffDPFxcWnrGPXrl08/fTTPPnkk9x666387Gc/42Mf+1hSt+N8SbtwDxWMAKC55ojCXUROa9GiRaecI/6d73yHZ555BoADBw6wa9eu94T7hAkTuPDCCwFYsGABe/fuPW/1JlvahXtu4TAA2uqOprgSETmdvvawz5ecnJyu5y+88ALPP/88r776KtnZ2VxxxRW9nkMeDAa7nnu93rQelkm7MfeicC61LoeOeoW7iLwrLy+PhoaGXl+rq6ujsLCQ7Oxstm/fzmuvvXaeqzv/0m7PvSQ3QLULE2g6lupSRGQQKS4uZsmSJcyaNYusrCyGDx/e9drSpUt5/PHHmT59OlOnTuWiiy5KYaXnR9qFe1FOgK3kM6ZZZ8uIyKl+8pOf9Do/GAzyq1/9qtfXTo6rl5SUsHnz5q75X/jCF5Je3/nU77CMmT1lZkfNbPNpXjcz+46Z7TazTWY2P/llvis36KOWMP62EwP5MSIiaS2RMfd/B5b28fqHgMnxxwrgu+de1umZGU2+QkLtCncRkdPpN9ydcy8CfSXpDcCPXMxrQIGZjUxWgb1pCxaR01kH0c6B/BgRkbSVjLNlRgMHuk1XxucNmEioCA8OWmoG8mNERNLWeT0V0sxWmFmFmVUcO3b2Z7u47JP9ZXTGjIhIb5IR7geBMd2my+Lz3sM594Rzrtw5V15aWnrWH+jJi73XNepcdxGR3iQj3NcAfxE/a+YioM45dzgJ6z0tfzh2/qquUhWRs5WbmwvAoUOH+MhHPtLrMldccQUVFRV9rueRRx6hubm5a3qwtBBO5FTIp4FXgalmVmlmnzSzz5jZZ+KLrAX2ALuBJ4HPDVi1cVkFsXBvrjky0B8lIhlu1KhRrF69+qzf3zPcB0sL4UTOlrnDOTfSOed3zpU5577vnHvcOfd4/HXnnLvLOTfROTfbOdf3r7kkyCsspdMZrdpzF5G4++67j0cffbRr+oEHHuAf//Efueqqq5g/fz6zZ8/mF7/4xXvet3fvXmbNmgVAS0sLt99+O9OnT+emm246pbfMZz/7WcrLy5k5cyb3338/EGtGdujQIa688kquvPJK4N0WwgAPP/wws2bNYtasWTzyyCNdn3c+Wgun3RWqACV52Zwgj86GqlSXIiK9+dV9cOSt5K5zxGz40NdP+/Jtt93Gvffey1133QXAqlWreO6557jnnnsIh8McP36ciy66iGXLlp32/qTf/e53yc7OZtu2bWzatIn589+9JvNrX/saRUVFdHZ2ctVVV7Fp0ybuueceHn74YdatW0dJSckp69qwYQM/+MEPeP3113HOsXjxYi6//HIKCwvPS2vhtGscBlCcG+CEC0NTdapLEZFBYt68eRw9epRDhw6xceNGCgsLGTFiBF/60peYM2cOV199NQcPHqSq6vQ7hS+++GJXyM6ZM4c5c+Z0vbZq1Srmz5/PvHnz2LJlC1u3bu2znpdffpmbbrqJnJwccnNz+fCHP8xLL70EnJ/Wwmm5516UE2CvC5PXov4yIoNSH3vYA+mWW25h9erVHDlyhNtuu40f//jHHDt2jA0bNuD3+xk/fnyvrX7788477/DQQw+xfv16CgsLWb58+Vmt56Tz0Vo4LffcQ34vdZ4CAuovIyLd3HbbbaxcuZLVq1dzyy23UFdXx7Bhw/D7/axbt459+/b1+f73ve99Xc3HNm/ezKZNmwCor68nJyeH/Px8qqqqTmlCdrpWw5dddhnPPvsszc3NNDU18cwzz3DZZZclcWv7lpZ77gDN/kKyOzalugwRGURmzpxJQ0MDo0ePZuTIkfzZn/0Z119/PbNnz6a8vJxp06b1+f7PfvazfOITn2D69OlMnz6dBQsWADB37lzmzZvHtGnTGDNmDEuWLOl6z4oVK1i6dCmjRo1i3bp1XfPnz5/P8uXLWbRoEQCf+tSnmDdv3nm7u5M5587LB/VUXl7u+jt/tC9Pf/Mu7mj6v/DlY+DTDfdEUm3btm1Mnz491WVklN7+Tc1sg3OuvL/3puWwDEAkK35kulkHVUVEekrbcCdH/WVERE4nbcPdmxsL92ijwl1ksEjVMG8mOtd/y7QN90C8v0xLrS5kEhkMQqEQ1dXVCvgkcM5RXV1NKBQ663Wk7dkyWYXx/jK1R8hJcS0iAmVlZVRWVnIu7bzlXaFQiLKysrN+f9qGe0FhCe3OS3ud9txFBgO/38+ECRNSXYbEpe2wTHFekBOE6WzQVaoiIj2lb7jnBKl2YaxZfwKKiPSUtuFemO2nmjC+Fp3nLiLSU9qGu8/rod5TQLBd4S4i0lPahjtAa6CI7I7U385KRGSwSetwbw8WEXIt0N7c/8IiIkNIQuFuZkvNbIeZ7Taz+3p5fZyZ/T8z22RmL5jZ2Z+ceQaiXf1ldMaMiEh3idwg2ws8CnwImAHcYWYzeiz2EPAj59wc4EHgn5NdaK+15cbDvUnhLiLSXSJ77ouA3c65Pc65dmAlcEOPZWYAv4s/X9fL6wPCmxe7SjXSoBtli4h0l0i4jwYOdJuujM/rbiPw4fjzm4A8MyvuuSIzW2FmFWZWkYxLlIP5wwBorjlyzusSEckkyTqg+gXgcjN7A7gcOAh09lzIOfeEc67cOVdeWlp6zh+aUzgCgJYatSAQEekukd4yB4Ex3abL4vO6OOcOEd9zN7Nc4Gbn3ICfo1hQUEiLC9ChYRkRkVMksue+HphsZhPMLADcDqzpvoCZlZjZyXX9PfBUcsvsXXFugGrC6ukuItJDv+HunIsAdwPPAduAVc65LWb2oJktiy92BbDDzHYCw4GvDVC9pyjJCXLchTGdCikicoqEWv4659YCa3vM+0q356uB1cktrX/hLB81hBnRqhYEIiLdpfUVqmZGo7eQUPuJVJciIjKopHW4Q6y/TE6kBnRrLxGRLmkf7h2hYvyuA9oaUl2KiMigkfbh7rLj10rpoKqISJe0D3fLjV8Mpf4yIiJd0j7cfeFYC4I23ShbRKRL2od7KD/WPKxJ/WVERLqkfbjnxvvLtNVqz11E5KS0D/fC/DzqXZba/oqIdJP24V6SG6Taqb+MiEh3aR/uxbkBThDG26IWBCIiJ6V9uGcHfNRYPn71lxER6ZL24Q7Q7CskS/1lRES6ZES4twaKyOmshWg01aWIiAwKGRHukaxivEShdcBv/iQikhYyItxddknsiVoQiIgAGRLuntxYCwLXpHPdRUQgwXA3s6VmtsPMdpvZfb28PtbM1pnZG2a2ycyuTX6ppxcMx5qHtdQq3EVEIIFwNzMv8CjwIWAGcIeZzeix2JeJ3Vt1HrEbaD+W7EL7EiqM9ZdpVn8ZEREgsT33RcBu59we51w7sBK4occyDgjHn+cDh5JXYv9yTvaXUWdIEREgsRtkjwYOdJuuBBb3WOYB4Ddm9ldADnB1UqpLUHFeNidcrvrLiIjEJeuA6h3AvzvnyoBrgf8ws/es28xWmFmFmVUcO5a8XjAluUFOuLDOlhERiUsk3A8CY7pNl8XndfdJYBWAc+5VIASU9FyRc+4J51y5c668tLT07CruRVFOgGrCeFsU7iIikFi4rwcmm9kEMwsQO2C6pscy+4GrAMxsOrFwP29tGgM+D3WeAoJtakEgIgIJhLtzLgLcDTwHbCN2VswWM3vQzJbFF/tb4E4z2wg8DSx3zrmBKro3Lb5CQh015/MjRUQGrUQOqOKcWwus7THvK92ebwWWJLe0M9MWLCKnqR46I+BNaLNERDJWRlyhCrH+Mh4ctGhoRkQkY8KdnPgB2ibdkUlEJGPC3ZcbC/dO3W5PRCRzwj2QrxYEIiInZUy4Z8X7y7Qo3EVEMifcw4XDiDgPbfVqQSAikjHhXpwXooY8og0acxcRyZxwzwlw3IV1toyICBkU7gXZAU4QxtdanepSRERSLmPC3esxGr3qLyMiAhkU7gDN/iKyI+ovIyKSUeHeHioiO9oEkbZUlyIiklIZFe7RUHHsiW7aISJDXEaFu8VbENCscBeRoS2jwt2bNwyADl3IJCJDXEaFezB/BABNJ9SCQESGtowK95yieH+Z2qoUVyIikloZFe4FBUW0OR8dDQp3ERnaEgp3M1tqZjvMbLeZ3dfL6/9qZm/GHzvNrDb5pfavKDdENWH1lxGRIa/fm42amRd4FPgAUAmsN7M18fumAuCc+5tuy/8VMG8Aau1XcW6AfS5Mvs6WEZEhLpE990XAbufcHudcO7ASuKGP5e8Ank5GcWcqL+ijhnx8rWpBICJDWyLhPho40G26Mj7vPcxsHDAB+N25l3bmzIwmXwGhdjUPE5GhLdkHVG8HVjvnOnt70cxWmFmFmVUcOzYw4+ItgSJyIikZ8hcRGTQSCfeDwJhu02Xxeb25nT6GZJxzTzjnyp1z5aWlpYlXeQY6QkUEXSu0Nw3I+kVE0kEi4b4emGxmE8wsQCzA1/RcyMymAYXAq8kt8cxEs0piT9RfRkSGsH7D3TkXAe4GngO2Aaucc1vM7EEzW9Zt0duBlc45NzClJsZzsr+Mwl1EhrB+T4UEcM6tBdb2mPeVHtMPJK+ss+cLx65Sba07QqgsxcWIiKRIRl2hChDKjzUPa67RVaoiMnRlXLjnFsWah7XUqnmYiAxdGRfuhQUFNLkgEbX9FZEhLOPCvTg3yAkXxumAqogMYZkX7jkBqglj6i8jIkNYxoV7yO+l1vIJtKkFgYgMXRkX7gBNvkKy2tU8TESGrowM97ZgETmROkjt9VQiIimTkeHeESrGTwe01ae6FBGRlMjIcHfZ6i8jIkNbRob7yf4yrlHnuovI0JSR4R6I95dprlW4i8jQlJHhnlUYC/emmsMprkREJDUyMtxzC2P9Zdpr1TxMRIamjAz3ovxc6l02kYaBuZWfiMhgl5HhXpwb4LgL45oU7iIyNGVkuBdmx/rL+FrUgkBEhqaMDHe/10ODR/1lRGToSijczWypme0ws91mdt9plrnVzLaa2RYz+0lyyzxzzf4isjpqUl2GiEhK9HsPVTPzAo8CHwAqgfVmtsY5t7XbMpOBvweWOOdqzGzYQBWcqLZgEbmNdRCNgicj/0ARETmtRFJvEbDbObfHOdcOrARu6LHMncCjzrkaAOdcyq8e6swqxkMUWrT3LiJDTyLhPho40G26Mj6vuynAFDP7g5m9ZmZLe1uRma0wswozqzh2bIDPZMmOtSBAN+0QkSEoWeMVPmAycAVwB/CkmRX0XMg594Rzrtw5V15aWpqkj+7dyf4ykYaU/xEhInLeJRLuB4Ex3abL4vO6qwTWOOc6nHPvADuJhX3KBPNjw/5NJ46ksgwRkZRIJNzXA5PNbIKZBYDbgTU9lnmW2F47ZlZCbJhmTxLrPGNZhSMBaKlVfxkRGXr6DXfnXAS4G3gO2Aascs5tMbMHzWxZfLHngGoz2wqsA77onEvpSeZ5hcOIOqND4S4iQ5C5FN2Krry83FVUVAzY+ncfbaTu365gWm4zOX+7Ebz+AfssEZHzxcw2OOfK+1suY08AL8kN8G+RG8lpPgibfprqckREzquMDfdwyM9LzONQ9lR46VvQGUl1SSIi503GhrvHY1w5bTjfbL4eTuyBLT9PdUkiIudNxoY7wN1XTuLZ1gupzpkELz4Ua0UgIjIEZHS4zx1TwKWTh/FQy/VwfAds+0WqSxIROS8yOtwhtvf+0+YF1GWP1967iAwZGR/uiy8opnx8CY+0L4OqzbBjbapLEhEZcBkf7gB3vX8SP2pcSEP2GHjxG5Cic/tFRM6XIRHu75tcwozRRTwWWQaHN8Ku36a6JBGRATUkwt3MuPv9k/he/WKas0Zq711EMt6QCHeAD0wfzoThBTzpboDK9bDnhVSXJCIyYIZMuHs8xl1XTuKx2otoDQ2DF7+Z6pJERAbMkAl3gOtmj2RkcQE/8twA+/4Ae/+Q6pJERAbEkAp3n9fDZ6+YyMMnLqE9WBwbexcRyUBDKtwBbppXRlF+Piv9N8bG3Q+sT3VJIiJJN+TCPeDz8OnLJ/L140voCBZq711EMtKQC3eA2xaOITs3n2dCN8Ku38ChN1JdkohIUiUU7ma21Mx2mNluM7uvl9eXm9kxM3sz/vhU8ktNnpDfy6cum8CDVUuIBMKxnjMiIhmk33A3My/wKPAhYAZwh5nN6GXRnzrnLow/vpfkOpPuYxeNw5uVz9rsG2H7f8GRzakuSUQkaRLZc18E7HbO7XHOtQMrgRsGtqyBlxv08Ykl4/nykUvp9OfCS9p7F5HMkUi4jwYOdJuujM/r6WYz22Rmq81sTFKqG2DLLxlPZyCf5/NuhC3PwO7nU12SiEhSJOuA6i+B8c65OcBvgR/2tpCZrTCzCjOrOHbsWJI++uwVZAf484vHc+/hq2gvmgY//zTUH051WSIi5yyRcD8IdN8TL4vP6+Kcq3bOtcUnvwcs6G1FzrknnHPlzrny0tLSs6k36T556QSi3iy+XfQl6GiGn98J0c5UlyUick4SCff1wGQzm2BmAeB2YE33BcxsZLfJZcC25JU4sErzgtyxaCyPb/Wzs/wB2PsS/F7nvotIeus33J1zEeBu4Dliob3KObfFzB40s2Xxxe4xsy1mthG4B1g+UAUPhL+9ZgqTSnO5+ZXx1E29BX7/L7Dn96kuS0TkrJlLUV/z8vJyV1FRkZLP7s3B2hZufPQP5Hva+HXO/fja6+EzL0PusFSXJiLSxcw2OOfK+1tuSF6h2pvRBVl8/+PlVDZ7+Ly7F9daBz9foRtqi0haUrh3M6esgG/fPo9fHinkx4Wfgz3r4OVvpbosEZEzpnDv4YMzR/C/rp3Olw8sYEvxNbDun2DfK6kuS0TkjCjce/HJSyfw5xeN57aDt9KQVQarPwlN1akuS0QkYQr3XpgZ918/g/Kp4/ho3WeINh2DZz6t8XcRSRsK99PweT3820fn01E6m691/gXs/i28+n9SXZaISEIU7n3IDfr4wScW8l+BD/E7z8W4578KB/6Y6rJERPqlcO/HyPwsvr98Efd13EmVlRD9z+XQfCLVZYmI9EnhnoBZo/P5pzsu5dOtdxOtr8I99UH1fxeRQU3hnqCrZwznxuuu5y/a/ye1J44TffJK+OOTkKIrfEVE+qJwPwOfWDKBD/6P27iZb/JC+wxY+wXqf3CLhmlEZNBRuJ+hj18ynjV/dyPbrvwe32A5oX2/o+ZbC9lb8VyqSxMR6aJwPwu5QR93vX8yn7nvIVbP+yF1ET9jf3kba79zNzsP16S6PBERhfu5CIf8fPTG6yn8m1fZOuw6rj3xH9R/9xru/9GvePtYY6rLE5EhTC1/k6hx/dP4f/V5WjuNL0U+RWjuzXz+mimMLshKdWkikiHU8jcFchfeQfDuV8geOZVH/d/mos338+GH1vAvv95OfWtHqssTkSFE4Z5sRRPw3/kbWHIvH/G+yO8D9+J56Vss/cav+eEre+noVH8aERl4CveB4PXDB76Kfe41QlPezxf9q/gv7mH7f3+HpQ+v47ktR0jVcJiIDA0JhbuZLTWzHWa228zu62O5m83MmVm/40FDQukUuP3H8Je/oXD0FP7Z/31+0PJX/PzHj3Pr46/wxn6dWSMiA6PfA6pm5gV2Ah8AKoH1wB3Oua09lssD/hsIAHc75/o8WpqJB1T75BzsWIv77QNY9U42MpUH225n5Owr+Lul0xhTlJ3qCkUkDSTzgOoiYLdzbo9zrh1YCdzQy3L/APwL0HpGlQ4VZjDtOuxzr8L132Z2bi0/C36Vm3Z8gRXf+gn3rnyD322v0pi8iCSFL4FlRgMHuk1XAou7L2Bm84Exzrn/NrMvJrG+zOP1wYLleGbfAq89xpUvP8KV9kVe3D6PZzct5n8HF/O+2RNZNncUiycU4fFYqisWkTSUSLj3ycw8wMPA8gSWXQGsABg7duy5fnR6C+TA+76IZ8Ffwivf5vJN/8kVDY/RwZO8tHEuKysW8b9yLubKuZNYNncUc8ryMVPQi0hiEhlzvxh4wDn3wfj03wM45/45Pp0PvA2cvCRzBHACWNbXuPuQG3PvTzQKleth67NEtzyDp+EwHebn951zWRNZzM6CS7l67kRuuHAUk4fnpbpaEUmRRMfcEwl3H7EDqlcBB4kdUP2oc27LaZZ/AfiCDqieg2gUKv8IW+JB33iEdvys65zLrzsX0jhiEe9fvIDrLxxNbvCc//gSkTSStHCPr+xa4BHACzzlnPuamT0IVDjn1vRY9gUU7snTFfTP0Ln5GbxNVQAcdkW86abSNmohU8qvZvq8SzCvP8XFishAS2q4DwSF+1mIRqHqLdz+16jd/hJW+ToFHUcBaCHEiYLZFEy7jJxJS6BsIYTyU1ywiCSbwn2IaD62j42v/Jqa7S8xtuktpts+vOZwGB3D5xCdcAWeSe/HP+4izB9Kdbkico4U7kPQ7qONPPv6dva88XumtG9miWcz82w3PovS7IL8ielUeOeyMTiPI8GJZAd9ZAe8DA+HuH7uKC6dVIJXp16KDGoK9yGsozPKy7uOc6S+lY7mOoqO/ZGR1a8xtuZ1Stv2AVDnLeSt4Dz+5L2QtQ2T2NEaZng4m5vmj+bm+WVMGpab4q0Qkd4o3KV3dZWw5wV4e13sZ/NxACLeLA55RvFWawl73Ehc4QVMmTGPSxYtJlw8vO91RjuhtS52L9mWGmg5AZE2yC6C7OLYI6sodgGXiJwThbv0LxqFo1vgwOtQvQeqdxM5vgtP7T48rrNrsSZPmEjhBeSNmhLrV9ESD/GTYd5aByTw31GoAHJK3g387OLY9PBZMPZiyB89UFsqkjESDXftSg1lHg+MmB17xPkAOjtwNXvZv3MT27e8Qd3B7Yw+eogLqn+Hx+un0RumxZtHs/cCWkJhWnPzaQ/k0+bPpyNQQEewgPzcbC4f46fAxffom45Dc3X8cRxq98PBP8Wmo/EbmRSMg3FLYNzFsZ9FF8R68ojIGdOeu/SrLdLJuu1HWbPxEMcb22mPROnojD1izx3tp0zH5vk8xtXTh3P7ojFcNrm094O1nRGoegv2vQr7/gD7X40FPkDOMBh3ybuPYTPA4z2/Gy8yyGhYRlJq99FGVlUcYPWGSk40tTO6IItbysu4tXwMo/q6p6xzcHwn7Hsl9tj/KtTF+tZFg2GsrBwbszh2Hn9Zuc7llyFH4S6DQnskym+3VrFy/X5e3h07eHv5lFJuXziWq6YPw+89tet0a0cnu482srOqgZ1VjeyqaqDuyNuU1W9koWcHC7y7mGIH8BA7l78hPJlo2UJyJl6Cf9xFUDwx4aGcSGeUqIOATzckk/ShcJdB58CJZlZVHGBVxQGq6tsoyQ1y8/zR+LzGzqpYoO8/0czJ/yT9XuOCklwmD89lyvA8CrP97DnexMGqo2RXvcHY5i0s8Oxknmc3YWsGoMET5kh4Ng2BYTS5EI0uSEM0SF1ngLrOALWRINXtfqo7fNREgrQSoDAvh+GFeYwsCjO6OMzo4nzGFOcxriSHguxACv/FRN5L4S6DVqQzyu93HmPl+gP8bvtRDBhfksPU4XldQT5leC7jinPes2ffXXN7hD3Hmnj7aD01+97Cf2g9xTUbmdi+g0IayLFWsmg76zrbnI+I+Yiaj6gnQGNoFM0FU4iWTic0ehaFEy4kr3ikWjHLeaVwl7RQ39pByOdN6tCIc+7dwI12QntTt0fjqc87mqGzI/aIdtDe3kpdYxP1jS00NjfT2NxCc0sr7a1NFLcfYortp8gauz6r2oU54BvPsewLaMqfTLR0OvkjJzJ7dCHDwsHYMYSTp4n2fG4GucNjN1Q/V50dULUZDm+KXV9QOg0KJ+jaggykUyElLYRDye9kecqetMcLoXDskYAAUBp/9BSNOo43trL5SCVNB96i88hWgie2U9C4mykNz5Hd8GzsPmVvnEGx3kAsiIfPghGz4j9nxwL6dJyDmr1wcMO7j8MbIdLjDpfeABRPhtKpsc84+bN4YnJ+ocigpj13kWSIRqHuAK0HN3P80DvsO9HM/hPN7Ktupq41gsMIej2UFWczrjiXccXZjC0MkVW/N7bHfWQzNB19d315o+JhPzMW+KH82HUBJ8M8fmVx1BukOjyDXf6pvN4+gRfqRzEq1MZFeceY4TvM2M79FDW/g69+P3byrwaPD4onxU4tHbMYxiyh1wJfAAAJRElEQVSK/UJR4KcFDcuIDALOOQ7WtrBhXw0b9tVQsbeG7UfqicZHZYpzApTkBinJDTI+1MgM288F0XcY2bqH4qad5NS9jblIbF0YNdkT2Ombyuvt43m+YQzbOsuIEGsAN3VEHlOG5XGiuZ1dVQ3s63ZwOs/bzvsKa1iUe4yZ/sOM6dxPcf02fI2HYgv4s2H0gljQnzzVtK+/HiRlFO4ig1RjW4Q399fyxv4aDtW1cryxjeONbRxriP1s7Yh2LRugg0l2kDxa2OLG0Ug2owuymD4yj+kjw12PcUXZ77mZekt7J28fe/e00tjPBiprWrqWmRSq5frCSi4OvM3kti0U1G3DTraeKJkaD/tFkFUI5gEs9tOs23Pefc3jA18w/gjFfnp7TJ/BhWgn80kHrd+lcBdJQ845mto7Od7QxrHGNo7HA78z6pg2Msz0EWHys89t+KSpLcKuo41sP1zPWwfr2Hywjm1HGmiPRMmilYuC+7kmvJdyz07GNW8m0FGfpK2LiZ195CdqXpyLXa/gXOxQs3MQjf90QNRBMyEOBcbTEJ4Cw2eSP24OZZPnMqwwnHjot9QSOf42jYd30nJkFxZpJTxuDtlj5saGqNJoSErhLiIJ6+iMsrOqgc0H63jrYB1vHaxn2+F6OiIRxtpRcmiN7aATxQBPt58AhsODw2edBIgQpD3+s4OgdRCgI/acDoIWe80Ti3V8HiPoMwI+LwGvEfB5CHg9BHxGwOsh2FFLuGE3IzsO4Cc2RBVxHvbZSKpCE2kpnIpv5EwKx86is7mGtqO78ZzYQ6BhH+HmA5S0HyTsTv0FFXEefBarvQM/x7PG01I0Hf/IWRRdMJ+csXMhd1ji/4DR+F87CfxV4pwj6jjreyck+x6qS4FvE7uH6vecc1/v8fpngLuATqARWOGc29rXOhXuIoNbR2eUXVWNbDtcT1skisfA4zE8ZrHnZlj8p9cTm2dm+L2Gz+PB5zF8Xg9eT7d5XovN93jIDfnIC/n6vJbhFJ0d1FVuo2r3G7RUbsJfvZ2ixt2MiB55z6JRZxy2Eqp8o6kJldGcM5bOwgl4iieSM3wSmIfq/VtoP/QWWSe2Udr8NlPYxwir6VpHnSefE9kT8fiDhGgnQAd+144v2oYv2oansx3rbMUirRCNAIbLHUYkexgtwVLqfSVUewo5Ei2gsiOfPa157GzOYVtDiK/cMIdby8ec1feStHA3My+wE/gAsRO91gN3dA9vMws7F/vVaGbLgM8555b2tV6Fu4gkg2tr4Ojbm6jevxlfdhG5I6dQVDaZUFZ2wuuIRmMHvvcd2E/NO2/SeWQzObU7GNbyDs5FaSVAm/PThr/reSsB2vAT9caOJ/jpJLv9OCXUMtxqGG41FFOPx07N2CgeKi/5B8Zec/dZbW8yz3NfBOx2zu2Jr3glcAPQFe4ngz0uh4Sae4uInDsL5jF8xhKGz1hy1uvweIwxRdmMKZoGc6ed8lpbpJO6lg7qWzqoiz86WzrobO6gqSXSNc85x4j8EO35ITrDIVx+Fi7PS4mrw9N0BBqqoOEwnsYqxk65+Fw3u1+JhPto4EC36Upgcc+FzOwu4PPErgN5f28rMrMVwAqAsWPHnmmtIiLnXdDnZViel2F5Z3uD+VwoOP83oknaNd/OuUedcxOBvwO+fJplnnDOlTvnyktLe7sGUEREkiGRcD8IdB/5L4vPO52VwI3nUpSIiJybRMJ9PTDZzCaYWQC4HVjTfQEzm9xt8jpgV/JKFBGRM9XvmLtzLmJmdwPPETsV8inn3BYzexCocM6tAe42s6uBDqAG+PhAFi0iIn1LqCukc24tsLbHvK90e/7XSa5LRETOge4vJiKSgRTuIiIZSOEuIpKBUtY4zMyOAfvO8u0lwPEkljMYZNo2Zdr2QOZtU6ZtD2TeNvW2PeOcc/1eKJSycD8XZlaRSG+FdJJp25Rp2wOZt02Ztj2Qedt0LtujYRkRkQykcBcRyUDpGu5PpLqAAZBp25Rp2wOZt02Ztj2Qedt01tuTlmPuIiLSt3TdcxcRkT6kXbib2VIz22Fmu83svlTXc67MbK+ZvWVmb5pZWt6aysyeMrOjZra527wiM/utme2K/yxMZY1n4jTb84CZHYx/T2+a2bWprPFMmdkYM1tnZlvNbIuZ/XV8flp+T31sT9p+T2YWMrM/mtnG+DZ9NT5/gpm9Hs+8n8YbOPa/vnQalknkln/pxsz2AuXOubQ9N9fM3kfs3rk/cs7Nis/7BnDCOff1+C/hQufc36WyzkSdZnseABqdcw+lsrazZWYjgZHOuT+ZWR6wgVhr7uWk4ffUx/bcSpp+T2ZmQI5zrtHM/MDLwF8TuwnSz51zK83scWCjc+67/a0v3fbcu27555xrJ9Y7/oYU1zTkOedeBE70mH0D8MP48x+SRj3+T7M9ac05d9g596f48wZgG7G7rKXl99TH9qQtF9MYn/THH47Yne1Wx+cn/B2lW7j3dsu/tP5CiX15vzGzDfHbEGaK4c65w/HnR4DhqSwmSe42s03xYZu0GL7ojZmNB+YBr5MB31OP7YE0/p7MzGtmbwJHgd8CbwO1zrlIfJGEMy/dwj0TXeqcmw98CLgrPiSQUVxs7C99xv96911gInAhcBj4VmrLOTtmlgv8DLi3x43t0/J76mV70vp7cs51OucuJHbHu0XAtH7eclrpFu5nesu/Qc85dzD+8yjwDLEvNBNUxcdFT46PHk1xPefEOVcV/x8vCjxJGn5P8XHcnwE/ds79PD47bb+n3rYnE74nAOdcLbAOuBgoMLOT995IOPPSLdz7veVfOjGznPjBIMwsB7gG2Nz3u9LGGt69I9fHgV+ksJZzdjIA424izb6n+MG67wPbnHMPd3spLb+n021POn9PZlZqZgXx51nEThzZRizkPxJfLOHvKK3OlgGIn9r0CO/e8u9rKS7prJnZBcT21iF2V6yfpOP2mNnTwBXEOthVAfcDzwKrgLHEun/e6pxLi4OUp9meK4j9qe+AvcCnu41VD3pmdinwEvAWEI3P/hKxceq0+5762J47SNPvyczmEDtg6iW2473KOfdgPCdWAkXAG8DHnHNt/a4v3cJdRET6l27DMiIikgCFu4hIBlK4i4hkIIW7iEgGUriLiGQghbuISAZSuIuIZCCFu4hIBvr/Bp8YiRsf2eIAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')\n", "plt.plot(val_loss_hist, label = 'validation')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc : 91.73%\n" ] } ], "source": [ "yhat = np.argmax(sess.run(score, feed_dict = {X : x_tst}), axis = 1)\n", "print('acc : {:.2%}'.format(np.mean(yhat == y_tst)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }