{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 04 : Eager execution\n",
    "### Custon training walkthrough\n",
    "Categorizing Iris flowers by species by using Tensorflow's eager execution.\n",
    "\n",
    "This guide uses these high-level TensorFlow concepts:\n",
    "\n",
    "* Enable an [eager execution](https://www.tensorflow.org/guide/eager?hl=ko) development environment,\n",
    "* Import data with the [Datasets API](https://www.tensorflow.org/guide/datasets?hl=ko)\n",
    "* Build models and layers with TensorFlow's [Keras API](https://keras.io/getting-started/sequential-model-guide/)  \n",
    "  \n",
    "  \n",
    "* Reference\n",
    "    + https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough?hl=ko"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import os, sys\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "tf.enable_eager_execution()\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Dense(10, activation = tf.nn.relu, input_shape = (4,)))\n",
    "model.add(tf.keras.layers.Dense(10, activation = tf.nn.relu))\n",
    "model.add(tf.keras.layers.Dense(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(model, features, label):\n",
    "    score = model(features)\n",
    "    return tf.losses.sparse_softmax_cross_entropy(labels = label, logits = score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<tf.Variable 'dense/kernel:0' shape=(4, 10) dtype=float32, numpy=\n",
      "array([[-0.6016106 ,  0.21193182,  0.15714651,  0.15312278, -0.28133038,\n",
      "         0.00746506, -0.148848  , -0.1853224 ,  0.18714899, -0.25901568],\n",
      "       [ 0.02185094, -0.08642298,  0.5632899 , -0.43347245,  0.1226458 ,\n",
      "        -0.49662638, -0.23893097,  0.49837744,  0.39349937, -0.3385544 ],\n",
      "       [-0.27739072, -0.31779853,  0.64550364,  0.33254373,  0.23269486,\n",
      "        -0.2784947 ,  0.06808609,  0.21770716,  0.23193008, -0.16453654],\n",
      "       [-0.5709686 , -0.24807453,  0.17099643, -0.20884967, -0.2562401 ,\n",
      "         0.34987652,  0.52932966, -0.1443029 ,  0.42169094,  0.46493506]],\n",
      "      dtype=float32)>, <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'dense_1/kernel:0' shape=(10, 10) dtype=float32, numpy=\n",
      "array([[ 0.4289506 ,  0.10450268, -0.04721606, -0.43415242, -0.47025183,\n",
      "        -0.361524  ,  0.0418207 , -0.0864554 , -0.0556342 ,  0.15549362],\n",
      "       [-0.44529125,  0.37460685,  0.01090676, -0.05481359, -0.27588695,\n",
      "         0.29222095, -0.41036385, -0.3664556 , -0.53682727,  0.31951696],\n",
      "       [-0.12785482,  0.03862995,  0.35421294,  0.01573461,  0.34672868,\n",
      "        -0.3728515 ,  0.2816208 , -0.29301217,  0.11257732, -0.04643917],\n",
      "       [-0.03176916,  0.5232763 ,  0.35986876,  0.33061004,  0.26544142,\n",
      "        -0.22060388, -0.14162892,  0.17554426, -0.49126202,  0.08825725],\n",
      "       [-0.32824028,  0.21503484, -0.54436713, -0.50431615, -0.16609886,\n",
      "        -0.20675969, -0.00118977, -0.23783684,  0.11200953,  0.393723  ],\n",
      "       [ 0.04101104, -0.14152196, -0.47661275, -0.46934345,  0.10372204,\n",
      "         0.5409229 ,  0.27753627, -0.10152608,  0.41975296, -0.2332783 ],\n",
      "       [-0.33580142,  0.4723308 , -0.24120337,  0.18304974, -0.05336449,\n",
      "        -0.37312585,  0.2332313 , -0.40402335,  0.36896586, -0.21605268],\n",
      "       [-0.32411113,  0.01418477, -0.23935765,  0.521438  ,  0.44553947,\n",
      "         0.3212368 ,  0.02267909,  0.17611271, -0.22656313, -0.45787033],\n",
      "       [-0.40987664, -0.01963365,  0.25425416, -0.12491649, -0.5176469 ,\n",
      "         0.35165775,  0.4348483 ,  0.48482704,  0.08984548, -0.42127824],\n",
      "       [ 0.4564724 ,  0.3914752 , -0.29007906,  0.18757671,  0.19446594,\n",
      "        -0.36466774, -0.11595219, -0.1236656 ,  0.26194257,  0.1289767 ]],\n",
      "      dtype=float32)>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'dense_2/kernel:0' shape=(10, 3) dtype=float32, numpy=\n",
      "array([[-0.55178356,  0.5969744 , -0.20873287],\n",
      "       [ 0.6190039 , -0.36399648,  0.49799395],\n",
      "       [-0.02199602, -0.13632727, -0.6334673 ],\n",
      "       [-0.5608484 ,  0.47014666,  0.43292272],\n",
      "       [ 0.24740952,  0.5413171 ,  0.3306861 ],\n",
      "       [-0.6199806 ,  0.5774143 , -0.6118181 ],\n",
      "       [-0.02008146, -0.57885826, -0.11845094],\n",
      "       [-0.45090094, -0.34696335, -0.42613342],\n",
      "       [-0.21294937,  0.35269964, -0.12020266],\n",
      "       [-0.20987946,  0.14686137, -0.5125678 ]], dtype=float32)>, <tf.Variable 'dense_2/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]\n"
     ]
    }
   ],
   "source": [
    "print(model.trainable_variables) # different from tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import and parse the training dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['iris_test.csv', 'iris_training.csv']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir('../data/lecture04/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define parsing function\n",
    "def parse_single_example(record):\n",
    "    decoded = tf.decode_csv(record, [[.0],[.0],[.0],[.0],[]])\n",
    "    features = decoded[:4]\n",
    "    label = tf.cast(x = decoded[4], dtype = tf.int32)\n",
    "    return features, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "batch_size = 8\n",
    "learning_rate = .03\n",
    "tr_dataset = tf.data.TextLineDataset(filenames = '../data/lecture04/iris_training.csv')\n",
    "tr_dataset = tr_dataset.map(parse_single_example)\n",
    "tr_dataset = tr_dataset.shuffle(200).batch(batch_size = batch_size)\n",
    "opt = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)\n",
    "global_step = tf.Variable(0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, ce_loss : 1.154\n",
      "epoch :   2, ce_loss : 0.829\n",
      "epoch :   3, ce_loss : 0.706\n",
      "epoch :   4, ce_loss : 0.591\n",
      "epoch :   5, ce_loss : 0.553\n",
      "epoch :   6, ce_loss : 0.553\n",
      "epoch :   7, ce_loss : 0.449\n",
      "epoch :   8, ce_loss : 0.451\n",
      "epoch :   9, ce_loss : 0.514\n",
      "epoch :  10, ce_loss : 0.395\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "    avg_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    for mb_x, mb_y in tr_dataset:\n",
    "        with tf.GradientTape() as tape:\n",
    "            tr_loss = loss_fn(model, mb_x, mb_y)\n",
    "        grads = tape.gradient(tr_loss, model.variables)\n",
    "        opt.apply_gradients(zip(grads, model.variables), global_step = global_step)\n",
    "        \n",
    "        avg_loss += tr_loss\n",
    "        tr_step += 1\n",
    "    else:\n",
    "        avg_loss /= tr_step\n",
    "    \n",
    "    print('epoch : {:3}, ce_loss : {:.3f}'.format(epoch + 1, avg_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate the model on the test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_dataset = tf.data.TextLineDataset(filenames = '../data/lecture04/iris_test.csv')\n",
    "tst_dataset = tst_dataset.map(parse_single_example)\n",
    "tst_dataset = tst_dataset.batch(batch_size = 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_x, tst_y = next(iter(tst_dataset))\n",
    "tst_yhat = tf.argmax(model(tst_x), axis = -1, output_type = tf.int32) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy : 96.67%\n"
     ]
    }
   ],
   "source": [
    "print('test accuracy : {:.2%}'.format(np.mean(tf.equal(tst_y, tst_yhat))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}