{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Stacked Bi-directional Gated Recurrent Unit with Drop out.\n",
    "\n",
    "### Many to One Classification by Stacked Bi-directional GRU with Drop out\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n",
    "- Applying **Stacking** and **dynamic rnn** to model by `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://pozalabs.github.io/blstm/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharStackedBiGRU class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharStackedBiGRU:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dims, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "            self._keep_prob = tf.placeholder(dtype = tf.float32)\n",
    "        \n",
    "        # Stacked Bi-directional GRU with Drop out\n",
    "        with tf.variable_scope('stacked_bi-directional_gru'):\n",
    "            \n",
    "            # forward \n",
    "            gru_fw_cells = []\n",
    "            for hidden_dim in hidden_dims:\n",
    "                gru_fw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "                gru_fw_cell = tf.contrib.rnn.DropoutWrapper(cell = gru_fw_cell, output_keep_prob = self._keep_prob)\n",
    "                gru_fw_cells.append(gru_fw_cell)\n",
    "            \n",
    "            # backword\n",
    "            gru_bw_cells = []\n",
    "            for hidden_dim in hidden_dims:\n",
    "                gru_bw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "                gru_bw_cell = tf.contrib.rnn.DropoutWrapper(cell = gru_bw_cell, output_keep_prob = self._keep_prob)\n",
    "                gru_bw_cells.append(gru_bw_cell)\n",
    "            \n",
    "            _, output_state_fw, output_state_bw = \\\n",
    "            tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw = gru_fw_cells, cells_bw = gru_bw_cells,\n",
    "                                                           inputs = self._X_batch,\n",
    "                                                           sequence_length = self._X_length,\n",
    "                                                           dtype = tf.float32)\n",
    "            \n",
    "            final_state = tf.concat([output_state_fw[-1], output_state_bw[-1]], axis = 1)\n",
    "\n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes, activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharStackedBiGRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_stacked_bi_gru = CharStackedBiGRU(X_length = X_length_mb, X_indices = X_indices_mb, \n",
    "                                       y = y_mb, n_of_classes = 2, hidden_dims = [16,16], dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_stacked_bi_gru.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.710\n",
      "epoch :   2, tr_loss : 0.637\n",
      "epoch :   3, tr_loss : 0.611\n",
      "epoch :   4, tr_loss : 0.564\n",
      "epoch :   5, tr_loss : 0.540\n",
      "epoch :   6, tr_loss : 0.478\n",
      "epoch :   7, tr_loss : 0.421\n",
      "epoch :   8, tr_loss : 0.402\n",
      "epoch :   9, tr_loss : 0.334\n",
      "epoch :  10, tr_loss : 0.302\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_stacked_bi_gru.ce_loss],\n",
    "                                  feed_dict = {char_stacked_bi_gru._keep_prob : .5})\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11c3b9160>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VPW9//HXJxthCWuCbMGwBJFVYGRVq7UqKkLdARdQKaKgtnp7r7a9tT+9bfW21tq6Ampdqoi44VLRXtxlC7LJaggKYQ2LYSfb5/fHjDZSMAMkOcnM+/l4zMPMmXMy75mHvOfknO+cr7k7IiISHxKCDiAiItVHpS8iEkdU+iIicUSlLyISR1T6IiJxRKUvIhJHVPoiInFEpS8iEkdU+iIicSQp6AAHS09P96ysrKBjiIjUKvPnz9/q7hkVrVfjSj8rK4ucnJygY4iI1Cpm9lU06+nwjohIHFHpi4jEEZW+iEgcUemLiMQRlb6ISBxR6YuIxBGVvohIHImZ0i8pLeP3by0nf8feoKOIiNRYMVP663bs47m5a7li8hw279wfdBwRkRopZkq/XXp9nrq2L1t3HeCKyXPYtvtA0JFERGqcqErfzAab2UozyzWz2w/x+P1mtjByW2VmX5d7bJSZfRG5jarM8Afr3bYJj48+mXXb93LV43Mp3FtclU8nIlLrVFj6ZpYIPAScC3QBRphZl/LruPvP3P0kdz8J+CvwcmTbpsCdQD+gL3CnmTWp3JfwXf3bN2Pi1SFyt+xm1JNz2X2gpCqfTkSkVolmT78vkOvuee5eBEwBhn3P+iOA5yM/nwO86+7b3X0H8C4w+FgCR+MHnTJ4cGQvlqwv5Nq/zWNfUWlVP6WISK0QTem3BtaVu58fWfZvzOx4oB0w80i2NbOxZpZjZjkFBQXR5K7Q2V1bcP/lJzHvy+2MfSaHAyUqfhGRyj6ROxyY5u5H1LDuPtHdQ+4eysio8HLQURvasxX3XtyDj77Yyvi/L6C4tKzSfreISG0UTemvBzLL3W8TWXYow/nXoZ0j3bZKXBbK5O5hXfnn8s389IWFlJZ5dT69iEiNEs0kKvOAbDNrR7iwhwMjD17JzDoDTYBZ5RbPAH5X7uTt2cAdx5T4KFw1IIt9xaX87q0VpCYl8odLepCQYNUdQ0QkcBWWvruXmNkEwgWeCDzh7kvN7C4gx92nR1YdDkxxdy+37XYzu5vwBwfAXe6+vXJfQnTGntaBfUVl3P/PVdRNSeDuYd0wU/GLSHyxch1dI4RCIa+q6RLdnXveXsFjH+Txk1Pb8YvzTlTxi0hMMLP57h6qaL0aN0duVTIzbh/cmf1FpUz6aA11U5K49axOQccSEak2cVX6EC7+Oy/oyr7iUv7yf19QNzmRG07vEHQsEZFqEXelD5CQYPz+oh7sLy7j3rdXUDc5gdGD2gUdS0SkysVl6QMkJhj3XdaTAyWl/Ob1ZaQmJzK8b9ugY4mIVKmYucrm0UhOTOAvI3rxg04Z3PHKEl5dUK1fIRARqXZxXfoAdZISeeyqPvRv14zbXlzE259vDDqSiEiVifvSB0hNTmTyqBA92zTipucX8N6KLUFHEhGpEir9iPp1knjymr6c0CKNcc/O59PcrUFHEhGpdCr9chrVTebpa/txfLN6jHk6h/lfBfLlYRGRKqPSP0jT+ik8O6YfxzVMZfQT81iSXxh0JBGRSqPSP4Tmaan8fUw/GtVL5qon5rBi086gI4mIVAqV/mG0alyX58b0JzUpkSsnz2F1we6gI4mIHDOV/vdo26wez47pB8AVk+awbvvegBOJiBwblX4FOjZvwDPX9WNfcSkjJs1mY+G+oCOJiBw1lX4UTmzZkGeu60vh3mKumDSHgl0Hgo4kInJUoip9MxtsZivNLNfMbj/MOpeZ2TIzW2pmz5VbXmpmCyO36Yfatjbo0aYxT15zMhsL93Pl5Dns2FMUdCQRkSNWYembWSLwEHAu0AUYYWZdDlonm/A0iIPcvSvw03IP73P3kyK3oZUXvfqFspoyeVSINdv2cPUTc9m5vzjoSCIiRySaPf2+QK6757l7ETAFGHbQOj8BHnL3HQDuHrPXMRjUMZ3HruzDik07uebJeew5UBJ0JBGRqEVT+q2BdeXu50eWldcJ6GRmn5jZbDMbXO6xVDPLiSz/8aGewMzGRtbJKSgoOKIXEIQzOjfnL8N7sWDtDsY8lcP+4tKgI4mIRKWyTuQmAdnA6cAIYJKZNY48dnxk3saRwJ/N7N+mqXL3ie4ecvdQRkZGJUWqWud2b8l9l/Vk9pptjHt2PgdKVPwiUvNFU/rrgcxy99tElpWXD0x392J3XwOsIvwhgLuvj/w3D3gf6HWMmWuMC3u14XcXduf9lQXc8vxCSkrLgo4kIvK9oin9eUC2mbUzsxRgOHDwKJxXCe/lY2bphA/35JlZEzOrU275IGBZJWWvEUb0bcuvh3Th7aWbuO3FRZSWedCRREQOq8LpEt29xMwmADOAROAJd19qZncBOe4+PfLY2Wa2DCgFfu7u28xsIPCYmZUR/oC5x91jqvQBrj2lHftLSvnft1dSNzmR31/UHTMLOpaIyL8x95q1ZxoKhTwnJyfoGEflvndW8teZuYwemMUvzz+R5ER9901EqoeZzY+cP/1ecTsxelW49axO7CsqZfLHa3hj8UYu6t2aS/u0Ifu4tKCjiYgA2tOvdO7OzBVbeGHeOmau2EJJmXNSZmMuC2UypGdLGqYmBx1RRGJQtHv6Kv0qtHX3AV5dsJ6pOetYtXk3qckJnNutJZf2aUP/9s1ISNBxfxGpHCr9GsTdWZxfyIvz1/Hawg3s2l9CmyZ1uaRPGy7p04Y2TeoFHVFEajmVfg21v7iUGUs38WJOPp+sDk++PrBDMy4LZXJO1xakJicGnFBEaiOVfi2Qv2MvL81fz7TP1rFu+z7SUpMY2rMVl4Yy6dmmkYZ9ikjUVPq1SFmZM3vNNqbl5PPW5xvZX1xGdvMGXBbK5Me9WpORVifoiCJSw6n0a6md+4t5c/FGpuasY8Har0lKMM7o3JzLQpmcfkKGxv6LyCGp9GNA7pZdvJiTz0ufrWfr7gOkN0jhwl6tuTSUSSeN/ReRclT6MaS4tIwPVhbw4vx1/N/y8Nj/npmNuSzUhgt6ttLYfxFR6ceqb8b+v5iTz8rNu6iTlMC53VpwaSiTARr7LxK3VPoxzt1Zsr6QF3PyeW3henbuL6F143+N/c9sqrH/IvFEpR9Hvhn7P21+Ph/nbsU9PPZ/RN+2DOnRUkM/ReKASj9Orf96Hy/Nz2fa/HzWbt/Lqdnp3HNxD1o3rht0NBGpQir9OFdW5jw3dy2/e2s5CWb895ATuSyUqb1+kRgVbelHNejbzAab2UozyzWz2w+zzmVmtszMlprZc+WWjzKzLyK3UdG/BDkWCQnGlf2PZ8ZPT6Nb64b810tLGP3kPDYW7gs6mogEqMI9fTNLJDzn7VmE58KdB4woPwOWmWUDU4EfuvsOM2vu7lvMrCmQA4QAB+YDfdx9x+GeT3v6la+szHlm9lfc848VJCUad17QlYt7t9Zev0gMqcw9/b5ArrvnuXsRMAUYdtA6PwEe+qbM3X1LZPk5wLvuvj3y2LvA4GhfhFSOhARj1MAs3v7pqZzYoiH/8eIirnsqh8079wcdTUSqWTSl3xpYV+5+fmRZeZ2ATmb2iZnNNrPBR7CtVJPjm9Vnytj+/HpIFz5dvZWz/vQBryzIp6ad1xGRqlNZF3JJArKB04ERwCQzaxztxmY21sxyzCynoKCgkiLJoSQkGNee0o63bj6V7OPS+NkLixj7zHy27NJev0g8iKb01wOZ5e63iSwrLx+Y7u7F7r6G8DmA7Ci3xd0nunvI3UMZGRlHkl+OUvuMBky9fgC/PO9EPlhVwNn3f8hrC9drr18kxkVT+vOAbDNrZ2YpwHBg+kHrvEp4Lx8zSyd8uCcPmAGcbWZNzKwJcHZkmdQAiQnGT05rz1s3n0pWs/rcMmUhNzz7GVt3Hwg6mohUkQpL391LgAmEy3o5MNXdl5rZXWY2NLLaDGCbmS0D3gN+7u7b3H07cDfhD455wF2RZVKDdGzegGnjBnD7uZ2ZuWILZ9//IW8u3hh0LBGpAvpylnzHqs27+I8XF7E4v5Dze7Tk7mHdaFo/JehYIlKBSv1ylsSPTsel8fINA/n5OSfwztJNnH3/B7z9+aagY4lIJVHpy79JSkxg/Bkdef2mU2jRKJVxz87n5ucXsGNPUdDRROQYqfTlsDq3aMgrNw7i1rM68daSjZx1/4e8u2xz0LFE5Bio9OV7JScmcPOZ2bw2YRAZaXX4ydM53PrCQgr3FgcdTUSOgkpfotK1VSNeGz8o/AGwaANn//kDZq7QXr9IbaPSl6ilJCVw61mdeG38IBrXTeHav+Xw8xcXUbhPe/0itYVKX45Yt9aNmH7TIMaf0YGXPsvnnPs/5P2VWyreUEQCp9KXo1InKZGfn9OZV24cRFpqEqOfnMftLy1m137t9YvUZCp9OSY9Mxvz+k2nMO4HHZias45z7v+Qj7/YGnQsETkMlb4cs9TkRG4/tzPTbhhIakoiVz4+h1+8soTdB0qCjiYiB1HpS6Xp3bYJb918KmNPa8/zc9dyzv0f8mmu9vpFahKVvlSq1OREfnHeibx4/QBSkhIYOXkOt05dyKert1JaVrOu8yQSj3TBNaky+4pKue+dlTw3dy17i0rJSKvD+d1bMqRHS3q3bUJCguboFaks0V5wTaUvVW5fUSkzV2zhjcUbmLliCwdKymjZKDX8AdCzFT3bNNIk7SLHSKUvNdLuAyX8c9lm3li8gQ9WFVBc6mQ2rcuQHq0Y0qMlXVo21AeAyFGo1NKPTHT+AJAITHb3ew56fDTwB/41FeKD7j458lgpsCSyfK27D+V7qPTjR+G+Yt5Zuok3Fm/k49zwMf/26fUZ0iP8F0Cn49KCjihSa1Ra6ZtZIuE5b88iPBfuPGCEuy8rt85oIOTuEw6x/W53bxBtcJV+fNq+p4i3P9/EG4s3MDtvG2UOnY5rwAU9WjGkZyvapdcPOqJIjRZt6SdF8bv6Arnunhf5xVOAYcCy791K5Ag0rZ/CyH5tGdmvLVt27eftzzfx+qIN3PfuKu57dxVdWzX89hBQZtN6QccVqbWiKf3WwLpy9/OBfodY72IzO43wXwU/c/dvtkk1sxygBLjH3V89lsAS+5qnpXL1gCyuHpDFxsJ9vLl4I28s3si9b6/g3rdX0DOzMRf0aMn5PVrSslHdoOOK1CrRHN65BBjs7mMi968C+pU/lGNmzYDd7n7AzK4HLnf3H0Yea+3u682sPTATONPdVx/0HGOBsQBt27bt89VXX1XeK5SYsW77Xt5cspHXF21g6YadAJyc1YQhPVpxbvcWNE9LDTihSHAq85j+AOA37n5O5P4dAO7++8Osnwhsd/dGh3jsb8Ab7j7tcM+nY/oSjTVb9/DGog28sXgjKzfvIsGgX7tmXNCzFYO7tdBk7hJ3KrP0kwgfsjmT8OicecBId19abp2W7r4x8vOFwH+5e38zawLsjfwFkA7MAoaVPwl8MJW+HKlVm3d9+wGQt3UPiQnGoI7pDOnRknO6tqBR3eSgI4pUucoesnke8GfCQzafcPffmtldQI67Tzez3wNDCR+33w7c4O4rzGwg8BhQRviSD39298e/77lU+nK03J1lG3fyxuKNvLF4A+u27yM50TgtO4MLerbi7K7HUS8lmtNYIrWPvpwlcc3dWZxfyOuLNvDmko1sLNxP+/T6PHVtX43+kZik0heJKCtzPviigJ9OWUhyovHE6JPp0aZx0LFEKlW0pa+rbErMS0gwzjihOS/dMJA6SYkMnzib91ZoekeJTyp9iRsdmzfglfEDaZ9RnzFP5/D83LVBRxKpdip9iSvN01J5YewATumYzh0vL+FP76ykph3iFKlKKn2JO/XrJDF5VIjLQ5n8ZWYu//HiYopLy4KOJVItNH5N4lJyYgL3XNydVo3rcv8/V7Fl134evqI3aaka0y+xTXv6ErfMjFt+lM3/XtKDWau3cfljs9m8c3/QsUSqlEpf4t5loUweH30yX23bw0UPf8qqzbuCjiRSZVT6IsAPOmXwwvUDKCot45JHPmV23ragI4lUCZW+SES31o145caBNG+YytWPz+X1RRuCjiRS6VT6IuW0aVKPaeMGcFJmY256fgGTPszTkE6JKSp9kYM0rpfC09f15fweLfntW8v5f68vo7RMxS+xQUM2RQ4hNTmRvw7vRatGqUz6aA0bC/fxwPBepCYnBh1N5JhoT1/kMBISjF+e34VfD+nCO8s2M3LSbLbvKQo6lsgxUemLVODaU9rx8MjefL5hJ5c88ilrt+0NOpLIUYuq9M1ssJmtNLNcM7v9EI+PNrMCM1sYuY0p99goM/sichtVmeFFqsu53Vvy3Jh+bN9bxEWPfMKidV8HHUnkqFRY+pE5bx8CzgW6ACPMrMshVn3B3U+K3CZHtm0K3An0A/oCd0amUBSpdUJZTXnphoGkJocvzzxzxeagI4kcsWj29PsCue6e5+5FwBRgWJS//xzgXXff7u47gHeBwUcXVSR4HTIa8PKNA+nYvAFjnsrhuTm6PLPULtGUfmtgXbn7+ZFlB7vYzBab2TQzyzzCbUVqjeZpqUwZ25/TOmXwi1eWcJ8uzyy1SGWdyH0dyHL3HoT35p86ko3NbKyZ5ZhZTkFBQSVFEqk69eskMfnqEMNPzuSvM3O57cVFFJXo8sxS80VT+uuBzHL320SWfcvdt7n7gcjdyUCfaLeNbD/R3UPuHsrIyIg2u0igkhIT+P1F3bntrE68/Nl6rv3bPHbtLw46lsj3iqb05wHZZtbOzFKA4cD08iuYWctyd4cCyyM/zwDONrMmkRO4Z0eWicQEM+OmM7P5wyU9mJ23jUsfncWmQl2eWWquCkvf3UuACYTLejkw1d2XmtldZjY0strNZrbUzBYBNwOjI9tuB+4m/MExD7grskwkplwayuSJ0SezbvteLnr4E12eWWosq2knoEKhkOfk5AQdQ+SoLN1QyDVPzmNfcSkTrwoxoEOzoCNJnDCz+e4eqmg9fSNXpBJ1bdWIV8YPokXDVEY9MZfXFv7bKSyRQKn0RSpZ68Z1mTZuICe1bcwtUxby2AerNaRTagyVvkgVaFQvmWeu68uQHi35/T9W8JvpS3V5ZqkRdGllkSpSJymRvwzvRavGdZn4YR4bC/fzlxG6PLMES3v6IlUoIcH4xXkn8psLuvDu8s2MmDSbbbsPVLyhSBVR6YtUg9GD2vHIFX1YtmEnP374E1Zu0pBOCYZKX6SaDO7WgheuH8CB4jIuevgT3l2mq3RK9VPpi1SjkzIbM33CKXRo3oCxz+Tw0Hu5Gtkj1UqlL1LNWjRKZer1A7igRyv+MGMlP31hIfuLS4OOJXFCo3dEApCanMgDw0/ihBZp/GHGSr7cuofHrgrRolFq0NEkxmlPXyQgZsb4Mzoy8ao+5G7ZzdAHP2ahpmGUKqbSFwnY2V1b8NKNA0lJSuCyx2bp0g1SpVT6IjVA5xYNmT7hFHplhi/dcO/bKyjTN3ilCqj0RWqIpvVTeOa6fozs15ZH3l/N2Gdy2H2gJOhYEmNU+iI1SEpSAr/9cTfuGtaV91YWcNHDn7B2296gY0kMiar0zWywma00s1wzu/171rvYzNzMQpH7WWa2z8wWRm6PVlZwkVhlZlw9IIunr+3L5p0HGPbQx8xavS3oWBIjKix9M0sEHgLOBboAI8ysyyHWSwNuAeYc9NBqdz8pchtXCZlF4sKgjum8Nn4QzRrU4arH5/Ds7K+CjiQxIJo9/b5ArrvnuXsRMAUYdoj17gbuBTRBqEglyUqvz8s3DuTU7HR+9ern/Pern1NcWhZ0LKnFoin91sC6cvfzI8u+ZWa9gUx3f/MQ27czswVm9oGZnXr0UUXiU8PUZCaPOpnrT2vPM7O/4urH57JjT1HQsaSWOuYTuWaWAPwJuO0QD28E2rp7L+BW4Dkza3iI3zHWzHLMLKegoOBYI4nEnMQE447zTuS+S3sy/6sd/PjhT/hCk6/LUYim9NcDmeXut4ks+0Ya0A1438y+BPoD080s5O4H3H0bgLvPB1YDnQ5+Anef6O4hdw9lZGQc3SsRiQMX92nDlOv7s+dAKRc+/Cn/t1xX6pQjE03pzwOyzaydmaUAw4Hp3zzo7oXunu7uWe6eBcwGhrp7jpllRE4EY2btgWwgr9JfhUgc6d22Ca/fNIis9HqMeTqHRzUHrxyBCkvf3UuACcAMYDkw1d2XmtldZja0gs1PAxab2UJgGjDO3bcfa2iReNeyUV1evH4g53VvyT3/WMFtUxfpSp0SFatpewihUMhzcnKCjiFSK7g7D87M5b53V3FSZmMmXtWH5g11pc54ZGbz3T1U0Xr6Rq5ILWZm3HRmNo9e2YdVm3cx9MFPWJyvK3XK4an0RWLA4G4tmDZuIIkJxqWPzmL6og1BR5IaSqUvEiO6tGrIaxMG0aNNI25+fgF/nLFSV+qUf6PSF4kh6Q3q8Pcx/bk8lMmD7+Uy7tn57NGVOqUclb5IjElJSuCei7tz5wVd+OfyzVz8yKes264rdUqYSl8kBpkZ1wxqx1PX9mXD1/sY9tAnzMnTlTpFpS8S007NzuDV8YNoXC+ZKybP4fm5a4OOJAFT6YvEuPYZDXjlxkEM7JjOHS8v4TfTl1KiK3XGLZW+SBxoVDeZJ0aFGHNKO/726ZeMfnIec/K2UVSi8o83SUEHEJHqkZSYwK+GdKFTizR+9ernXD5xNvVSEunbrimndEznlOx0TjguDTMLOqpUIV2GQSQOFe4rZnbeNj7J3crHuVvJK9gDQHqDFAZ1TP/21rpx3YCTSrSivQyD9vRF4lCjusmc07UF53RtAcCGr/fxSe7WyIfANl5bGP5Gb/v0+t9+AAxo34xG9ZKDjC2VQHv6IvId7s6qzbv5OPIhMDtvG3uLSkkw6N6mMad0bMagjun0Ob4JdZISg44rEdHu6av0ReR7FZWUsSj/az7+IvwhsGDd15SWOanJCZycFT4fMKhjOl1aNiQhQecDgqLSF5EqsWt/MXPXbP/2L4FVm3cD0KReMgM7podPCndMJ7NpvYCTxhcd0xeRKpGWmsyZJx7HmSceB8Dmnfv5dPVWPv5iGx/nFvDm4o0AtG1aj0GRD4CBHZrRpH5KkLElIqo9fTMbDDwAJAKT3f2ew6x3MeEZsk5295zIsjuA64BS4GZ3n/F9z6U9fZHay91ZXbDn21FBs1dvY9eBEsyga6uG334InJzVlNRknQ+oTJV2eCcyx+0q4Cwgn/CcuSPcfdlB66UBbwIpwITIHLldgOeBvkAr4J9AJ3c/7LxuKn2R2FFSWsbi9YV88sVWPsrdyoK1OygudVKSEggd34Qfdm7OqIFZJCfqe6LHqjIP7/QFct09L/KLpwDDgGUHrXc3cC/w83LLhgFT3P0AsMbMciO/b1YUzysitVxSYgK92zahd9sm3HRmNnsOlDD3y+18Ghka+j9vLmfumu38dWQvjQSqJtF8vLYG1pW7nx9Z9i0z6w1kuvubR7ptZPuxZpZjZjkFBQVRBReR2qd+nSTOOKE5vzy/C/+45VR+c0EX3lm2mXHPzNfE7tXkmP+mMrME4E/AbUf7O9x9oruH3D2UkZFxrJFEpJYYPagdv7uwO++vKuC6p+axt0gTvlS1aEp/PZBZ7n6byLJvpAHdgPfN7EugPzDdzEJRbCsicW5kv7b88ZKezFq9jdFPzGO3ZvqqUtGU/jwg28zamVkKMByY/s2D7l7o7ununuXuWcBsYGhk9M50YLiZ1TGzdkA2MLfSX4WI1GoX92nDA8N7MX/tDq56fA6F+4qDjhSzKix9dy8BJgAzgOXAVHdfamZ3mdnQCrZdCkwlfNL3bWD8943cEZH4dUHPVjw0sjefry/kismz2bGnKOhIMUnfyBWRGuW9FVu4/tn5tE+vzzPX9SMjrU7QkWqFaIdsanCsiNQoZ3RuzpOjT+bLbXsYPnEWmwr3Bx0ppqj0RaTGGdQxnaev7cemwv1cPnEW67/eF3SkmKHSF5EaqW+7pjwzph/b9xRx2aOz+GrbnqAjxQSVvojUWL3bNuH5n/RnT1EJlz82m9UFu4OOVOup9EWkRuvWuhFTxvanpKyMyx+bzcpNu4KOVKup9EWkxuvcoiFTxg4gwWD4xFl8vr4w6Ei1lkpfRGqFjs0bMPX6AdRLSWLkpNksXPd10JFqJZW+iNQaWen1eeH6/jSul8KVk+cw78vtQUeqdVT6IlKrtGlSj6nXD6B5Wh2ufnwun+ZuDTpSraLSF5Fap0WjVKZc35/MpnW55m/zeH/llqAj1RoqfRGplZqnpTJl7AA6Nm/A2Kfn887STUFHqhVU+iJSazWtn8JzY/pzYquG3Pj3z76dlF0OT6UvIrVao3rJPHtdX3q1bcxNz3/GKwvyg45Uo6n0RaTWS0tN5qlr+9K/fTNunbqIKXPXBh2pxlLpi0hMqJeSxBOjT+a07Axuf3kJT8/6MuhINVJUpW9mg81spZnlmtnth3h8nJktMbOFZvaxmXWJLM8ys32R5QvN7NHKfgEiIt9ITU5k4tV9OKvLcfz6taVM+jAv6Eg1ToWlb2aJwEPAuUAXYMQ3pV7Oc+7e3d1PAv6X8ETp31jt7idFbuMqK7iIyKHUSUrk4St6c373lvz2reU8OPOLoCPVKElRrNMXyHX3PAAzmwIMIzwFIgDuvrPc+vWBmjUdl4jEleTEBB4YfhJ1khL44zurOFBSxq1ndcLMgo4WuGhKvzWwrtz9fKDfwSuZ2XjgViAF+GG5h9qZ2QJgJ/Ard//oENuOBcYCtG3bNurwIiKHk5SYwB8v7UlKUgJ/nZnL/uJSfnHeiXFf/JV2ItfdH3L3DsB/Ab+KLN4ItHX3XoQ/EJ4zs4aH2Haiu4fcPZSRkVFZkUQkziUkGL+7sDujBhzPpI/WcOf0pZSVxfeBiGj29NcDmeXut4ksO5wpwCMA7n4AOBD5eb6ZrQY6AZr5XESqRUKC8ZuhXamTnMjED/MoKinjtxd2JzGbR2amAAAHUElEQVQhPvf4oyn9eUC2mbUjXPbDgZHlVzCzbHf/5mzJ+cAXkeUZwHZ3LzWz9kA2oNPpIlKtzIw7zu1MncihngMlZfzhkh4kJcbfqPUKS9/dS8xsAjADSASecPelZnYXkOPu04EJZvYjoBjYAYyKbH4acJeZFQNlwDh317VQRaTamRm3nX3Ctyd3i0rK+PPwk0iOs+I395p1fCsUCnlOjo7+iEjVmfxRHv/z5nLO6nIcD47sRZ2kxKAjHTMzm+/uoYrWi6+POBERYMyp7blrWFfeXbaZsU/PZ39xadCRqo1KX0Ti0tUDsrj34u58+EUBIyfNjptZuFT6IhK3Lj+5LQ8M78WX2/Zy6aOzuPTRT3lvxRZq2mHvyqRj+iIS9/YWlTB13jomfbSG9V/vo3OLNG44vQPnd29Za0b4RHtMX6UvIhJRXFrGaws38OgHq8ndspu2Tesx9rT2XNKnDanJNftkr0pfROQolZU57y7fzMPvr2bRuq9Jb1CH605px5X925KWmhx0vENS6YuIHCN3Z1beNh55fzUffbGVtNQkrup/PNcMakdGWp2g432HSl9EpBItyS/kkQ9y+cfnm0hJTOCyUCZjT2tPZtN6QUcDVPoiIlUir2A3j32Qx8sL8ilzuKBHS244vSMntEgLNJdKX0SkCm0q3M/jH+fx9zlr2VtUypmdm3PjGR3oc3zTQPKo9EVEqsHXe4t4etZXPPnJGnbsLaZvVlNuOKMDp3fKqNZr96v0RUSq0d6iEqbMXcekj/LYWLifE1s25IbTO3BetxbVMtZfpS8iEoCikjJeW7ieRz9YzeqCPRzfLDzW/+LeVTvWX6UvIhKgsjLnnWWbeeT9XBblF5KRFh7rf0W/qhnrr9IXEakB3J1Zq7fx8Pur+Tg3PNb/6gHhsf7pDSpvrH+lXlrZzAab2UozyzWz2w/x+DgzW2JmC83sYzPrUu6xOyLbrTSzc47sZYiI1G5mxsCO6Tw7ph/TJwzilI7pPPz+agbdM5Nfv/Y567bvrd48Fe3pm1kisAo4C8gnPH3iCHdfVm6dhu6+M/LzUOBGdx8cKf/ngb5AK+CfQCd3P+zFq7WnLyKxbnXBbiaWG+s/tGcrxv2gwzGN9a/MPf2+QK6757l7EeGJz4eVX+Gbwo+oD3zzSTIMmOLuB9x9DZAb+X0iInGrQ0YD7r2kBx/+5xlcMzCLGUs3cc6fP2T8c59V+WWdo5kYvTWwrtz9fKDfwSuZ2XjgViAF+GG5bWcftG3ro0oqIhJjWjaqy6+GdGH8GR15etZXFJWWVvnY/mhKPyru/hDwkJmNBH7FvyZHr5CZjQXGArRt27ayIomI1ApN6qdwy4+yq+W5ojm8sx7ILHe/TWTZ4UwBfnwk27r7RHcPuXsoIyMjikgiInI0oin9eUC2mbUzsxRgODC9/ApmVv4j6nzgi8jP04HhZlbHzNoB2cDcY48tIiJHo8LDO+5eYmYTgBlAIvCEuy81s7uAHHefDkwwsx8BxcAOIod2IutNBZYBJcD47xu5IyIiVUtfzhIRiQGV+uUsERGJDSp9EZE4otIXEYkjKn0RkThS407kmlkB8NUx/Ip0YGslxant9F58l96P79L78S+x8F4c7+4VftGpxpX+sTKznGjOYMcDvRffpffju/R+/Es8vRc6vCMiEkdU+iIicSQWS39i0AFqEL0X36X347v0fvxL3LwXMXdMX0REDi8W9/RFROQwYqb0K5rHN56YWaaZvWdmy8xsqZndEnSmoJlZopktMLM3gs4SNDNrbGbTzGyFmS03swFBZwqSmf0s8u/kczN73sxSg85UlWKi9CPz+D4EnAt0AUaUn5w9DpUAt7l7F6A/MD7O3w+AW4DlQYeoIR4A3nb3zkBP4vh9MbPWwM1AyN27Eb6S8PBgU1WtmCh9opjHN564+0Z3/yzy8y7C/6jjdppKM2tDeJ6HyUFnCZqZNQJOAx4HcPcid/862FSBSwLqmlkSUA/YEHCeKhUrpX+oeXzjtuTKM7MsoBcwJ9gkgfoz8J9AWdBBaoB2QAHwZORw12Qzqx90qKC4+3rgj8BaYCNQ6O7vBJuqasVK6cshmFkD4CXgp+6+M+g8QTCzIcAWd58fdJYaIgnoDTzi7r2APUDcngMzsyaEjwq0A1oB9c3symBTVa1YKf0jncc35plZMuHC/7u7vxx0ngANAoaa2ZeED/v90MyeDTZSoPKBfHf/5i+/aYQ/BOLVj4A17l7g7sXAy8DAgDNVqVgp/Qrn8Y0nZmaEj9kud/c/BZ0nSO5+h7u3cfcswv9fzHT3mN6T+z7uvglYZ2YnRBadSXg603i1FuhvZvUi/27OJMZPbFc4R25tcLh5fAOOFaRBwFXAEjNbGFn2C3d/K8BMUnPcBPw9soOUB1wTcJ7AuPscM5sGfEZ41NsCYvzbufpGrohIHImVwzsiIhIFlb6ISBxR6YuIxBGVvohIHFHpi4jEEZW+iEgcUemLiMQRlb6ISBz5/7IGArGwghuUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_stacked_bi_gru.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 83.33%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}