{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Stacked Bi-directional Recurrent Neural Networks with Drop out.\n",
    "\n",
    "### Many to One Classification by Stacked Bi-directional RNN with Drop out\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n",
    "- Applying **Stacking** and **dynamic rnn** to model by `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://pozalabs.github.io/blstm/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharStackedBiRNN class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharStackedBiRNN:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dims, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "            self._keep_prob = tf.placeholder(dtype = tf.float32)\n",
    "        \n",
    "        # Stacked Bi-directional RNN with Drop out\n",
    "        with tf.variable_scope('stacked_bi-directional_rnn'):\n",
    "            \n",
    "            # forward \n",
    "            rnn_fw_cells = []\n",
    "            for hidden_dim in hidden_dims:\n",
    "                rnn_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "                rnn_fw_cell = tf.contrib.rnn.DropoutWrapper(cell = rnn_fw_cell, output_keep_prob = self._keep_prob)\n",
    "                rnn_fw_cells.append(rnn_fw_cell)\n",
    "            \n",
    "            # backword\n",
    "            rnn_bw_cells = []\n",
    "            for hidden_dim in hidden_dims:\n",
    "                rnn_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "                rnn_bw_cell = tf.contrib.rnn.DropoutWrapper(cell = rnn_bw_cell, output_keep_prob = self._keep_prob)\n",
    "                rnn_bw_cells.append(rnn_bw_cell)\n",
    "            \n",
    "            _, output_state_fw, output_state_bw = \\\n",
    "            tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw = rnn_fw_cells, cells_bw = rnn_bw_cells,\n",
    "                                                           inputs = self._X_batch,\n",
    "                                                           sequence_length = self._X_length,\n",
    "                                                           dtype = tf.float32)\n",
    "            \n",
    "            final_state = tf.concat([output_state_fw[-1], output_state_bw[-1]], axis = 1)\n",
    "\n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n",
    "                                               activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharStackedBiRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_stacked_bi_rnn = CharStackedBiRNN(X_length = X_length_mb, X_indices = X_indices_mb, \n",
    "                                       y = y_mb, n_of_classes = 2, hidden_dims = [16,16], dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_stacked_bi_rnn.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.869\n",
      "epoch :   2, tr_loss : 0.486\n",
      "epoch :   3, tr_loss : 0.390\n",
      "epoch :   4, tr_loss : 0.260\n",
      "epoch :   5, tr_loss : 0.311\n",
      "epoch :   6, tr_loss : 0.187\n",
      "epoch :   7, tr_loss : 0.112\n",
      "epoch :   8, tr_loss : 0.060\n",
      "epoch :   9, tr_loss : 0.043\n",
      "epoch :  10, tr_loss : 0.036\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_stacked_bi_rnn.ce_loss],\n",
    "                                  feed_dict = {char_stacked_bi_rnn._keep_prob : .5})\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1171a2be0>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl0VeW9//H395zkZCKQEIIhOSAzEhlDQAFnraJVUIOttlq1zq1WW+9t6/31evuz63evnfV2Wal1aB0qKiBgCw51qAOIBAggyCQI5AAhYQiQeXh+fyRqQJADJtln+LzWcpGzz5Ocj0f5nJ29n/1sc84hIiKxxed1ABERaX8qdxGRGKRyFxGJQSp3EZEYpHIXEYlBKncRkRikchcRiUEqdxGRGKRyFxGJQQlevXCPHj1c3759vXp5EZGotGTJkgrnXPbRxnlW7n379qW4uNirlxcRiUpmtjmccTosIyISg1TuIiIxSOUuIhKDVO4iIjFI5S4iEoNU7iIiMUjlLiISg6Ku3Jdt2cP989d4HUNEJKJFXbl/GKpk2r8+ZtW2Sq+jiIhErKgr90tG5hLw+5i5JOR1FBGRiBV15Z6RGuDcoT2ZUxKioanZ6zgiIhEp6sodoKggyK6qet5aW+51FBGRiBSV5X7mkGyy0gLMXFLqdRQRkYgUleWe6Pdx6eg8Xl9Txp6qeq/jiIhEnKgsd2g5NNPQ5HhpxTavo4iIRJyoLff83K4M7dWVGTo0IyLyBVFb7gBFBXmsKK1kfdl+r6OIiESUqC73KaPy8PuMGUu19y4i0lZUl3t2ehJnD8lm9rIQTc3O6zgiIhEjqssdWk6slu2r490NFV5HERGJGFFf7ucM7Um3lETNeRcRaSOscjezSWa21sw2mNlPD/N8HzN708yWmdkKM7uo/aMeXlKCn8kjc3ll1Q721TZ01suKiES0o5a7mfmBh4ALgXzgKjPLP2TYz4DnnXOjgSuBP7Z30C9TNCZIXWMz/1ixvTNfVkQkYoWz5z4O2OCc2+icqwemA1MOGeOArq1fdwM69cqikcFuDMhO06EZEZFW4ZR7HrC1zePS1m1t/Ry42sxKgXnAHe2SLkxmxtQxvSnevIdPKqo686VFRCJSe51QvQr4i3MuCFwEPGVmX/jZZnazmRWbWXF5efuu6HjZ6Dx8BrM0511EJKxyDwG92zwOtm5r6wbgeQDn3EIgGehx6A9yzj3inCt0zhVmZ2cfX+IjyOmWzMSBPZi5NESz5ryLSJwLp9wXA4PMrJ+ZBWg5YTr3kDFbgHMBzGwoLeXe6YutTx0TJLS3hvc37erslxYRiShHLXfnXCNwO/AK8BEts2JWmdl9Zja5ddjdwE1mthx4FrjOOdfpu8/n5+fQJSlBt+ATkbiXEM4g59w8Wk6Utt12b5uvVwMT2zfasUsJ+Ll4RC/mLt/GfVNOJi0prH89EZGYE/VXqB6qaEyQ6vomXv5wh9dRREQ8E3PlXnhiJidmpTJTs2ZEJI7FXLmbGZePDrLg412U7qn2Oo6IiCdirtwBLi9oucbqxaU6sSoi8Skmy71391RO7d+dWctCeDBpR0TEczFZ7tCyzvumiiqWbtnjdRQRkU4Xs+V+4fBepCT6maE57yISh2K23LskJXDhsBz+vmIbtQ1NXscREelUMVvu0DLnfX9tI6+uLvM6iohIp4rpch/fP4vcbsla511E4k5Ml7vPZ1xeEOSd9eWU7av1Oo6ISKeJ6XKHljnvzQ5mL9OJVRGJHzFf7v2zu1DQJ4OZS0s1511E4kbMlzu0nFhdV3aAlaFKr6OIiHSKuCj3i0fkEkjw6cSqiMSNuCj3bimJnJ9/AnOXb6O+sdnrOCIiHS4uyh1aDs3sqW7gjTU7vY4iItLh4qbcTx/Yg+z0JK3zLiJxIW7KPcHv47LReby5Zie7DtR5HUdEpEPFTblDy0qRjc2OOSXbvI4iItKh4qrch+SkMzyvmw7NiEjMi6tyBygqyGPVtn2s2bHP6ygiIh0m7sp98qg8Ev2mOe8iEtPirty7pwU4e0hPXly2jcYmzXkXkdgUd+UOLXPeKw7U8fb6cq+jiIh0iLgs97OH9KR7WoCZugWfiMSouCz3QIKPySNzeW11GZXVDV7HERFpd3FZ7gBTxwSpb2rmpRWa8y4isSduy/3k3K4MOSFdc95FJCbFbbmbGUVj8li2ZS8flx/wOo6ISLuK23IHuHRUHj5Dc95FJObEdbn37JrMmYOzeXFZiKZm3YJPRGJHXJc7tMx5315Zy8KPd3kdRUSk3cR9uZ839AS6JifoxKqIxJS4L/fkRD8Xj8xl/ofb2V+rOe8iEhvivtyhZZ332oZm5q/c4XUUEZF2oXIHCvpk0L9HGjN0aEZEYoTKnU/nvAf5YNNutu6u9jqOiMhXFla5m9kkM1trZhvM7KdHGPMNM1ttZqvM7G/tG7PjXTY6DzN0YlVEYsJRy93M/MBDwIVAPnCVmeUfMmYQcA8w0Tl3MnBXB2TtULkZKUwYkMWspSGc05x3EYlu4ey5jwM2OOc2OufqgenAlEPG3AQ85JzbA+Cc29m+MTtHUUGQLburWfzJHq+jiIh8JeGUex6wtc3j0tZtbQ0GBpvZe2b2vplNaq+AnWnSsBzSAn5mLNl69MEiIhGsvU6oJgCDgLOAq4A/m1nGoYPM7GYzKzaz4vLyyLsLUmoggYuG92Leyh3U1Dd5HUdE5LiFU+4hoHebx8HWbW2VAnOdcw3OuU3AOlrK/iDOuUecc4XOucLs7OzjzdyhisYEOVDXyCurNOddRKJXOOW+GBhkZv3MLABcCcw9ZMxsWvbaMbMetBym2diOOTvNuL7dCWamaNaMiES1o5a7c64RuB14BfgIeN45t8rM7jOzya3DXgF2mdlq4E3g351zUbkSl89nXF4Q5N0NFWyvrPE6jojIcQnrmLtzbp5zbrBzboBz7v+1brvXOTe39WvnnPuRcy7fOTfcOTe9I0N3tKKCPJyDWUt1A20RiU66QvUwTsxKY1zf7sxcWqo57yISlVTuR1A0Jo+N5VWUbN3rdRQRkWOmcj+Ci4b3IjnRpxOrIhKVVO5HkJ6cyAUn5/DS8u3UNWrOu4hEF5X7lygqCFJZ08DrH0XlagoiEsdU7l9i4sAe5HRNZsYSHZoRkeiicv8Sfp9xWUEe/1pXTvn+Oq/jiIiETeV+FEUFQZqaHXNKNOddRKKHyv0oBvbswsjeGTo0IyJRReUehqkFeazZsZ9V2yq9jiIiEhaVexguGZlLwO9j5hIdmhGR6KByD0NGaoDz8nsypyREQ1Oz13FERI5K5R6mooIgu6rqeWtt5N1kRETkUCr3MJ0xOJseXQLM1IlVEYkCKvcwJfp9TBmVx+trythTVe91HBGRL6VyPwZFBUEamhwvrdjmdRQRkS+lcj8G+bldye/VVXPeRSTiqdyPUdGYICtKK1lftt/rKCIiR6RyP0ZTRuWS4DNmaJ13EYlgKvdj1KNLEmcNyWb2shBNzboFn4hEJpX7cSgqCFK2r453N1R4HUVE5LBU7sfhnKE9yUhN5On3N2vvXUQiksr9OCQl+PnOqSfy2uoyrnxkIZt3VXkdSUTkICr34/TDrw3mN1eMZM2O/Ux64B2eXPgJzdqLF5EIoXI/TmbG1DFBXv3hGYzt151756zi6scWsXV3tdfRRERU7l9Vr24p/PX6sfzP5cNZvnUvkx54m2c/2IJz2osXEe+o3NuBmXHVuD68fNcZjAhmcM+slVz7xGK2V9Z4HU1E4pTKvR317p7KMzeewn1TTmbxpt2c//u3mbGkVHvxItLpVO7tzOczvjO+L/PvPJ2TctL5txeWc9OTxezcV+t1NBGJIyr3DtK3RxrTbx7Pz74+lHfWV3D+A28zpySkvXgR6RQq9w7k9xk3nt6feXeeTt+sNO6cXsL3nllKxYE6r6OJSIxTuXeCAdldmHHreH4y6SRe/2gnF/z+beav3O51LBGJYSr3TpLg93HbWQN46Y7T6JWRzG3PLOUHzy7TXZ1EpEOo3DvZkJx0XvzeRH70tcHMW7md8x94m3+uLvM6lojEGJW7BxL9Pn5w7iDm3D6RrLQANz5ZzN3PL6eypsHraCISI1TuHjo5txtzbz+NO84ZyOySEBf8/m3eWrvT61giEgNU7h4LJPi4+/whzLptAunJCVz3xGLumbWC/bXaixeR46dyjxAje2fw0h2nccuZ/Xlu8VYmPfAOC3QzEBE5Tir3CJKc6OeeC4fywq0TCCT4+Naji7h3zodU1TV6HU1EokxY5W5mk8xsrZltMLOffsm4IjNzZlbYfhHjz5gTM5n3g9P57sR+PPX+Zi588B0+2LTb61giEkWOWu5m5gceAi4E8oGrzCz/MOPSgTuBRe0dMh6lBPzce0k+0286FYBvPrKQX/x9NbUNTR4nE5FoEM6e+zhgg3Nuo3OuHpgOTDnMuF8AvwS0QlY7OqV/FvPvPJ2rTzmRx97dxEUPvsPSLXu8jiUiES6ccs8DtrZ5XNq67TNmVgD0ds79ox2zSau0pAR+cekwnr7hFOoam5n68ALun79Ge/EickRf+YSqmfmA3wF3hzH2ZjMrNrPi8vLyr/rScee0QT14+a7T+UZhb6b962Mu+cO7rCyt9DqWiESgcMo9BPRu8zjYuu1T6cAw4C0z+wQ4FZh7uJOqzrlHnHOFzrnC7Ozs408dx9KTE7m/aARPXD+WfbUNXPrH93j4rY+1lLCIHCSccl8MDDKzfmYWAK4E5n76pHOu0jnXwznX1znXF3gfmOycK+6QxALA2UN68updZzJpWA6/fHkNP5+7iuZmFbyItEg42gDnXKOZ3Q68AviBx51zq8zsPqDYOTf3y3+CdJRuqYn84crR9OqazKPvbmJ3dQO/vWIkgQRdviAS745a7gDOuXnAvEO23XuEsWd99VgSLp/P+NnF+WSnJ/E/89ewp6qeadeMoUtSWP9pRSRGaRcvRtxy5gB+PXUECzfu4lt/fp9dutuTSFxTuceQKwp786erx7B2x36mTlvI1t3VXkcSEY+o3GPMefkn8MyNp7DrQB1FDy9gzY59XkcSEQ+o3GNQYd/uvHDrBMzgG9MWsvgTrUsjEm9U7jFqSE46M2+bQI8uSVz96CJe0638ROKKyj2GBTNTeeHW8ZyUk86tTy/h+cVbj/5NIhITVO4xLqtLEn+76VQmDMjixzNX8Me3NuhqVpE4oHKPA2lJCTx27Vgmj8zlVy+v5Rd//0hXs4rEOF3pEicCCT4e+OYouqcFePy9TeyuquNXU3U1q0isUrnHEZ/P+K9LWq5m/fUra9ld3cC0qwtIDeh/A5FYo922OGNmfP/sgdx/+XDeXV/Ot/68iD1V9V7HEpF2pnKPU1eO68PDV49h9fZ9TJ22gNDeGq8jiUg7UrnHsQtOzuHJ745j5746pj68gHVl+72OJCLtROUe507tn8Vzt4ynsdlxxbSFLNmsq1lFYoHKXcjP7cqs2yaQmZrItx9dxBtrdDWrSLRTuQsAvbunMuO2CQzs2YWbnlzCzCWlXkcKW1Oz470NFfz7C8s557dv8XyxrsQV0Rw4+UyPLkk8e9Op3PLUEu5+YTm7quq4+YwBXsc6LOccH4b2MackxEsrtlG2r44uSQnkZiTz4xkrKN1dzQ+/Nhgz8zqqiCdU7nKQ9OREnrh+LD96bjn/PW8Nuw7U89MLT4qYkty8q4o5JduYXRJiY3kViX7jrCE9mTIql/OGnoDfZ9wzayX/+8YGSvfUcH/RCF2oJXFJ5S5fkJTg53+vGk33tAB/ensjFQfqub9oOIl+b0qy4kAd/1ixndklIZZt2QvAuH7dufG0/lw0PIeM1MBB4389dQR9uqfyu9fWsb2ylmnXjKFbSqIX0UU8o3KXw/L7jPumnExWlwAP/HM9e6rreehbBaQE/J3y+lV1jby6egdzSrbxzvoKmpodJ+Wk85NJJzF5VC55GSlH/F4z4wfnDiKYmcJPZq5g6sMLeOL6sQQzUzslu0gkMK9WCCwsLHTFxcWevLYcm6ff38x/zvmQ0b0zePy6sV/YU24vDU3NvLO+nNnLtvHa6jJqGprIy0hh8qhcpozK5aScrsf8MxdsqOCWp5eQnOjn8WvHMjzYrQOSi3QeM1vinCs86jiVu4Rj3srt3DW9hBOzUnnyhnH06nbkPedj4Zxj6ZY9zF62jX+s3M7uqnoyUhO5aHgvLh2VR+GJmfh8X+14/7qy/Vz/xGJ2V9Xz0LdHc85JJ7RLdhEvqNyl3S34uIKbn1xC1+QEnrzhFAb27HLcP2t92X5ml4SYU7KN0j01JCX4OC//BC4dlceZg7Pb/STozn21fPevi1m9bR//d8owrjn1xHb9+SKdReUuHeLDUCXXPfEBTc2Ox68by+g+mWF/747KWuYuDzF72TZWb9+Hz2DiwB5cOiqPC4bl0CWpY08BVdU1csezy3hjzU5uOaM/P5l00lf+rUCks6ncpcN8UlHFdx7/gPL9dUy7ZgxnDs4+4tjKmgbmr2yZ6bJo026cg5HBbkwZlcfFI3vRMz25E5NDY1MzP39pFU+/v4Wvj+jFb68YSXJi55wkFmkPKnfpUDv313Lt44tZX7af31wxkktH5332XG1DE2+u2cnskhBvrimnvqmZfj3SmDIqlymj8ujXI83D5C3H+f/09kbun7+GwhMz+fN3CslM65iTxCLtTeUuHW5fbQM3/bWYRZt287OvD2Vor67MKQkx/8Md7K9tpEeXJC4Z2XJidESwW8RcCPWpv6/Yxo+eX05eRgp/uX4sJ2Z5+6EjEg6Vu3SK2oYm7py+jFdWtSw2lhbwc8GwHC4dlceEAVkkeHThU7gWf7Kbm54sxmfGo9cWUnAM5xBEvKByl07T1Ox4vngr6ckJnDf0hKg7hr2x/ADXPbGYsn21PHjlKCYN6+V1JJEjCrfcI3u3SqKC32dcNa4PF4/IjbpiB+if3YUXvzeB/Nyu3PbMUh57d5PXkUS+MpW7CJDVuiLmBfk5/OLvq/n53FU0NXvzW61Ie1C5i7RKTvTz0LcLuOG0fvxlwSfc+vQSauqbvI4lclxU7iJt+H3Gf16cz39dks8/PyrjykcWUr6/zutYIsdM5S5yGNdP7Me0q8ewtmw/lz/8Hht2HvA6ksgxUbmLHMEFJ+cw/ebx1NQ3UfTwAhZt3OV1JJGwqdxFvsSo3hnMum0iWV0CXPPYB8wpCXkdSSQsKneRo+iTlcqs2yYwqk8Gd04v4Y9vbcCr60NEwqVyFwlDRmqAp24Yx+SRufzq5bX8x4sf0tjU7HUskSMKq9zNbJKZrTWzDWb208M8/yMzW21mK8zsdTPTYtkSc5IS/DzwzVF876wBPPvBFm74azEH6hq9jiVyWEctdzPzAw8BFwL5wFVmln/IsGVAoXNuBDAD+FV7BxWJBD6f8eNJJ/Hflw3n3Q0VfGPaQsr21XodS+QLwtlzHwdscM5tdM7VA9OBKW0HOOfedM5Vtz58Hwi2b0yRyPKtU/rw6LWFbN5VxaUPvceaHfu8jiRykHDKPQ/Y2uZxaeu2I7kBmP9VQolEg7OH9OS5W8bT1Oy44uGFvLu+wutIIp9p1xOqZnY1UAj8+gjP32xmxWZWXF5e3p4vLeKJYXndePH7E8nNSOG6Jz7gheKtR/8mkU4QTrmHgN5tHgdbtx3EzM4D/g8w2Tl32Ou1nXOPOOcKnXOF2dlHvjWbSDTJy0jhhdvGc2r/LP59xgp+/9o6TZUUz4VT7ouBQWbWz8wCwJXA3LYDzGw08Cdain1n+8cUiWxdkxN5/LqxTB0T5MHX13P388uprG7wOpbEsaPebt4512hmtwOvAH7gcefcKjO7Dyh2zs2l5TBMF+CF1lupbXHOTe7A3CIRJ5Dg49dTR9A7M5UHXl/Hax+VccsZ/bl+Yj/Sko76V02kXelOTCId4KPt+/jtq+v450dlZKUF+N7ZA/n2KX2i8mYmEll0mz2RCLB0yx5+++pa3tuwi17dkvnBuYOYOiZIYoTfW1Yil26zJxIBCvpk8syNp/K3m06hV7dk7pm1kvN+9y/mlIRo1p2epAOp3EU6wYQBPZh52wQeu7aQ1EACd04v4cIH3+GVVTs0s0Y6hMpdpJOYGecOPYF/3HEaf7hqNA1Nzdzy1BIufeg93llfrpKXdqVyF+lkPp9xychcXv3hGfyqaAQVB+q55rEPuPKR9yn+ZLfX8SRG6ISqiMfqGpuY/sFW/vDGBioO1HH2kGzuPn8Iw/K6eR1NIpBmy4hEmer6Rv66YDPT/vUxlTUNfH14L374tcEM7NnF62gSQVTuIlGqsqaBx97ZyGPvbqKmoYnLC4Lcee4gendP9TqaRACVu0iU23Wgjoff+pgn39+Mc44rx/bhjnMG0rNrstfRxEMqd5EYsaOylj+8sZ7nFm8lwW9cO74vt545gMy0gNfRxAMqd5EYs3lXFQ/+cz0vloRICyRw4+n9uOG0fqQnJ3odTTqRyl0kRq0r28/vXl3Hy6t2kJmayG1nDeA74/tq3Zo4oXIXiXErSvfym1fX8fa6cnqmJ3HHOQP55tg+BBJ0+UosU7mLxIlFG3fxm1fXsviTPQQzU7jrvMFcNjoPv8+8jiYdQAuHicSJU/pn8fwt4/nL9WPJSE3k315YzgUPvM28ldu1OFkc0x0ERGKAmXHWkJ6cOTiblz/cwW9fW8f3nlnKSTnpXFHYm8kjc8lOT/I6pnQiHZYRiUFNzY7Zy0L8ZcEnrAxV4vcZZw3O5vKCIOcO7amTr1Es3MMy2nMXiUF+n1E0JkjRmCDryvYza2mIF5eV8vqanXRNTuDikbkUFeRR0CeT1ltjSozRnrtInGhqdiz4uIJZS0O8/OEOahqa6JuVyuUFQS4bnaflDaKEZsuIyBEdqGtk/srtzFoaYuHGXQCM69edqQVBLhyeowujIpjKXUTCUrqnmtnLQsxaGmJjRRVJCT4uODmHojFBThvYQ1MqI4zKXUSOiXOOkq17mbm0lJeWb6eypoGe6UlcOjqPooIgQ3LSvY4oqNxF5Cuoa2zizTU7mbEkxFtrd9LY7Dg5tyuXFwSZMiqXHl00rdIrKncRaRe7DtTx0vJtzFwa0rTKCKByF5F213ZaZdm+Ok2r9IDKXUQ6jKZVekflLiKdQtMqO5fKXUQ63afTKmcuDbGpoorkxJZplWcNySaYmUowM4We6cmaXvkVqNxFxDPOOZZt3cusNtMqP5XgM3plJJOXkUJeRip5mSkEM1LIy0whLyOFXhnJJCXoJO2RqNxFJCLUNzazeVcVpXtrCO2pIXTIn2X7a2lbQ2aQ3SXps7LPy0xp2etv8wGQlhS/y2Jp4TARiQiBBB+DTkhn0AmHvwiqvrGZHZW1lO6t/kL5rwxV8sqqHTQ0HbwTmpGa2Lrn/3nhBzM//00gMzUx7mfuqNxFxFOBBB99slLpk3X4GTbNzY6d++sI7a2m9JDy31RRxbsbKqiubzroe1IDfnI/K/zPPwAyUwOkBvykBPykBhI+/zrRT4I/tu5dpHIXkYjm8xk53ZLJ6ZbMmBO/+Lxzjr3VDYT21hxS/tWE9tawfOte9lQ3fPEbDxHw+0gJ+ElrU/4tf7b8k5KY8PnXn21vu631wyLx8+c+HZfowQeHyl1EopqZkZkWIDMtwLC8bocdU1XXyLa9NVTWNFBd30R1fRM1DY0tf9Y3UVXXRHVDIzWfPlffRHV9y/O7q+op3dM6rnVbfWPzMWVM9BspiX7SkloK/67zBjN5ZG57/OsfkcpdRGJeWlLCEY/5H4/GpmZqGpo++zBo+2FR3eaD4aDnW7dVNzSRmdrxc/9V7iIixyjB7yPd74voC7Ri6wyCiIgAKncRkZikchcRiUFhlbuZTTKztWa2wcx+epjnk8zsudbnF5lZ3/YOKiIi4TtquZuZH3gIuBDIB64ys/xDht0A7HHODQR+D/yyvYOKiEj4wtlzHwdscM5tdM7VA9OBKYeMmQL8tfXrGcC5Fu/X/oqIeCiccs8DtrZ5XNq67bBjnHONQCWQ1R4BRUTk2HXqCVUzu9nMis2suLy8vDNfWkQkroRzEVMI6N3mcbB12+HGlJpZAtAN2HXoD3LOPQI8AmBm5Wa2+XhCAz2AiuP83lik9+Ngej8+p/fiYLHwfhxmhZ0vCqfcFwODzKwfLSV+JfCtQ8bMBa4FFgJTgTfcURaKd85lhxPwcMysOJz1jOOF3o+D6f34nN6Lg8XT+3HUcnfONZrZ7cArgB943Dm3yszuA4qdc3OBx4CnzGwDsJuWDwAREfFIWGvLOOfmAfMO2XZvm69rgSvaN5qIiByvaL1C9RGvA0QYvR8H0/vxOb0XB4ub98Oze6iKiEjHidY9dxER+RJRV+5HW+cmXphZbzN708xWm9kqM7vT60yRwMz8ZrbMzP7udRavmVmGmc0wszVm9pGZjfc6k1fM7Ietf08+NLNnzSzZ60wdLarKPcx1buJFI3C3cy4fOBX4fhy/F23dCXzkdYgI8SDwsnPuJGAkcfq+mFke8AOg0Dk3jJZZfzE/oy+qyp3w1rmJC8657c65pa1f76flL+6hy0LEFTMLAl8HHvU6i9fMrBtwBi3TlHHO1Tvn9nqbylMJQErrRZapwDaP83S4aCv3cNa5iTutSyyPBhZ5m8RzDwA/Bo7t7sWxqR9QDjzRepjqUTNL8zqUF5xzIeA3wBZgO1DpnHvV21QdL9rKXQ5hZl2AmcBdzrl9XufxipldDOx0zi3xOkuESAAKgIedc6OBKiAuz1GZWSYtv+H3A3KBNDO72ttUHS/ayj2cdW7ihpkl0lLszzjnZnmdx2MTgclm9gkth+vOMbOnvY3kqVKg1Dn36W9zM2gp+3h0HrDJOVfunGsAZgETPM7U4aKt3D9b58bMArScFJnrcSZPtK6X/xjwkXPud17n8Zpz7h7nXNA515eW/y/ecM7F/N7ZkTjndgBbzWxI66ZzgdUeRvLSFuBUM0tt/XtzLnFwcjms5QcixZHWufE4llcmAtcAK82spHXbf7QuFSECcAfwTOuO0Ebgeo/zeMI5t8jMZgBLaZlltow4uFJVV6iKiMSgaDssIyIiYVC5i4jEIJW7iEgMUrliuv6yAAAAIElEQVSLiMQglbuISAxSuYuIxCCVu4hIDFK5i4jEoP8PbI2wXWDjLIgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_stacked_bi_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 100.00%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}