{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to Many Classification (Simple pos tagger) by Recurrent Neural Networks. \n",
    "\n",
    "### Many to Many Classification by RNN\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Training **many to many classification** with `tf.contrib.seq2seq.sequence_loss`\n",
    "- Masking unvalid token with `tf.sequence_mask`\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/aisolab/sample_code_of_Deep_learning_Basics/blob/master/DLEL/DLEL_12_2_RNN_(toy_example).ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [['I', 'feel', 'hungry'],\n",
    "     ['tensorflow', 'is', 'very', 'difficult'],\n",
    "     ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],\n",
    "     ['tensorflow', 'is', 'very', 'fast', 'changing']]\n",
    "pos = [['pronoun', 'verb', 'adjective'],\n",
    "     ['noun', 'verb', 'adverb', 'adjective'],\n",
    "     ['noun', 'verb', 'determiner', 'noun', 'preposition', 'adjective', 'noun'],\n",
    "     ['noun', 'verb', 'adverb', 'adjective', 'verb']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'<pad>': 0, 'I': 1, 'a': 2, 'changing': 3, 'deep': 4, 'difficult': 5, 'fast': 6, 'feel': 7, 'for': 8, 'framework': 9, 'hungry': 10, 'is': 11, 'learning': 12, 'tensorflow': 13, 'very': 14}\n"
     ]
    }
   ],
   "source": [
    "# word dic\n",
    "word_list = []\n",
    "for elm in sentences:\n",
    "    word_list += elm\n",
    "word_list = list(set(word_list))\n",
    "word_list.sort()\n",
    "word_list = ['<pad>'] + word_list\n",
    "\n",
    "word_dic = {word : idx for idx, word in enumerate(word_list)}\n",
    "print(word_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<pad>', 'adjective', 'adverb', 'determiner', 'noun', 'preposition', 'pronoun', 'verb']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'<pad>': 0,\n",
       " 'adjective': 1,\n",
       " 'adverb': 2,\n",
       " 'determiner': 3,\n",
       " 'noun': 4,\n",
       " 'preposition': 5,\n",
       " 'pronoun': 6,\n",
       " 'verb': 7}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pos dic\n",
    "pos_list = []\n",
    "for elm in pos:\n",
    "    pos_list += elm\n",
    "pos_list = list(set(pos_list))\n",
    "pos_list.sort()\n",
    "pos_list = ['<pad>'] + pos_list\n",
    "print(pos_list)\n",
    "\n",
    "pos_dic = {pos : idx for idx, pos in enumerate(pos_list)}\n",
    "pos_dic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '<pad>',\n",
       " 1: 'adjective',\n",
       " 2: 'adverb',\n",
       " 3: 'determiner',\n",
       " 4: 'noun',\n",
       " 5: 'preposition',\n",
       " 6: 'pronoun',\n",
       " 7: 'verb'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_idx_to_dic = {elm[1] : elm[0] for elm in pos_dic.items()}\n",
    "pos_idx_to_dic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('<pad>')] # 0 is idx of meaningless token \"<pad>\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-process data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 4, 7, 5] (4, 10)\n"
     ]
    }
   ],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = sentences, max_len = max_length, dic = word_dic)\n",
    "print(X_length, np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 10)\n"
     ]
    }
   ],
   "source": [
    "y = [elm + ['<pad>'] * (max_length - len(elm)) for elm in pos]\n",
    "y = [list(map(lambda el : pos_dic.get(el), elm)) for elm in y]\n",
    "print(np.shape(y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[6, 7, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       " [4, 7, 2, 1, 0, 0, 0, 0, 0, 0],\n",
       " [4, 7, 3, 4, 5, 1, 4, 0, 0, 0],\n",
       " [4, 7, 2, 1, 7, 0, 0, 0, 0, 0]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define SimPosRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimPosRNN:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, max_len, word_dic):\n",
    "        \n",
    "        # Data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(word_dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "            \n",
    "        # RNN cell (many to many)\n",
    "        with tf.variable_scope('rnn_cell'):\n",
    "            rnn_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim,\n",
    "                                                   activation = tf.nn.tanh)\n",
    "            score_cell = tf.contrib.rnn.OutputProjectionWrapper(cell = rnn_cell, output_size = n_of_classes)\n",
    "            self._outputs, _ = tf.nn.dynamic_rnn(cell = score_cell, inputs = self._X_batch,\n",
    "                                                 sequence_length = self._X_length,\n",
    "                                                 dtype = tf.float32)\n",
    "        \n",
    "        with tf.variable_scope('seq2seq_loss'):\n",
    "            masks = tf.sequence_mask(lengths = self._X_length, maxlen = max_len, dtype = tf.float32)\n",
    "            self.seq2seq_loss = tf.contrib.seq2seq.sequence_loss(logits = self._outputs, targets = self._y,\n",
    "                                                                 weights = masks)\n",
    "    \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._outputs,\n",
    "                                         axis = 2, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of SimPosRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 100\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 10)), types: (tf.int32, tf.int32, tf.int32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_pos_rnn = SimPosRNN(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n",
    "                        n_of_classes = 8, hidden_dim = 16, max_len = max_length, word_dic = word_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = sim_pos_rnn.seq2seq_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :  10, tr_loss : 1.716\n",
      "epoch :  20, tr_loss : 1.214\n",
      "epoch :  30, tr_loss : 0.785\n",
      "epoch :  40, tr_loss : 0.502\n",
      "epoch :  50, tr_loss : 0.347\n",
      "epoch :  60, tr_loss : 0.246\n",
      "epoch :  70, tr_loss : 0.183\n",
      "epoch :  80, tr_loss : 0.142\n",
      "epoch :  90, tr_loss : 0.109\n",
      "epoch : 100, tr_loss : 0.087\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, sim_pos_rnn.seq2seq_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[6, 7, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "       [4, 7, 2, 1, 0, 0, 0, 0, 0, 0],\n",
       "       [4, 7, 3, 4, 5, 1, 4, 0, 0, 0],\n",
       "       [4, 7, 2, 1, 7, 0, 0, 0, 0, 0]], dtype=int32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "yhat = sim_pos_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)\n",
    "yhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[6, 7, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       " [4, 7, 2, 1, 0, 0, 0, 0, 0, 0],\n",
       " [4, 7, 3, 4, 5, 1, 4, 0, 0, 0],\n",
       " [4, 7, 2, 1, 7, 0, 0, 0, 0, 0]]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['pronoun', 'verb', 'adjective', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['noun', 'verb', 'adverb', 'adjective', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['noun', 'verb', 'determiner', 'noun', 'preposition', 'adjective', 'noun', '<pad>', '<pad>', '<pad>']\n",
      "['noun', 'verb', 'adverb', 'adjective', 'verb', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n"
     ]
    }
   ],
   "source": [
    "yhat = [list(map(lambda elm : pos_idx_to_dic.get(elm), row)) for row in yhat]\n",
    "for elm in yhat:\n",
    "    print(elm)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}