{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to simply use tf.data\n",
    "`tf.data` package를 사용하는 방법에 관한 예시, 예시 데이터는 `numpy` package를 이용하여 간단하게 data에 해당하는 `X`, target에 해당하는 `y`를 생성하여 `tf.data` package의 각종 module, function을 이용한다. epoch 마다 validation data에 대해서 validation을 하는 상황을 가정"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Template\n",
    "for문을 활용, model을 training시 data pipeline으로써 아래의 function과 method를 사용하는 방법에 대한 예시\n",
    "- **Dataset class**\n",
    "    - `tf.data.Dataset.from_tensor_slices`으로 Dataset class의 instance를 생성\n",
    "        - train data에 대한 Dataset class의 instance, tr_data\n",
    "        - validation data에 대한 Dataset class의 instance, val_data\n",
    "    - 아래와 같은 method를 활용하여 training 시 필요한 요소를 지정\n",
    "        - instance의 `shuffle` method를 활용하여, shuffling\n",
    "        - instanec의 `batch` method를 활용하여, batch size 지정\n",
    "        - for문으로 전체 epoch를 control하므로 `repeat` method는 활용하지 않음\n",
    "\n",
    "\n",
    "- **Iterator class**\n",
    "    - Dataset class의 instance에서 `make_initializable_iterator` method로 Iterator class의 instance를 생성\n",
    "        - train data에 대한 iterator class의 instance, tr_iterator\n",
    "        - validation data에 대한 iterator class의 instance, val_iterator\n",
    "        - 주의사항 : `make_initializable_iterator` method로 Iterator class의 instance를 생성할 경우, random_seed를 고정 X\n",
    "            - random_seed를 고정할 경우, 서로 다른 epoch의 step 별 mini-batch의 구성이 완전히 똑같아지기 때문\n",
    "    - Anonymous iterator를 `tf.data.Iterator.from_string_handle`로 생성\n",
    "        - `string_handle` argument에 `tf.placeholder`를 이용\n",
    "            - tr_iterator를 활용할 것인지, val_iterator를 활용할 것인지 조절"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12, 2) (12,)\n"
     ]
    }
   ],
   "source": [
    "# 전체 데이터의 개수가 12개인 임의의 데이터셋 생성\n",
    "X = np.c_[np.arange(12), np.arange(12)]\n",
    "y = np.arange(12)\n",
    "\n",
    "print(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 2) (8,)\n",
      "(4, 2) (4,)\n"
     ]
    }
   ],
   "source": [
    "# 위의 데이터를 train, validation으로 split\n",
    "X_tr = X[:8]\n",
    "y_tr = y[:8]\n",
    "\n",
    "X_val = X[8:]\n",
    "y_val = y[8:]\n",
    "\n",
    "print(X_tr.shape, y_tr.shape)\n",
    "print(X_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch : 3, batch_size : 2, total_steps : 4\n"
     ]
    }
   ],
   "source": [
    "n_epoch = 3\n",
    "batch_size = 2\n",
    "total_steps = int(X_tr.shape[0] / batch_size)\n",
    "print('epoch : {}, batch_size : {}, total_steps : {}'.format(n_epoch, batch_size, total_steps))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 2), (?,)), types: (tf.int64, tf.int64)>\n",
      "<BatchDataset shapes: ((?, 2), (?,)), types: (tf.int64, tf.int64)>\n"
     ]
    }
   ],
   "source": [
    "tr_data = tf.data.Dataset.from_tensor_slices((X_tr, y_tr)) # 0th dimension의 size가 같아야\n",
    "tr_data = tr_data.shuffle(buffer_size = 30)\n",
    "tr_data = tr_data.batch(batch_size = batch_size)\n",
    "\n",
    "val_data = tf.data.Dataset.from_tensor_slices((X_val, y_val))\n",
    "val_data = val_data.batch(batch_size = batch_size)\n",
    "\n",
    "print(tr_data)\n",
    "print(val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_iterator = tr_data.make_initializable_iterator()\n",
    "val_iterator = val_data.make_initializable_iterator()\n",
    "\n",
    "handle = tf.placeholder(dtype = tf.string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = tf.data.Iterator.from_string_handle(string_handle = handle,\n",
    "                                               output_shapes = tr_iterator.output_shapes,\n",
    "                                               output_types = tr_iterator.output_types)\n",
    "X_mb, y_mb = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch : 1 training start\n",
      "step : 1\n",
      "[[0 0]\n",
      " [1 1]] [0 1]\n",
      "step : 2\n",
      "[[4 4]\n",
      " [3 3]] [4 3]\n",
      "step : 3\n",
      "[[5 5]\n",
      " [2 2]] [5 2]\n",
      "step : 4\n",
      "[[7 7]\n",
      " [6 6]] [7 6]\n",
      "epoch : 1 training finished\n",
      "at epoch : 1, validation start\n",
      "step : 1\n",
      "[[8 8]\n",
      " [9 9]] [8 9]\n",
      "step : 2\n",
      "[[10 10]\n",
      " [11 11]] [10 11]\n",
      "validation finished\n",
      "epoch : 2 training start\n",
      "step : 1\n",
      "[[3 3]\n",
      " [5 5]] [3 5]\n",
      "step : 2\n",
      "[[4 4]\n",
      " [0 0]] [4 0]\n",
      "step : 3\n",
      "[[1 1]\n",
      " [2 2]] [1 2]\n",
      "step : 4\n",
      "[[7 7]\n",
      " [6 6]] [7 6]\n",
      "epoch : 2 training finished\n",
      "at epoch : 2, validation start\n",
      "step : 1\n",
      "[[8 8]\n",
      " [9 9]] [8 9]\n",
      "step : 2\n",
      "[[10 10]\n",
      " [11 11]] [10 11]\n",
      "validation finished\n",
      "epoch : 3 training start\n",
      "step : 1\n",
      "[[4 4]\n",
      " [3 3]] [4 3]\n",
      "step : 2\n",
      "[[6 6]\n",
      " [1 1]] [6 1]\n",
      "step : 3\n",
      "[[0 0]\n",
      " [2 2]] [0 2]\n",
      "step : 4\n",
      "[[7 7]\n",
      " [5 5]] [7 5]\n",
      "epoch : 3 training finished\n",
      "at epoch : 3, validation start\n",
      "step : 1\n",
      "[[8 8]\n",
      " [9 9]] [8 9]\n",
      "step : 2\n",
      "[[10 10]\n",
      " [11 11]] [10 11]\n",
      "validation finished\n"
     ]
    }
   ],
   "source": [
    "# n_tr_step, n_val_step 변수와 관련된 코드는 step 수 확인을 위해 넣어놓음\n",
    "sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n",
    "sess = tf.Session(config = sess_config)\n",
    "tr_handle, val_handle = sess.run([tr_iterator.string_handle(), val_iterator.string_handle()])\n",
    "\n",
    "for epoch in range(n_epoch):\n",
    "    \n",
    "    print('epoch : {} training start'.format(epoch + 1))\n",
    "    sess.run(tr_iterator.initializer) # run tr_iterator\n",
    "    n_tr_step = 0\n",
    "    \n",
    "    while True:\n",
    "        try:\n",
    "            n_tr_step += 1\n",
    "            X_tmp, y_tmp = sess.run([X_mb, y_mb], feed_dict = {handle : tr_handle})\n",
    "            print('step : {}'.format(n_tr_step))\n",
    "            print(X_tmp, y_tmp)\n",
    "        \n",
    "        except:\n",
    "            print('epoch : {} training finished'.format(epoch + 1))\n",
    "            break\n",
    "\n",
    "    print('at epoch : {}, validation start'.format(epoch + 1))        \n",
    "    sess.run(val_iterator.initializer)\n",
    "    n_val_step = 0\n",
    "    while True:\n",
    "        try:\n",
    "            n_val_step += 1\n",
    "            X_tmp, y_tmp = sess.run([X_mb, y_mb], feed_dict = {handle : val_handle})\n",
    "            \n",
    "            print('step : {}'.format(n_val_step))\n",
    "            print(X_tmp, y_tmp)\n",
    "        except:\n",
    "            print('validation finished')\n",
    "            break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}