{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Bi-directional Recurrent Neural Networks.\n",
    "\n",
    "### Many to One Classification by Bi-directional RNN\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://pozalabs.github.io/blstm/\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharBiRNN class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharBiRNN:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "        \n",
    "        # Bi-directional RNN\n",
    "        with tf.variable_scope('bi-directional_rnn'):\n",
    "            rnn_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            rnn_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            _, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw = rnn_fw_cell,\n",
    "                                                                    cell_bw = rnn_bw_cell,\n",
    "                                                                    inputs = self._X_batch,\n",
    "                                                                    sequence_length = self._X_length,\n",
    "                                                                    dtype = tf.float32)\n",
    "\n",
    "            final_state = tf.concat([output_states[0], output_states[1]], axis = 1)\n",
    "\n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n",
    "                                               activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharBiRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_bi_rnn = CharBiRNN(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n",
    "                        n_of_classes = 2, hidden_dim = 16, dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_bi_rnn.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.712\n",
      "epoch :   2, tr_loss : 0.554\n",
      "epoch :   3, tr_loss : 0.454\n",
      "epoch :   4, tr_loss : 0.350\n",
      "epoch :   5, tr_loss : 0.278\n",
      "epoch :   6, tr_loss : 0.213\n",
      "epoch :   7, tr_loss : 0.161\n",
      "epoch :   8, tr_loss : 0.121\n",
      "epoch :   9, tr_loss : 0.092\n",
      "epoch :  10, tr_loss : 0.072\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_bi_rnn.ce_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x119b0d828>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd8VfX9x/HXJzcJIRBmEkYS9gwyApEtgqs4EUEEW6t14Kh1t7WtXfanta0LlSpIrVtERMWqxQWKLAl7hBHCSlhhBggQknx/fyRitEgucJOTe+/7+XjwqOfcLzlv76O8PXzPOd9jzjlERCS0RHgdQEREAk/lLiISglTuIiIhSOUuIhKCVO4iIiFI5S4iEoJU7iIiIUjlLiISglTuIiIhKNKrA8fHx7sWLVp4dXgRkaC0YMGCnc65hIrGeVbuLVq0ICMjw6vDi4gEJTPb6M84TcuIiIQglbuISAhSuYuIhCCVu4hICFK5i4iEIJW7iEgIUrmLiISgoCv3rB37eeSjVej1gCIiP8yvcjezwWa22syyzOz+43z+hJktLvu1xsz2Bj5qqRmr83jui3W8lZFTWYcQEQl6FZa7mfmAscCFQCowysxSy49xzt3tnOvmnOsGPA1MqYywANf3a0nvVg348/sr2Ly7oLIOIyIS1Pw5c+8JZDnnsp1zhcBEYMgJxo8C3ghEuOOJiDAevbIrEWbcO2kJxSWanhER+T5/yj0J2FxuO6ds3/8ws+ZAS+Dz04/2w5Lrx/LHyzrx9YbdTJiZXZmHEhEJSoG+oDoSmOycKz7eh2Y22swyzCwjLy/vtA40rHsSP+rUiMc+XkPm1vzT+lkiIqHGn3LPBVLKbSeX7TuekZxgSsY5N945l+6cS09IqHDFyhMyMx4e2pk6NSO5+83FHCk67n9PRETCkj/lPh9oa2YtzSya0gKf+v1BZtYBqA/MCWzEH9awdg0euaILq7bt54lP1lbVYUVEqr0Ky905VwTcDkwDMoFJzrkVZvagmV1WbuhIYKKr4hvQz0ttxFXpKYz7ch3zN+yuykOLiFRb5tXDQOnp6S5QL+s4cKSIC8d8CcBHdw6gdg3P3kEiIlKpzGyBcy69onFB94Tq8dSuEcljV3YjZ88h/u8/K72OIyLiuZAod4CeLRswekArJs7fzGeZ272OIyLiqZApd4B7zm9Hh8Zx/PrtZew6cMTrOCIingmpcq8R6eOJq7qRf+gov3tnuRYXE5GwFVLlDtCxSR3uuaAd/12xjSkLf+h2fBGR0BZy5Q5w01mt6NmiAX+auoLcvYe8jiMiUuVCstx9EcZjI7pS4hz3TVpCiRYXE5EwE5LlDpDSIJY/XJrKnOxdvDBrvddxRESqVMiWO8CI9BTO65jI36etZs32/V7HERGpMiFd7mbGX6/oQlyN0sXFCotKvI4kIlIlQrrcARLiavDwFZ1ZsSWfpz7T4mIiEh5CvtwBftSpMcN7JPPPGVks2LjH6zgiIpUuLMod4I+XptKkbk3unbSYgsIir+OIiFSqsCn3uJgoHhvRlY27C3jog0yv44iIVKqwKXeA3q0acmP/lrw2bxPTV+/wOo6ISKUJq3IHuPeC9rRrVJtfT17KnoOFXscREakUYVfuMVE+Hh/RjT0FhTzwnhYXE5HQFHblDnBGUl3uOq8dHyzdytQlW7yOIyIScGFZ7gA3D2hF92b1+P27y9m6T4uLiUhoCdtyj/RF8PiIbhSVOH751lItLiYiISVsyx2gRXwtfndxR77K2snLczZ4HUdEJGDCutwBru7ZjEHtE/jrR6vI2nHA6zgiIgHhV7mb2WAzW21mWWZ2/w+MGWFmK81shZm9HtiYlcfM+NuwLsRG+7hn0mKOFmtxMREJfhWWu5n5gLHAhUAqMMrMUr83pi3wG6Cfc64TcFclZK00iXVieGhoZ5bm7OOZz7O8jiMictr8OXPvCWQ557Kdc4XARGDI98bcBIx1zu0BcM4F3eOfF3VuwtC0JJ6ZnsXizXu9jiMiclr8KfckYHO57ZyyfeW1A9qZ2Swzm2tmgwMVsCr96bJOJMbV4J43F3OosNjrOCIipyxQF1QjgbbAQGAU8LyZ1fv+IDMbbWYZZpaRl5cXoEMHTt2aUTx6ZVeydx7kkY+0uJiIBC9/yj0XSCm3nVy2r7wcYKpz7qhzbj2whtKy/w7n3HjnXLpzLj0hIeFUM1eqfm3i+Vm/Frw0ZyMz11a//wCJiPjDn3KfD7Q1s5ZmFg2MBKZ+b8y7lJ61Y2bxlE7TZAcwZ5X69eAOtEmszS/fWsq+gqNexxEROWkVlrtzrgi4HZgGZAKTnHMrzOxBM7usbNg0YJeZrQSmA790zu2qrNCVLSbKxxMjurHzwBF+/95yr+OIiJw082pVxPT0dJeRkeHJsf311GdrefyTNTw9Ko1Luzb1Oo6ICGa2wDmXXtG4sH9C9URuG9iabin1eODd5WzPP+x1HBERv6ncT6B0cbGuHCkq5peTl2rtdxEJGir3CrRKqM1vL+rIl2vyeHXeJq/jiIj4ReXuh2t6N+estvE8/EEm63ce9DqOiEiFVO5+MDP+Mbwr0ZER3P3mYoq0uJiIVHMqdz81rhvDXy4/g8Wb9/LsjHVexxEROSGV+0m4rGtTLu3alDGfrWVZzj6v44iI/CCV+0n6y5BONKwdzd2TFnP4qBYXE5HqSeV+kurFRvOP4V3J2nGAhz7I1O2RIlItqdxPwYB2CdzQvyWvzN3IH6euoFgv1xaRaibS6wDB6ncXdcQXYYz/MptdBwp5/Kqu1Ij0eR1LRARQuZ+yiAjjtxd1JL52NA9/uIo9BYWMu6YHcTFRXkcTEdG0zOkaPaA1j4/oytfrdzNy/Fzy9h/xOpKIiMo9EK7onszz16aTnXeQ4c/NZuMuPcUqIt5SuQfIoPaJvHZTL/YdOsqwZ+ewPFf3wYuId1TuAdS9WX0m39KHaJ8xcvxcZq/b6XUkEQlTKvcAa5MYx9u39aVJ3Riue2E+Hyzd6nUkEQlDKvdK0KRuTd66pQ+dk+ty+xsLeWXOBq8jiUiYUblXknqx0bx6Qy/OaZ/I799bweMfr9bTrCJSZVTulahmtI9x1/Tgyh7JPPV5Fr99Z7meZhWRKqGHmCpZpC+Cvw/vQkJcDf45Yx27Dx5hzMg0YqL0NKuIVB6duVcBM+NXgzvwh0tSmbZiOz994Wv2HTrqdSwRCWF+lbuZDTaz1WaWZWb3H+fz68wsz8wWl/26MfBRg9/1/VsyZmQ3Fm3aw1Xj5rAj/7DXkUQkRFVY7mbmA8YCFwKpwCgzSz3O0Dedc93Kfk0IcM6QMaRbEv+69kw27S7gimdn652sIlIp/Dlz7wlkOeeynXOFwERgSOXGCm0D2iXwxk29KSgsZvizs1mas9frSCISYvwp9yRgc7ntnLJ93zfMzJaa2WQzSwlIuhDWNaUek2/pQ0yUj1Hj5zJzbZ7XkUQkhATqgur7QAvnXBfgE+Cl4w0ys9FmlmFmGXl5KrNWCbWZcltfUhrEcv2L85m6ZIvXkUQkRPhT7rlA+TPx5LJ9xzjndjnnvlnrdgLQ43g/yDk33jmX7pxLT0hIOJW8IadRnRjevLkPaSn1ueONRfx71nqvI4lICPCn3OcDbc2spZlFAyOBqeUHmFmTcpuXAZmBixj66taM4uUbenJBaiP+/P5K/jFtlZ5mFZHTUmG5O+eKgNuBaZSW9iTn3Aoze9DMLisbdoeZrTCzJcAdwHWVFThUxUT5ePYnPRjVsxljp6/j128vpai4xOtYIhKkzKszxPT0dJeRkeHJsasz5xxPfLKGpz7P4ryOiTw9qjs1o/U0q4iUMrMFzrn0isbpCdVqxsy454L2PDikE5+t2sE1/5rHvgI9zSoiJ0flXk39tE8LnhnVnaU5+7hy3Gy27dPTrCLiP5V7NXZxlya8+LMz2bL3MMOenU3WjgNeRxKRIKFyr+b6toln4ujeHCkq5srnZrNo0x6vI4lIEFC5B4Ezkuoy+Za+xMVEcfXz85ixeofXkUSkmlO5B4kW8bWYfGsfWsbX4saXMnhnUY7XkUSkGlO5B5HEuBjevLk3Z7ZowN1vLmHCzGyvI4lINaVyDzJxMVG8eP2ZXNS5Mf/3QSZ//TBTT7OKyP/Qa/aCUI1IH0+P6k7DWisY92U2uw4W8rdhXfBFmNfRRKSaULkHKV+E8eCQTjSoFc2Yz9Zy6GgxT17VjSif/jImIir3oGZm3H1+O2rV8PHwh6s4crSYZ67urpdvi4jm3EPB6AGt+cuQTnyauYObXs6goLDI60gi4jGVe4i4pk8LHr2yK7OydnLtC1+z/7DWoxEJZyr3EDK8RzJPjUpj0aa9/HjCPPYWFHodSUQ8onIPMZd0acpzP+nBqq37GTl+Lnn7j1T8m0Qk5KjcQ9B5qY144boz2birgKvGz2HrvkNeRxKRKqZyD1H928bz8g092ZF/hBHj5rB5d4HXkUSkCqncQ9iZLRrw2o29yD9UxJXPzWFdnpYMFgkXKvcQ1zWlHhNH96aopISrxs0hc2u+15FEpAqo3MNAxyZ1mDi6D5EREYwcP5elOXu9jiQilUzlHibaJNbmrVv6EBcTydXPz2P+ht1eRxKRSqRyDyMpDWJ565Y+JMbV4Kf/+pqv1u70OpKIVBK/yt3MBpvZajPLMrP7TzBumJk5M0sPXEQJpCZ1a/LmzX1o3jCW61+az2eZ272OJCKVoMJyNzMfMBa4EEgFRplZ6nHGxQF3AvMCHVICKyGuBhNH96ZD4zhufmUBHyzd6nUkEQkwf87cewJZzrls51whMBEYcpxxfwH+BhwOYD6pJPVio3n1xl6kNavHL95YyOQFem2fSCjxp9yTgM3ltnPK9h1jZt2BFOfcBwHMJpWsTkwUL13fk76t47nvrSW8Mnej15FEJEBO+4KqmUUAjwP3+jF2tJllmFlGXl7e6R5aAiA2OpIJ16ZzbodEfv/ucp7/Uu9lFQkF/pR7LpBSbju5bN834oAzgBlmtgHoDUw93kVV59x451y6cy49ISHh1FNLQMVE+Xjumh5c3LkJD32YyZhP1+q9rCJBzp83Mc0H2ppZS0pLfSRw9TcfOuf2AfHfbJvZDOA+51xGYKNKZYryRTBmZDdionw88ekaCo4Wcf/gDpjpvawiwajCcnfOFZnZ7cA0wAe84JxbYWYPAhnOuamVHVKqRqQvgn8M70LN6AjGfZHN4cJi/nhpJyL04m2RoOPXO1Sdcx8CH35v3x9+YOzA048lXomIMP4y5AxqRvl4fuZ6CgqLeWRYF3wqeJGgohdky/8wM357UUdioyMZ89laDheV8PiIrkT59ECzSLBQuctxmRl3n9+O2Ggff/1oFYePFvPM1WnUiPR5HU1E/KBTMTmhm89uzYNDOvHJyu3c+FIGhwqLvY4kIn5QuUuFftqnBX8f3oVZWTu59t9fc+BIkdeRRKQCKnfxy4j0FMaMTGPhxj38eMI89hUc9TqSiJyAyl38dmnXpjz7kx5kbsln5PNz2XngiNeRROQHqNzlpJyf2ogJ16azfucBrho3h237tE6cSHWkcpeTNqBdAi9f34vt+UcYMW4Om3cXeB1JRL5H5S6npGfLBrx6Yy/2FhQyYtwcsvMOeB1JRMpRucsp65ZSj4mj+1BYVMKIcXNZslkv3hapLlTuclpSm9bhzZv7EO0zrnh2Nn/9KJPDR3UvvIjXVO5y2tok1uajuwZwZY9kxn2RzYVjZjIve5fXsUTCmspdAqJuzSgeGdaF127sRXGJ46rxc3ng3WXsP6z74UW8oHKXgOrXJp7/3nUWN/ZvyevzNnHBE18yfdUOr2OJhB2VuwRcbHQkD1ySytu39iUuJpKfvTifuyYuYvfBQq+jiYQNlbtUmrRm9fnPL87iznPb8sGyrZz/+Be8v2SLXuEnUgVU7lKpoiMjuPv8drz/i/4k16/JL95YxE0vL9CTrSKVTOUuVaJD4zpMua0fD1zcka+y8jj/8S944+tNOosXqSQqd6kyvgjjxrNaMe2uAZyRVJffTFnG1c/PY+Oug15HEwk5Knepcs0b1uL1m3rx1ys6szx3Hz968kue/zKb4hKdxYsEispdPGFmjOrZjE/uOZv+beJ56MNMrvjnLFZv2+91NJGQoHIXTzWuG8PzP03n6VFp5Ow5xCVPz+SJT9ZwpEhLGIicDr/K3cwGm9lqM8sys/uP8/ktZrbMzBab2Vdmlhr4qBKqzIxLuzblk3vO5pIuTRnz2VoufforFm3a43U0kaBVYbmbmQ8YC1wIpAKjjlPerzvnOjvnugF/Bx4PeFIJeQ1qRfPEVd144bp09h8u4opnZ/OX/6ykoFDvbBU5Wf6cufcEspxz2c65QmAiMKT8AOdcfrnNWoCujMkpO6dDIz6+ewA/7tWMf321nsFPzmR21k6vY4kEFX/KPQnYXG47p2zfd5jZz81sHaVn7ncEJp6Eq7iYKP7v8s68Obo3vgjj6gnzuP/tpew7pIXIRPwRsAuqzrmxzrnWwK+BB443xsxGm1mGmWXk5eUF6tASwnq1ashHd57FLWe35q0FOZz/+Bd8vGKb17FEqj1/yj0XSCm3nVy274dMBC4/3gfOufHOuXTnXHpCQoL/KSWsxUT5uP/CDrx7Wz8a1q7B6FcW8PPXF5K3/4jX0USqLX/KfT7Q1sxamlk0MBKYWn6AmbUtt3kxsDZwEUVKdU6uy9Tb+3HfBe34ZMV2zn/iC6YszNESBiLHUWG5O+eKgNuBaUAmMMk5t8LMHjSzy8qG3W5mK8xsMXAPcG2lJZawFuWL4PZz2vLhnf1pFV+LeyYt4Wcvzid37yGvo4lUK+bVWU96errLyMjw5NgSGopLHK/M2cDfp63GgPsv7MCPezUnIsK8jiZSacxsgXMuvaJxekJVgpYvwriuX0um3TWA7s3r8/v3VjBy/FzWbNcSBiIqdwl6KQ1iefn6nvxjeBdWb9/PRWNm8tAHKzlwRA8/SfhSuUtIMDOuTE9h+n0DGd4jmQlfreecR2fw3uJcXXCVsKRyl5DSoFY0jwzrwju39aNx3RjunLiYq8bPZdW2/Ip/s0gIUblLSOqWUo93buvHw0M7s2b7fi5+6iv+/P4K8g/rCVcJDyp3CVm+COPqXs2Yfu9ARp6ZwouzN3DOo1/w9gLdGy+hT+UuIa9+rWgeGtqZ937ej+T6Nbn3rSVc+dwcVm7RVI2ELpW7hI0uyfWYcmtf/j6sC9k7D3LJ0zP543vLtRiZhCSVu4SViAhjxJkpTL93ID/p3ZxX5m7knEdnMCljMyV6h6uEEJW7hKW6sVE8OOQMpt7enxbxtfjV5KUMe242y3P3eR1NJCBU7hLWzkiqy1s39+HRK7uyeXcBlz7zFQ+8u4y9BYVeRxM5LSp3CXsREcbwHsl8du9Aru3TgtfnbWLQozOY+PUmTdVI0FK5i5SpWzOKP13WiQ/uOIs2ibW5f8oyhj47m6U5e72OJnLSVO4i39OxSR0m3dyHJ67qypa9hxgydha/mbKMPQc1VSPBQ+UuchxmxtC0ZD6/92yu79eSSRmbGfTYDF6du5FiTdVIEFC5i5xAXEwUv78klQ/vOIv2jeJ44N3lXD52Fos27fE6msgJqdxF/NC+cRwTR/fmqVFp7Nh/mKH/nM2vJi9h1wG9x1WqJ5W7iJ/MjMu6NuWzewdy84BWTFmYy6BHZ/DynA2aqpFqR+UucpJq14jkNxd15L93nUXn5Lr84b0VXPr0VyzYuNvraCLHqNxFTlGbxDhevaEXY6/uzp6CQoY9O4d7Jy0hb7+masR7KneR02BmXNylCZ/ecza3DmzN1CW5nPPYDF74aj2HjxZ7HU/CmHm1rnV6errLyMjw5NgilWVd3gH+NHUFM9fuJDGuBjed1YqrezWjVo1Ir6NJiDCzBc659IrG+XXmbmaDzWy1mWWZ2f3H+fweM1tpZkvN7DMza34qoUWCXeuE2rx8fU9eu7EXbRvV5qEPM+n3t8958tM1Wq9GqlSFZ+5m5gPWAOcDOcB8YJRzbmW5MYOAec65AjO7FRjonLvqRD9XZ+4SDhZt2sM/Z6zjk5XbqRXt48e9m3Nj/5Yk1onxOpoEqUCeufcEspxz2c65QmAiMKT8AOfcdOdcQdnmXCD5ZAOLhKK0ZvV5/qfpTLtrAOenNmLCzGz6/206v3tnGZt2FVT8A0ROkT/lngRsLredU7bvh9wAfHS8D8xstJllmFlGXl6e/ylFglz7xnE8OTKNGfcNYnh6Mm9l5DDosRncNXERq7ft9zqehKCA3i1jZj8B0oF/HO9z59x451y6cy49ISEhkIcWCQrNGsby8NDOzPz1IG7o35KPV27nR09+yU0vZ7B4s1aflMDx5xJ+LpBSbju5bN93mNl5wO+As51zutFX5AQa1Ynhtxd15NazW/PSnA38e9YGPlk5i35tGvLzgW3o07ohZuZ1TAli/lxQjaT0guq5lJb6fOBq59yKcmPSgMnAYOfcWn8OrAuqIt86cKSIN+Zt4vmZ2ezYf4RuKfX4+aA2nNshkYgIlbx8y98Lqn7d525mFwFPAj7gBefcQ2b2IJDhnJtqZp8CnYGtZb9lk3PushP9TJW7yP86fLSYtxfm8NwX69i8+xDtG8Vx68DWXNKlCZE+PXMoAS73yqByF/lhRcUl/GfpVv45I4s12w/QrEEsN5/dimHdk4mJ8nkdTzykchcJASUljk8ztzN2xjqWbN6rp15F5S4SSpxzzF63i7HTs5i9bhd1a0bxs34tuK5vC+rFRnsdT6qQyl0kRJV/6jU22sePezXjxrNa0UhPvYYFlbtIiFu9bT/Pzshi6pItREZEMDw9mVsGtKZZw1ivo0klUrmLhImNuw4y7stsJmfkUFRSwmVdm3LrwDa0bxzndTSpBCp3kTCzPf8wE2Zm89q8TRQUFnN+aiNuObsV3ZvV1wNRIUTlLhKm9hws5MXZG3hx9gb2HTpKq/haXJ6WxNC0JFIaaMom2KncRcLcgSNFfLB0C1MW5jJvfen7Xc9sUZ+haclc3LkJdWOjPE4op0LlLiLH5Owp4L3FW5iyMId1eQeJ9kVwbsdEhqYlMbB9ItGRevo1WKjcReR/OOdYlruPKQtzeX/JFnYdLKR+bBSXdGnK0O5JpKXU0/x8NadyF5ETOlpcwsy1eUxZmMsnK7dzpKiElvG1uLxb6fy8bqmsnlTuIuK3/MNH+e+ybUxZlMPc7NL5+fTm9RnaPYlLOjfV/Hw1onIXkVOSu/cQ7y7K5Z1FuWTtOEC0L4JzOiQytHsSgzQ/7zmVu4icFuccy3PzmbIoh/eXbGHngULqxUZxSZcmDE1Lpnszzc97QeUuIgFTVFzCzLU7mbIol49XbONIUQktGsYeu3++ecNaXkcMGyp3EakU+w8f5aPl23hnYS5z1+/COejRvD5D05K4pEsTrVJZyVTuIlLptuw9xHuLt/DOohzWbC+dnx/UIYGhackM6pBAjUi9WCTQVO4iUmWcc6zYks+UhblMXbKFnQeOULdm6fz8Fd2TtL5NAKncRcQTRcUlzMzayTsLc/l45TYOHy2hcZ0Y+rWJp1+bhvRrE6+150+Dv+Wu93SJSEBF+iIY1D6RQe0TOXCkiP8u38b0VTv4bNV23l6YA0CbxNr0bxNP39YN6d26IXVidB99oOnMXUSqREmJY+XWfGZl7WTWul18vX4Xh4+WEGHQJbleadm3aUiP5vU1V38CmpYRkWrtSFExizbtLS37rJ0sydlHcYkjJiqCM1s0KJ3GaR1PatM6+CI0X/+NgJa7mQ0GxgA+YIJz7pHvfT4AeBLoAox0zk2u6Geq3EWkvP2HjzIvezez1pWW/ZrtBwCoFxtFn1YNy+bs42nRMDasL84GbM7dzHzAWOB8IAeYb2ZTnXMryw3bBFwH3HdqcUUk3MXFRHFeaiPOS20EwI78w8xet+vYmf1Hy7cBkFSvJn1bN6R/23j6tG5IYpwuzh6PPxdUewJZzrlsADObCAwBjpW7c25D2WcllZBRRMJQYp0YLk9L4vK0JJxzbNhVwFdZO5mdtZOPV27nrQWlF2fbN4qjb5uG9G8TT8+WDYjTxVnAv3JPAjaX284Bep3KwcxsNDAaoFmzZqfyI0QkDJkZLeNr0TK+Ftf0bk5xiWPllvzSsl+3k9fnbeLfszbgizC6pdSjX+vSaZy0ZvXDdqGzKr0V0jk3HhgPpXPuVXlsEQkdvgijc3JdOifX5daBrTl8tJiFm/aUTeHs4pnpWTz1eRY1o3yc2bIB/ds0pG/reDo2CZ+Ls/6Uey6QUm47uWyfiEi1EBPlo2/rePq2jueXP4J9h44yL3vXsdsuH/5wFQA1o3y0bxxHxyZ1SG0SR2rTOrRvXIfaNULvkR9//o3mA23NrCWlpT4SuLpSU4mInIa6NaO4oFNjLujUGIBt+w4zJ3snS3P2kbk1nw+XbeWNrzcdG9+8YSwdG9chtWkdOjapQ8cmcSTVqxnUd+X4eyvkRZTe6ugDXnDOPWRmDwIZzrmpZnYm8A5QHzgMbHPOdTrRz9StkCLiFeccW/YdJnNLPplb88nclk/m1v1s2HWQbyqxTkxkWdHXIbXsf9s2qk1MlLcPWOkhJhGRk3TwSBGrtu0vLfyt+azcms/qbfspKCwGSuf6WyfUOlb63xR/QlyNKsuotWVERE5SrRqR9Ghenx7N6x/bV1Li2Li74NvC35LP/PW7eW/xlmNj4mvXoGOTOFKbfDu10yq+FpE+7+7UUbmLiJxARMS3t2Fe1LnJsf17CwrJ3Lr/2Bl+5tZ8/j1rA4XFpY/7REdG0K5R7WNTOt/8qluzau7D17SMiEiAHC0uITvvICu37vu2+Lfks+tg4bExSfVq8qvB7RnSLemUjqFpGRGRKhbli6B94zjaN45jaFrpPuccefuPlJ3dlxZ+VczRq9xFRCqRmZFYJ4bEOjEMbJ9YZccNz+dyRURCnMpdRCQEqdxFREKQyl1EJASp3EVEQpDKXUQkBKncRURCkMpdRCQEebb8gJnlARtP8bfHAzsDGCfY6fv4Ln0f39J38V2h8H00d84lVDTIs3I/HWaW4c/aCuFC38d36fv4lr6L7wqn70PTMiIiIUhSoh6AAAACrElEQVTlLiISgoK13Md7HaCa0ffxXfo+vqXv4rvC5vsIyjl3ERE5sWA9cxcRkRMIunI3s8FmttrMsszsfq/zeMXMUsxsupmtNLMVZnan15mqAzPzmdkiM/uP11m8Zmb1zGyyma0ys0wz6+N1Jq+Y2d1lf06Wm9kbZhbjdabKFlTlbmY+YCxwIZAKjDKzVG9TeaYIuNc5lwr0Bn4ext9FeXcCmV6HqCbGAP91znUAuhKm34uZJQF3AOnOuTMAHzDS21SVL6jKHegJZDnnsp1zhcBEYIjHmTzhnNvqnFtY9s/7Kf2De2ovZQwRZpYMXAxM8DqL18ysLjAA+BeAc67QObfX21SeigRqmlkkEAts8ThPpQu2ck8CNpfbziHMCw3AzFoAacA8b5N47kngV0CJ10GqgZZAHvDvsmmqCWZWy+tQXnDO5QKPApuArcA+59zH3qaqfMFW7vI9ZlYbeBu4yzmX73Uer5jZJcAO59wCr7NUE5FAd+BZ51wacBAIy2tUZlaf0r/htwSaArXM7Cfepqp8wVbuuUBKue3ksn1hycyiKC3215xzU7zO47F+wGVmtoHS6bpzzOxVbyN5KgfIcc5987e5yZSWfTg6D1jvnMtzzh0FpgB9Pc5U6YKt3OcDbc2spZlFU3pRZKrHmTxhZkbpfGqmc+5xr/N4zTn3G+dcsnOuBaX/v/jcORfyZ2c/xDm3DdhsZu3Ldp0LrPQwkpc2Ab3NLLbsz825hMHF5UivA5wM51yRmd0OTKP0ivcLzrkVHsfySj/gGmCZmS0u2/db59yHHmaS6uUXwGtlJ0LZwM88zuMJ59w8M5sMLKT0LrNFhMGTqnpCVUQkBAXbtIyIiPhB5S4iEoJU7iIiIUjlLiISglTuIiIhSOUuIhKCVO4iIiFI5S4iEoL+H8XSClRWZC7iAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_bi_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 100.00%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}