{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 03 : Linear and Logistic Regression\n",
    "### Linear Regression with tf.data\n",
    "\n",
    "**Reference**\n",
    "\n",
    "* https://jhui.github.io/2017/11/21/TensorFlow-Importing-data/\n",
    "* https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428\n",
    "* https://stackoverflow.com/questions/47356764/how-to-use-tensorflow-dataset-api-with-training-and-validation-sets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from pprint import pprint\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build input pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_6.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_7.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_5.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_4.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_14.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_10.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_1.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_11.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_13.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_3.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_2.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_12.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_9.txt',\n",
      " '../data/lecture03/example_with_data/train_dir/birth_life_2010_tr_8.txt']\n"
     ]
    }
   ],
   "source": [
    "train_dir = os.listdir('../data/lecture03/example_with_data/train_dir/')\n",
    "train_dir = list(map(lambda path : '../data/lecture03/example_with_data/train_dir/' + path, train_dir))\n",
    "pprint(train_dir, compact = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'../data/lecture03/example_with_data/val_dir/birth_life_2010_val.txt'\n"
     ]
    }
   ],
   "source": [
    "val_dir = '../data/lecture03/example_with_data/val_dir/birth_life_2010_val.txt'\n",
    "pprint(val_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyper parameters\n",
    "epochs = 100\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# datasets construction\n",
    "# for training dataset\n",
    "tr_dataset = tf.data.TextLineDataset(filenames = train_dir)\n",
    "tr_dataset = tr_dataset.map(lambda record : tf.decode_csv(records = record,\n",
    "                                                          record_defaults = [[''],[.0],[.0]],\n",
    "                                                          field_delim = '\\t')[1:])\n",
    "tr_dataset = tr_dataset.shuffle(200)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "\n",
    "# for validation dataset\n",
    "val_dataset = tf.data.TextLineDataset(filenames = val_dir)\n",
    "val_dataset = val_dataset.map(lambda record : tf.decode_csv(records = record,\n",
    "                                                          record_defaults = [[''],[.0],[.0]],\n",
    "                                                          field_delim = '\\t')[1:])\n",
    "val_dataset = val_dataset.batch(batch_size = batch_size)\n",
    "val_iterator = val_dataset.make_initializable_iterator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# handle constructions. Handle allows us to feed data from different dataset by providing a parameter in feed_dict \n",
    "handle = tf.placeholder(dtype = tf.string)\n",
    "iterator = tf.data.Iterator.from_string_handle(string_handle = handle,\n",
    "                                               output_types = tr_iterator.output_types)\n",
    "X, Y = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n",
    "sess = tf.Session(config = sess_config)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the graph of Simple Linear Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create weight and bias, initialized to 0 \n",
    "w = tf.get_variable(name = 'weight', initializer = tf.constant(.0))\n",
    "b = tf.get_variable(name = 'bias', initializer = tf.constant(.0))\n",
    "\n",
    "# construct model to predict Y\n",
    "yhat = X * w + b\n",
    "\n",
    "# use the square error as loss function\n",
    "mse_loss = tf.reduce_mean(tf.square(Y - yhat))\n",
    "mse_loss_summ = tf.summary.scalar(name = 'mse_loss', tensor = mse_loss) # for tensorboard\n",
    "\n",
    "# using gradient descent with learning rate of 0.01 to minimize loss\n",
    "opt = tf.train.GradientDescentOptimizer(learning_rate=.01)\n",
    "training_op = opt.minimize(mse_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/linreg_mse_with_tf_data/train',\n",
    "                                     graph = tf.get_default_graph())\n",
    "val_writer = tf.summary.FileWriter(logdir = '../graphs/lecture03/linreg_mse_with_tf_data/val',\n",
    "                                     graph = tf.get_default_graph())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   0, tr_loss : 1757.168, val_loss : 1228.869\n",
      "epoch :  10, tr_loss : 376.397, val_loss : 347.431\n",
      "epoch :  20, tr_loss : 119.488, val_loss : 131.623\n",
      "epoch :  30, tr_loss : 57.805, val_loss : 56.236\n",
      "epoch :  40, tr_loss : 40.083, val_loss : 41.398\n",
      "epoch :  50, tr_loss : 35.828, val_loss : 37.847\n",
      "epoch :  60, tr_loss : 33.837, val_loss : 38.415\n",
      "epoch :  70, tr_loss : 33.972, val_loss : 37.400\n",
      "epoch :  80, tr_loss : 33.842, val_loss : 37.901\n",
      "epoch :  90, tr_loss : 33.386, val_loss : 38.766\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "# hyper parameters\n",
    "epochs = 100\n",
    "batch_size = 8\n",
    "'''\n",
    "\n",
    "tr_loss_hist = []\n",
    "val_loss_hist = []\n",
    "\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "tr_handle, val_handle = sess.run(fetches = [tr_iterator.string_handle(), val_iterator.string_handle()])\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    avg_tr_loss = 0\n",
    "    avg_val_loss = 0\n",
    "    tr_step = 0\n",
    "    val_step = 0\n",
    "    \n",
    "    # for mini-batch training\n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss, tr_loss_summ = sess.run(fetches = [training_op, mse_loss, mse_loss_summ], feed_dict = {handle : tr_handle})\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "\n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    # for validation\n",
    "    sess.run(val_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            val_loss, val_loss_summ = sess.run(fetches = [mse_loss, mse_loss_summ], feed_dict = {handle : val_handle})\n",
    "            avg_val_loss += val_loss\n",
    "            val_step += 1\n",
    "    \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    train_writer.add_summary(tr_loss_summ, global_step = epoch)\n",
    "    val_writer.add_summary(val_loss_summ, global_step = epoch)\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    avg_val_loss /= val_step\n",
    "    \n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    val_loss_hist.append(avg_val_loss)\n",
    "    \n",
    "    if epoch % 10 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.3f}, val_loss : {:.3f}'.format(epoch, avg_tr_loss, avg_val_loss))\n",
    "\n",
    "train_writer.close()\n",
    "val_writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x11e0651d0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')\n",
    "plt.plot(val_loss_hist, label = 'validation')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x11e0e7ef0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = pd.read_table('../data/lecture03/example_with_placeholder/birth_life_2010.txt') # loading data for Visualization\n",
    "w_out, b_out = sess.run([w, b])\n",
    "plt.plot(data.iloc[:,1], data.iloc[:,2], 'bo', label='Real data')\n",
    "plt.plot(data.iloc[:,1], data.iloc[:,1] * w_out + b_out, 'r', label='Predicted data')\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}