{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 12 : Seq2Seq with Attention\n",
    "Simple example for Seq2Seq (Machine Translation) with Attention by Encoder RNN and Decoder RNN.\n",
    "\n",
    "### Seq2Seq with Attention by Encoder RNN and Decoder RNN\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Training **many to many classification** with `tf.contrib.seq2seq.sequence_loss`\n",
    "- Masking unvalid token with `tf.sequence_mask`\n",
    "- Using **attention mechanism** by `tf.contrib.seq2seq.LuongAttention`, `tf.contrib.seq2seq.AttentionWrapper`\n",
    "- Using `tf.contrib.seq2seq.dynamic_decode`\n",
    "- Training with `tf.contrib.seq2seq.TrainingHelper`\n",
    "- Translating with `tf.contrib.seq2seq.GreedyEmbeddingHelper`\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/03%20-%20Seq2Seq.py\n",
    "    - https://github.com/HiJiGOO/tf_nmt_tutorial\n",
    "    - https://github.com/hccho2/RNN-Tutorial\n",
    "    - https://www.tensorflow.org/tutorials/seq2seq\n",
    "    - https://github.com/j-min/tf_tutorial_plus/tree/master/RNN_seq2seq/contrib_seq2seq\n",
    "    - https://gist.github.com/ilblackdragon/c92066d9d38b236a21d5a7b729a10f12"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.10.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "from pprint import pprint\n",
    "%matplotlib inline\n",
    "\n",
    "s2s = tf.contrib.seq2seq\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sources = [['I', 'feel', 'hungry'],\n",
    "     ['tensorflow', 'is', 'very', 'difficult'],\n",
    "     ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],\n",
    "     ['tensorflow', 'is', 'very', 'fast', 'changing']]\n",
    "targets = [['나는', '배가', '고프다'],\n",
    "           ['텐서플로우는', '매우', '어렵다'],\n",
    "           ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'],\n",
    "           ['텐서플로우는', '매우', '빠르게', '변화한다']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'<pad>': 0, 'I': 1, 'a': 2, 'changing': 3, 'deep': 4, 'difficult': 5, 'fast': 6, 'feel': 7, 'for': 8, 'framework': 9, 'hungry': 10, 'is': 11, 'learning': 12, 'tensorflow': 13, 'very': 14}\n",
      "15\n"
     ]
    }
   ],
   "source": [
    "# word dic for sentences\n",
    "source_words = []\n",
    "for elm in sources:\n",
    "    source_words += elm\n",
    "source_words = list(set(source_words))\n",
    "source_words.sort()\n",
    "source_words = ['<pad>'] + source_words\n",
    "\n",
    "source_dic = {word : idx for idx, word in enumerate(source_words)}\n",
    "print(source_dic)\n",
    "print(len(source_dic))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '<pad>',\n",
       " 1: 'I',\n",
       " 2: 'a',\n",
       " 3: 'changing',\n",
       " 4: 'deep',\n",
       " 5: 'difficult',\n",
       " 6: 'fast',\n",
       " 7: 'feel',\n",
       " 8: 'for',\n",
       " 9: 'framework',\n",
       " 10: 'hungry',\n",
       " 11: 'is',\n",
       " 12: 'learning',\n",
       " 13: 'tensorflow',\n",
       " 14: 'very'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_idx_dic = {elm[1] : elm[0] for elm in source_dic.items()}\n",
    "source_idx_dic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'<pad>': 0, '<start>': 1, '<end>': 2, '고프다': 3, '나는': 4, '딥러닝을': 5, '매우': 6, '배가': 7, '변화한다': 8, '빠르게': 9, '어렵다': 10, '위한': 11, '텐서플로우는': 12, '프레임워크이다': 13}\n",
      "14\n"
     ]
    }
   ],
   "source": [
    "# word dic for translations\n",
    "target_words = []\n",
    "for elm in targets:\n",
    "    target_words += elm\n",
    "target_words = list(set(target_words))\n",
    "target_words.sort()\n",
    "target_words =  ['<pad>']+ ['<start>'] + ['<end>'] + \\\n",
    "                    target_words # 번역문의 시작과 끝을 알리는 'start', 'end' token 추가\n",
    "\n",
    "target_dic = {word : idx for idx, word in enumerate(target_words)}\n",
    "print(target_dic)\n",
    "print(len(target_dic))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '<pad>',\n",
       " 1: '<start>',\n",
       " 2: '<end>',\n",
       " 3: '고프다',\n",
       " 4: '나는',\n",
       " 5: '딥러닝을',\n",
       " 6: '매우',\n",
       " 7: '배가',\n",
       " 8: '변화한다',\n",
       " 9: '빠르게',\n",
       " 10: '어렵다',\n",
       " 11: '위한',\n",
       " 12: '텐서플로우는',\n",
       " 13: '프레임워크이다'}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_idx_dic = {elm[1] : elm[0] for elm in target_dic.items()}\n",
    "target_idx_dic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function for sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq_enc(sequences, max_len, dic):\n",
    "    seq_len = []\n",
    "    seq_indices = []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(word) for word in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('<pad>')] \n",
    "        seq_indices.append(seq_idx)        \n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq_dec(sequences, max_len, dic):\n",
    "    seq_input_len = []\n",
    "    seq_input_indices = []\n",
    "    seq_target_indices = []\n",
    "    \n",
    "    # for decoder input\n",
    "    for seq in sequences:\n",
    "        seq_input_idx = [dic.get('<start>')] + [dic.get(token) for token in seq]\n",
    "        seq_input_len.append(len(seq_input_idx))\n",
    "        seq_input_idx += (max_len - len(seq_input_idx)) * [dic.get('<pad>')] \n",
    "        seq_input_indices.append(seq_input_idx)\n",
    "        \n",
    "    # for decoder output\n",
    "    for seq in sequences:\n",
    "        seq_target_idx = [dic.get(token) for token in seq] + [dic.get('<end>')]\n",
    "        seq_target_idx += (max_len - len(seq_target_idx)) * [dic.get('<pad>')]\n",
    "        seq_target_indices.append(seq_target_idx)\n",
    "        \n",
    "    return seq_input_len, seq_input_indices, seq_target_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-process example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 4, 7, 5] (4, 10)\n"
     ]
    }
   ],
   "source": [
    "# for encoder\n",
    "source_max_len = 10\n",
    "X_length, X_indices = pad_seq_enc(sequences = sources, max_len = source_max_len, dic = source_dic)\n",
    "print(X_length, np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 4, 5, 5]\n",
      "[[1, 4, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      " [1, 12, 6, 10, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      " [1, 12, 5, 11, 13, 0, 0, 0, 0, 0, 0, 0],\n",
      " [1, 12, 6, 9, 8, 0, 0, 0, 0, 0, 0, 0]]\n",
      "[[4, 7, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      " [12, 6, 10, 2, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      " [12, 5, 11, 13, 2, 0, 0, 0, 0, 0, 0, 0],\n",
      " [12, 6, 9, 8, 2, 0, 0, 0, 0, 0, 0, 0]]\n"
     ]
    }
   ],
   "source": [
    "# for decoder\n",
    "target_max_len = 12\n",
    "y_length, y_input_indices, y_target_indices = pad_seq_dec(sequences = targets, max_len = target_max_len,\n",
    "                                                             dic = target_dic)\n",
    "pprint(y_length)\n",
    "pprint(y_input_indices)\n",
    "pprint(y_target_indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define SimpleNMT\n",
    "Encoder RNN, Decoder RNN, Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleNMT:\n",
    "    def __init__(self, s_len, s_indices, t_len, t_input_indices, t_output_indices,\n",
    "                 t_max_len = target_max_len, s_dic = source_dic, t_dic = target_dic,\n",
    "                 n_of_classes = len(target_dic), enc_hdim = 32, dec_hdim = 16):\n",
    "        \n",
    "        with tf.variable_scope('input_layer'):\n",
    "            # s : source, t : target\n",
    "            self._s_len = s_len\n",
    "            self._s_indices = s_indices\n",
    "            self._t_len = t_len\n",
    "            self._t_input_indices = t_input_indices\n",
    "            self._t_output_indices = t_output_indices\n",
    "            self._s_dic = s_dic\n",
    "            self._t_dic = t_dic\n",
    "            self._t_max_len = target_max_len\n",
    "            \n",
    "            s_embeddings = tf.eye(num_rows = len(self._s_dic), dtype = tf.float32)\n",
    "            s_embeddings = tf.get_variable(name = 's_embeddings', initializer = s_embeddings,\n",
    "                                           trainable = False)\n",
    "            s_batch = tf.nn.embedding_lookup(params = s_embeddings, ids = self._s_indices)\n",
    "        \n",
    "        with tf.variable_scope('encoder'):\n",
    "            \n",
    "            enc_cell = tf.contrib.rnn.BasicRNNCell(num_units = enc_hdim, activation = tf.nn.tanh)\n",
    "            enc_outputs, _ = tf.nn.dynamic_rnn(cell = enc_cell, inputs = s_batch,\n",
    "                                               sequence_length = self._s_len, dtype = tf.float32)\n",
    "            \n",
    "        with tf.variable_scope('pipe'):\n",
    " \n",
    "            t_embeddings = tf.eye(num_rows = len(self._t_dic))\n",
    "            t_embeddings = tf.get_variable(name = 'embeddings',\n",
    "                                           initializer = t_embeddings,\n",
    "                                           trainable = False)\n",
    "            t_batch = tf.nn.embedding_lookup(params = t_embeddings, ids = self._t_input_indices)\n",
    "\n",
    "            batch_size = tf.reduce_sum(tf.ones_like(tensor = self._s_len))\n",
    "            tr_tokens = tf.tile(input = [self._t_max_len], multiples = [batch_size])\n",
    "            trans_tokens = tf.tile(input = [self._t_dic.get('<start>')], multiples = [batch_size])\n",
    "             \n",
    "        with tf.variable_scope('decoder'):\n",
    "            \n",
    "            dec_cell = tf.contrib.rnn.BasicRNNCell(num_units = dec_hdim, activation = tf.nn.tanh)\n",
    "            \n",
    "            # Applying attention-mechanism\n",
    "            attn = s2s.LuongAttention(num_units = dec_hdim,\n",
    "                                      memory = enc_outputs,\n",
    "                                      memory_sequence_length = self._s_len, dtype = tf.float32)\n",
    "            attn_cell = s2s.AttentionWrapper(cell = dec_cell, attention_mechanism = attn)\n",
    "            dec_initial_state = attn_cell.zero_state(batch_size = batch_size, dtype = tf.float32)\n",
    "            output_layer = tf.layers.Dense(units = n_of_classes,\n",
    "                                           kernel_initializer = \\\n",
    "                                           tf.contrib.layers.xavier_initializer(uniform = False))\n",
    "            \n",
    "            with tf.variable_scope('training'):\n",
    "                \n",
    "                tr_helper = s2s.TrainingHelper(inputs = t_batch,\n",
    "                                               sequence_length = tr_tokens)\n",
    "                tr_decoder = s2s.BasicDecoder(cell = attn_cell, helper = tr_helper,\n",
    "                                              initial_state = dec_initial_state,\n",
    "                                              output_layer = output_layer)\n",
    "                self._tr_outputs,_,_ = s2s.dynamic_decode(decoder = tr_decoder,\n",
    "                                                          impute_finished = True,\n",
    "                                                          maximum_iterations = self._t_max_len)\n",
    "            with tf.variable_scope('translation'):\n",
    "                \n",
    "                trans_helper = s2s.GreedyEmbeddingHelper(embedding = t_embeddings,\n",
    "                                                         start_tokens = trans_tokens,\n",
    "                                                         end_token = self._t_dic.get('<end>'))\n",
    "\n",
    "                trans_decoder = s2s.BasicDecoder(cell = attn_cell, helper = trans_helper,\n",
    "                                                 initial_state = dec_initial_state,\n",
    "                                                 output_layer = output_layer)\n",
    "            \n",
    "                self._trans_outputs, _, _ = s2s.dynamic_decode(decoder = trans_decoder,\n",
    "                                                               impute_finished = True,\n",
    "                                                               maximum_iterations = self._t_max_len * 2)\n",
    "            \n",
    "        with tf.variable_scope('seq2seq_loss'):\n",
    "            masking = tf.sequence_mask(lengths = self._t_len,\n",
    "                                       maxlen = self._t_max_len, dtype = tf.float32)\n",
    "            self.__seq2seq_loss = s2s.sequence_loss(logits = self._tr_outputs.rnn_output,\n",
    "                                                    targets = self._t_output_indices,\n",
    "                                                    weights = masking)\n",
    "            \n",
    "    def translate(self, sess, s_len, s_indices):\n",
    "        feed_translation = {self._s_len : s_len, self._s_indices : s_indices}\n",
    "        return sess.run(self._trans_outputs.sample_id, feed_dict = feed_translation)\n",
    "    \n",
    "    @property\n",
    "    def loss(self):\n",
    "        return self.__seq2seq_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of SimpleNMT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 1000\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?,), (?, 12), (?, 12)), types: (tf.int32, tf.int32, tf.int32, tf.int32, tf.int32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y_length, y_input_indices, y_target_indices))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_length_mb, y_input_indices_mb, y_target_indices_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_nmt = SimpleNMT(s_len = X_length_mb, s_indices  = X_indices_mb,\n",
    "                    t_len = y_length_mb, t_input_indices = y_input_indices_mb,\n",
    "                    t_output_indices = y_target_indices_mb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = sim_nmt.loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch : 100, tr_loss : 0.413\n",
      "epoch : 200, tr_loss : 0.368\n",
      "epoch : 300, tr_loss : 0.259\n",
      "epoch : 400, tr_loss : 0.034\n",
      "epoch : 500, tr_loss : 0.016\n",
      "epoch : 600, tr_loss : 0.010\n",
      "epoch : 700, tr_loss : 0.007\n",
      "epoch : 800, tr_loss : 0.005\n",
      "epoch : 900, tr_loss : 0.004\n",
      "epoch : 1000, tr_loss : 0.003\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, sim_nmt.loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4,  7,  3,  2,  0],\n",
       "       [12,  6, 10,  2,  0],\n",
       "       [12,  5, 11, 13,  2],\n",
       "       [12,  6,  9,  8,  2]], dtype=int32)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "yhat = sim_nmt.translate(sess = sess, s_len = X_length, s_indices = X_indices)\n",
    "yhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['나는', '배가', '고프다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['텐서플로우는', '매우', '어렵다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['텐서플로우는', '딥러닝을', '위한', '프레임워크이다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['텐서플로우는', '매우', '빠르게', '변화한다', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n"
     ]
    }
   ],
   "source": [
    "# 원래 문장\n",
    "originals = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], y_target_indices))\n",
    "for original in originals:\n",
    "    print(original)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['나는', '배가', '고프다', '<end>', '<pad>']\n",
      "['텐서플로우는', '매우', '어렵다', '<end>', '<pad>']\n",
      "['텐서플로우는', '딥러닝을', '위한', '프레임워크이다', '<end>']\n",
      "['텐서플로우는', '매우', '빠르게', '변화한다', '<end>']\n"
     ]
    }
   ],
   "source": [
    "# 한글 넣은 번역문장\n",
    "translations = list(map(lambda elm : [target_idx_dic.get(idx) for idx in elm], yhat))\n",
    "for translation in translations:\n",
    "    print(translation)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}