{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Bi-directional Gated Recurrent Unit.\n",
    "\n",
    "### Many to One Classification by Bi-directional GRU\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://pozalabs.github.io/blstm/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharBiGRU class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharBiGRU:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "        \n",
    "        # Bi-directional GRU\n",
    "        with tf.variable_scope('bi-directional_gru'):\n",
    "            gru_fw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            gru_bw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            _, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw = gru_fw_cell,\n",
    "                                                                    cell_bw = gru_bw_cell,\n",
    "                                                                    inputs = self._X_batch,\n",
    "                                                                    sequence_length = self._X_length,\n",
    "                                                                    dtype = tf.float32)\n",
    "\n",
    "            final_state = tf.concat([output_states[0], output_states[1]], axis = 1)\n",
    "\n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n",
    "                                               activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharBiGRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_bi_gru = CharBiGRU(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n",
    "                        n_of_classes = 2, hidden_dim = 16, dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_bi_gru.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.670\n",
      "epoch :   2, tr_loss : 0.624\n",
      "epoch :   3, tr_loss : 0.598\n",
      "epoch :   4, tr_loss : 0.550\n",
      "epoch :   5, tr_loss : 0.518\n",
      "epoch :   6, tr_loss : 0.489\n",
      "epoch :   7, tr_loss : 0.456\n",
      "epoch :   8, tr_loss : 0.431\n",
      "epoch :   9, tr_loss : 0.394\n",
      "epoch :  10, tr_loss : 0.366\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_bi_gru.ce_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x115f56f28>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VPW9//HXJzs7kUS2AIEQoLLDsMuiiKJtQQUUpC1YFUERUKu1Pu793V57b12LqAVkcasFWbWlboiyyiYJsshq2MMadpAlBL6/PzLawEUzkOVMZt7PxyMPOWfOybwzwpvDfL/fM+acQ0REwkOE1wFERKT4qPRFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIxEeR3gUgkJCS45OdnrGCIiJUp6evpB51xifscFXeknJyeTlpbmdQwRkRLFzHYEcpze3hERCSMqfRGRMKLSFxEJIyp9EZEwotIXEQkjKn0RkTCi0hcRCSMhU/rnLzj+/PEGdh0+5XUUEZGgFTKlv+PQd0z5aid3jl3C2sxjXscREQlKIVP6dRLLMnNIe2IiI7h7/FLmbTzgdSQRkaATMqUPkFq5HB881J46iWW4/29pTF6+0+tIIiJBJaRKH+Da8nFMHdSOTqkJPP3BWl6cvRHnnNexRESCQsiVPkCZ2Cgm/MZHv9Y1GD1vC49OXcXZnPNexxIR8VzQ3WWzsERFRvDnOxqTFF+aF2dvYt/xM4z7tY8KpaK9jiYi4pmQvNL/npnx8A11efnupqTvOELvsUvYffS017FERDwT0qX/vTuaJ/HOb1uz7/gZ7hi9mG92a0qniISnsCh9gPYpCcwc0p6oCOPucUuZv0lTOkUk/IRN6QPUq1yODx7uQK1KZbjvnTSmrtCUThEJLwGVvpl1N7NNZpZhZk/9yDF3mdl6M1tnZpPz7D9vZqv8X7MKK/jVqlw+jmmD29GhbgK/n7mWkZ9t0pROEQkb+c7eMbNIYDTQDcgEVpjZLOfc+jzHpAJ/ADo4546Y2bV5vsVp51yzQs5dIGVjo3hjgI//+OAbXp2bQeaR0zzXqwkxUWH1Dx8RCUOBTNlsDWQ457YCmNkUoCewPs8xDwCjnXNHAJxzQf+GeXRkBM/1akxSfCn+Mmcz+0+cYeyvWlI+TlM6RSR0BXJpWx3YlWc7078vr3pAPTNbbGbLzKx7nsfizCzNv//2yz2BmQ3yH5OWlZV1RT9AQZgZj3RN5S99mrJ862H6jF3KHk3pFJEQVljvZ0QBqUAXoB8wwcwq+h+r5ZzzAfcAo8ws5dKTnXPjnXM+55wvMTGxkCIFrlfL3Cmde46e5o4xi1m/53ixZxARKQ6BlP5uoEae7ST/vrwygVnOuXPOuW3AZnL/EsA5t9v/363AfKB5ATMXiQ51E5g+pB0RZtw1bikLNxffvzhERIpLIKW/Akg1s9pmFgP0BS6dhfMPcq/yMbMEct/u2Wpm8WYWm2d/By4eCwgqDaqU54OHOpAUX4rfvr2CaWm78j9JRKQEybf0nXM5wFBgNrABmOacW2dmz5hZD/9hs4FDZrYemAc84Zw7BPwMSDOz1f79z+Wd9ROMqlSIY/rgdrRLqcSTM9bw8pzNmtIpIiHDgq3QfD6fS0tL8zoG585f4On31zI9PZPeLZN49s7GREdqSqeIBCczS/ePn/6kkL3LZkFFR0bwQu8mVI8vxajPv2X/8TOM6d+CcprSKSIlmC5df4KZMeKmerzYuwlLtxyiz+tL2XfsjNexRESumko/AH18NXhzYCsyj+RO6dy4T1M6RaRkUukHqFO9RKY92A7noM/YpXz57UGvI4mIXDGV/hW4rlp5Pni4PdXjSzHwra+YkZ7pdSQRkSui0r9CVSuUYtrgdrSpcw2/m76aV7/4VlM6RaTEUOlfhfJx0bw1sDV3tqjOyDmbeWrmWs6dv+B1LBGRfGnK5lWKiYrgL32akhRfmle/+Ja9/imdZWP1kopI8NKVfgGYGY91q8fzvRqzOOMgd72+lP3HNaVTRIKXSr8Q3N2qJm8ObMWOQ99xx+jFbN5/wutIIiKXpdIvJJ3rJTJtcDtyLjh6jVnCxEVbOXPuvNexREQuotIvRA2rVeCDhzvQrGZF/uejDXR6YR5/W7qdszkqfxEJDir9Qla9Yineva8NUwe1JblSGf7fP9dx40sLmLpip2b4iIjndJfNIuSc48uMg/zls82s2nWUWpVKM+KmVHo0rU5khHkdT0RCSKB32dSVfhEyMzqmJvLBQ+15Y4CPMjFRPDp1NbeMWshHa/Zy4UJw/YUrIqFPpV8MzIyuP6vMh49cz9j+LTDg4ckr+flrXzJn/X6t6BWRYqPSL0YREcatjavy6YhOjLq7Gaezc3jgb2ncPnoxCzZnqfxFpMip9D0QGWHc3rw6nz/WmRd6NeHgyWwGvPkVd41byrKth7yOJyIhTAO5QSA75wJT03bx17nfsv/4WTrUrcRj3erTsla819FEpIQIdCBXpR9Ezpw7z6TlOxk7P4ODJ7O5oX4ij99cn0bVK3gdTUSCnEq/BPvubA7vLN3OuAVbOXb6HN0bVuHRbvWoX6Wc19FEJEip9EPA8TPnePPLbbyxaBsns3P4ZZNqjLgplTqJZb2OJiJBRqUfQo6eymb8wq28tTj3lg53tkhieNdUalxT2utoIhIkCnVxlpl1N7NNZpZhZk/9yDF3mdl6M1tnZpPz7B9gZt/6vwYE/iPI9yqWjuHJ7g1Y9PsbuLdDbWat3sMNL83n6Q/WsvfYaa/jiUgJku+VvplFApuBbkAmsALo55xbn+eYVGAacKNz7oiZXeucO2Bm1wBpgA9wQDrQ0jl35MeeT1f6+dt37Ayj52UwZcVOzIx7WtfkoRtSuLZcnNfRRMQjhXml3xrIcM5tdc5lA1OAnpcc8wAw+vsyd84d8O+/BZjjnDvsf2wO0D3QH0Iur0qFOP50eyPm/a4LdzSrzrvLdtDphXk8+8kGDn+X7XU8EQligZR+dWBXnu1M/7686gH1zGyxmS0zs+5XcK5cpaT40jzfuwmfP9aZ7g2rMH7hVjo+P5eRn23i2OlzXscTkSBUWCtyo4BUoAvQD5hgZhUDPdnMBplZmpmlZWVlFVKk8FE7oQyj+jZn9ohOdK6fyKtzM+j4/FzGzt+im7qJyEUCKf3dQI0820n+fXllArOcc+ecc9vIHQNIDfBcnHPjnXM+55wvMTHxSvJLHvUql2NM/5Z8NOx6WiVfw/OfbmT41FVk5+g+/iKSK5DSXwGkmlltM4sB+gKzLjnmH+Re5WNmCeS+3bMVmA3cbGbxZhYP3OzfJ0WoYbUKTBzg46lbG/Cv1Xu4750VnDyb43UsEQkC+Za+cy4HGEpuWW8Apjnn1pnZM2bWw3/YbOCQma0H5gFPOOcOOecOA38i9y+OFcAz/n1SxMyMwZ1TeKF3E5ZsOUT/Ccs4dPKs17FExGNanBUGPl+/n4cnr6R6xVL87b7WJMVrUZdIqNEnZ8kPbrquMn+/vw0HT56l19glbNp3wutIIuIRlX6YaJV8DdMGtwOgz+tLSNuud9lEwpFKP4w0qFKeGYPbk1A2lv4Tl/PFhv1eRxKRYqbSDzM1rinN9MHtqF+lHIPeTWd62q78TxKRkKHSD0OVysYy+YG2tKtTiSdmrGHcgi1eRxKRYqLSD1NlY6N4Y6CPXzSpyrOfbOR/P1qv1bsiYSDK6wDindioSF7t25xKZWKYsGgbh05m83zvJkRH6lpAJFSp9MNcRITxxx4NqVQ2lpFzNnPkVDaj+7egdIx+a4iEIl3SCWbGsK6p/O8djViwOYv+E5dz9JRu0SwSilT68oP+bWoxpn8L1u0+Tp/Xl+pTuURCkEpfLtK9UVXe/m0r9h47Q68xS8g4oNW7IqFEpS//R/uUBKYMakv2eUfv15fy9c4f/XRLESlhVPpyWY2qV2DmkHaUj4vmngnLWbBZH24jEgpU+vKjalUqw4wh7aidUIb73l7BP1f9n8+/EZESRqUvP+nacnFMebAtvuR4hk9ZxZtfbvM6kogUgEpf8lU+Lpq3721N94ZVeObD9bzw6UaC7XMYRCQwKn0JSFx0JKP7t6Bf65qMmb+Fp2auJee8PntXpKTRsksJWGSE8ec7GpFQNobX5mZw+FQ2r/VrTlx0pNfRRCRAutKXK2JmPH5zff74y+v4fMN+fvPGVxw7fc7rWCISIJW+XJWBHWrzSt/mfL3rCHePW8qB42e8jiQiAVDpy1Xr0bQabw5sxc7Dp7hz7BK2HfzO60gikg+VvhRIx9RE3nugLaeyz9N77BLWZh7zOpKI/ASVvhRY0xoVmT64HXHRkfQdv5TFGQe9jiQiP0KlL4UiJbEsM4e0Jym+NPe+tYKP1uz1OpKIXEZApW9m3c1sk5llmNlTl3l8oJllmdkq/9f9eR47n2f/rMIML8GlSoU4pj3YjiZJFRj63kreXbrd60gicol85+mbWSQwGugGZAIrzGyWc279JYdOdc4Nvcy3OO2ca1bwqFISVCgdzbv3tWHo5JX85z/XcfBkNiNuSsXMvI4mIgR2pd8ayHDObXXOZQNTgJ5FG0tKslIxkYz7dUt6t0zilS++5fcz13DijObyiwSDQEq/OrArz3amf9+lepnZGjObYWY18uyPM7M0M1tmZrdf7gnMbJD/mLSsLN3CNxRERUbwYu8mPHxDCtPTM+k2ciGffrPP61giYa+wBnL/BSQ755oAc4B38jxWyznnA+4BRplZyqUnO+fGO+d8zjlfYmJiIUUSr5kZT9zSgPeHtKdi6WgG/z2d+99JY/dRfQyjiFcCKf3dQN4r9yT/vh845w455876NycCLfM8ttv/363AfKB5AfJKCdS8Zjz/euR6nr6tAYszDtJt5AImLtqqG7aJeCCQ0l8BpJpZbTOLAfoCF83CMbOqeTZ7ABv8++PNLNb/6wSgA3DpALCEgejICAZ1SmHOY51oW6cS//PRBnqOXsyazKNeRxMJK/mWvnMuBxgKzCa3zKc559aZ2TNm1sN/2DAzW2dmq4FhwED//p8Baf7984DnLjPrR8JIUnxp3hjgY2z/FmSdOMvtoxfzx1nrNNArUkws2D4Mw+fzubS0NK9jSDE4fuYcL83exLvLdlC5XBx/7NGQWxpW1vROkatgZun+8dOfpBW54pnycdE807PRRQO9D/wtXQO9IkVIpS+e00CvSPFR6UtQ0ECvSPFQ6UtQ0UCvSNFS6UvQMTNubVyVzx/vzK/a1uKdpdt/WNEbbBMPREoalb4ELQ30ihQ+lb4EPQ30ihQelb6UCBroFSkcKn0pUTTQK1IwKn0pcTTQK3L1VPpSYmmgV+TKqfSlxNNAr0jgVPoSEjTQKxIYlb6EFA30ivw0lb6EnMsN9N40cgEfrtmjgV4Jeyp9CVnfD/R+8FAHrikTy9DJX9Nr7BLSdxz2OpqIZ1T6EvKa1ajIh49czwu9mpB55DS9xi7loUnpbD/4ndfRRIqdPjlLwsqp7BwmLNzGuIVbOHf+Ar9qW4thN6YSXybG62giBaJPzhK5jNIxUQy/KZX5v+tC75ZJvLNkO51enMf4hVs4c+681/FEipxKX8LSteXjePbOJnw6ohO+WvH8+eON3DRyAbNWa7BXQptKX8JavcrleOve1vz9vjaUi4tm2Htfc/uYJXy1TYO9EppU+iLA9akJfPjI9bzUpyn7j53hrnFLGfS3NLZmnfQ6mkih0kCuyCVOZ5/njS+3Mnb+Fs7mXKB/m5oM65pKpbKxXkcT+VGFOpBrZt3NbJOZZZjZU5d5fKCZZZnZKv/X/XkeG2Bm3/q/BlzZjyFS/ErFRDL0xlTmP3EDfVvX4O/Ld9LlxfmMna/BXin58r3SN7NIYDPQDcgEVgD9nHPr8xwzEPA554Zecu41QBrgAxyQDrR0zh35sefTlb4Em4wDJ3juk418vuEA1SuW4olb6tOjaTUiIszraCI/KMwr/dZAhnNuq3MuG5gC9Awwxy3AHOfcYX/RzwG6B3iuSFCoe205Jg5oxeQH2hBfJpoRU1fRY/SXLN1yyOtoIlcskNKvDuzKs53p33epXma2xsxmmFmNKzxXJOi1T0lg1sPX8/LdTTl8Mpt+E5Zx/zsryDhwwutoIgErrNk7/wKSnXNNyL2af+dKTjazQWaWZmZpWVlZhRRJpPBFRBh3NE9i7u+68PvuDVi+9TC3jFrEf/xjLQdPnvU6nki+Ain93UCNPNtJ/n0/cM4dcs59/zt+ItAy0HP95493zvmcc77ExMRAs4t4Ji46kiFdUpj/RBd+1aYmU77aRZcX5zN6XganszXYK8ErkNJfAaSaWW0ziwH6ArPyHmBmVfNs9gA2+H89G7jZzOLNLB642b9PJCRUKhvLf/dsxOxHO9E+pRIvzt7EDS/NZ0Z6JhcuBNd0aBEIoPSdcznAUHLLegMwzTm3zsyeMbMe/sOGmdk6M1sNDAMG+s89DPyJ3L84VgDP+PeJhJSUxLKM/42PaQ+2o3L5WH43fTW/eO1LFmcc9DqayEW0OEukkF244Phw7V6e/2Qju4+e5ob6ifzhtp9Rr3I5r6NJCNNdNkU8EhFh9GhajS8e78zTtzUgbccRuo9ayB/eX8OBE2e8jidhTqUvUkTioiMZ1CmFhU/cwID2ycxIz6TLi/N5fcEWcs5f8DqehCmVvkgRiy8Tw3/9siFzHu1M+5QEnvtkIz1HL+ab3ce8jiZhSKUvUkySE8ow4TctGdu/BfuPn6Xn6MU8+8kG3c9HipVKX6QYmRm3Nq7KF491pneLJMYt2Er3UQt1SwcpNip9EQ9UKB3N872bMPn+Nlxw0G/CMp6auYZjp895HU1CnEpfxEPt6yYwe0QnHuxUh2lpu+g2cgGffrPP61gSwlT6Ih4rFRPJH277Gf98+HoqlY1l8N/TGfxuOgeOa3qnFD6VvkiQaJxUgVlDO/Bk9/rM3XSAriMXMOWrnfqgdilUKn2RIBIdGcFDXeoye0QnrqtanqfeX8s9E5az/eB3XkeTEKHSFwlCtRPK8N4DbXn2zsZ8s/sYt4xaqEVdUihU+iJBKiLC6Ne6Jp8/3pnO9RK1qEsKhUpfJMhVLh/HuF9rUZcUDpW+SAmgRV1SWFT6IiWIFnVJQan0RUogLeqSq6XSFymhtKhLroZKX6SE06IuuRIqfZEQoEVdEiiVvkgI0aIuyY9KXyTEaFGX/BSVvkiI0qIuuRyVvkgI06IuuZRKXyQMXG5R131vryB9xxGvo0kxC6j0zay7mW0yswwze+onjutlZs7MfP7tZDM7bWar/F+vF1ZwEbly3y/qeqxbPdJ3HqHX2CXcPW4pCzdnaYpnmLD8/kebWSSwGegGZAIrgH7OufWXHFcO+AiIAYY659LMLBn40DnXKNBAPp/PpaWlXcnPICJX4buzObz31U4mLtrGvuNnaFy9Ag91SeGWhlWIiDCv48kVMrN055wvv+MCudJvDWQ457Y657KBKUDPyxz3J+B5QMsBRUqAMrFR3N+xDgue7MJzdzbmxJlzDJm0km4vL2B62i7OaZpnSAqk9KsDu/JsZ/r3/cDMWgA1nHMfXeb82mb2tZktMLOOl3sCMxtkZmlmlpaVlRVodhEpBLFRkfRtXZMvHu/Ca/2aEx0ZwRMz1tDlxfm8vXgbp7M12yeUFHgg18wigJHA45d5eC9Q0znXHHgMmGxm5S89yDk33jnnc875EhMTCxpJRK5CZITxy6bV+GR4R94a2IqqFeL447/Wc/3zcxk9L0N38gwRgZT+bqBGnu0k/77vlQMaAfPNbDvQFphlZj7n3Fnn3CEA51w6sAWoVxjBRaRomBk3NLiWGUPaM+3BdjSqXoEXZ2/i+ufm8vynG8k6cdbriFIAgQzkRpE7kNuV3LJfAdzjnFv3I8fPB37nH8hNBA47586bWR1gEdDYOXf4x55PA7kiweeb3ccYO38LH3+zl5jICPq2qsEDneqQFF/a62jiF+hAblR+BzjncsxsKDAbiATedM6tM7NngDTn3KyfOL0T8IyZnQMuAIN/qvBFJDg1ql6B0f1bsCXrJOMWbGHS8p1MWr6THs2q8VCXFOpeW87riBKgfK/0i5uu9EWC356jp5mwaCvvfbWTszkXuPm6yjzUpS5Na1T0OlrYCvRKX6UvIlft0MmzvL1kO28v2c6JMzl0TE1gSJcU2tWphJnm+hcnlb6IFJsTZ84xaXnuQq+DJ8/SvGZFHupSl64NrtVCr2Ki0heRYnfm3Hmmp2cybsEWMo+cpn7lcgzpksIvmlQlKlK3+ipKKn0R8cy58xf4cM0exszbwrcHTlLjmlI82CmF3i2TiIuO9DpeSFLpi4jnLlxwfL5hP6Pnb2H1rqMklovl/utr079tLcrG5jt5UK6ASl9EgoZzjiVbDjFmfgaLMw5RoVQ0A9onc2/7ZOLLxHgdLySo9EUkKK3adZQx8zL4bP1+ysRE8l89GtKnZZJm+xRQYd5lU0Sk0DSrUZHxv/Hx2aOdaJxUgSdnrGHo5K85dkr39ikOKn0R8US9yuWYdH9bnuxen9nr9nHrKwtZvlUf41jUVPoi4pnICOOhLnWZOaQ9MVER9JuwjL98tkn38i9CKn0R8VzTGhX5aFhHerVI4rW5GfR5fSk7Dn3ndayQpNIXkaBQJjaKF/s05a/3NGdL1klue2UR76/M1Gf3FjKVvogElV80qcanIzrRsFoFHpu2muFTVnH8jAZ5C4tKX0SCTvWKpXhvUFse71aPj9bu5dZRi0jbrruyFwaVvogEpcgI45GuqUwf3I6ICLhr3FJenrOZHA3yFohKX0SCWoua8Xw8rCO3N6vOK198y93jl7Hr8CmvY5VYKn0RCXrl4qIZeXczXunbjM37TnDbK4v456rd+Z8o/4dKX0RKjJ7NqvPx8I7Uq1KO4VNW8djUVZzQIO8VUemLSIlS45rSTB3UlhE3pfKPVbv5+atfsnLnEa9jlRgqfREpcaIiIxhxUz2mPdiO8xccfV5fymtffMv5C5rTnx+VvoiUWL7ka/hkREd+3rgqf5mzmX7jl7H76GmvYwU1lb6IlGjl46J5pW8zRt7VlHV7jtF91EI+XLPH61hBS6UvIiWemXFniyQ+Ht6RlMSyDJ38NU9MX813Z3O8jhZ0Aip9M+tuZpvMLMPMnvqJ43qZmTMzX559f/Cft8nMbimM0CIil1OrUhmmD27HIzfWZcbKTH7+6iJW7zrqdaygkm/pm1kkMBq4FbgO6Gdm113muHLAcGB5nn3XAX2BhkB3YIz/+4mIFInoyAgev7k+Ux5oS3bOBXqNXcKY+Rka5PUL5Eq/NZDhnNvqnMsGpgA9L3Pcn4DngTN59vUEpjjnzjrntgEZ/u8nIlKk2tSpxCfDO3FLoyq88Okm+k9cxt5jGuQNpPSrA7vybGf69/3AzFoANZxzH13puSIiRaVC6Wj+2q85L/RuwprMY3QftYhP1u71OpanCjyQa2YRwEjg8QJ8j0FmlmZmaVlZWQWNJCLyAzPjLl8NPhrWkVqVSjNk0kqemrmGU9nhOcgbSOnvBmrk2U7y7/teOaARMN/MtgNtgVn+wdz8zgXAOTfeOedzzvkSExOv7CcQEQlA7YQyzBjcniFdUpiatotfvPolazOPeR2r2AVS+iuAVDOrbWYx5A7Mzvr+QefcMedcgnMu2TmXDCwDejjn0vzH9TWzWDOrDaQCXxX6TyEiEoCYqAh+370Bk+5vw6ns89w5djHjFmzhQhgN8uZb+s65HGAoMBvYAExzzq0zs2fMrEc+564DpgHrgU+Bh51z5wseW0Tk6rVPSeCT4R3p2qAyz36ykXsmLuOb3eFx1W/B9vmTPp/PpaWleR1DRMKAc45pabt49pONHD11jp83qcrj3epRJ7Gs19GumJmlO+d8+R0XVRxhRESCkZlxd6ua3Nq4KhMXbmXil9v49Jt93OWrwfCuqVSpEOd1xEKnK30REb+sE2cZPS+DSct3EGHGwA7JDOmcQsXSMV5Hy1egV/oqfRGRS+w6fIqXP9/MB1/vpmxsFIM7p3Bvh2RKxwTvmyMqfRGRAtq07wQvzt7E5xv2k1A2luFd63J3q5rERAXfvSoDLf3gSy4iEiTqVynHxAE+Zg5pT53EMvznP9dx08gF/OPr3SV2mqdKX0QkHy1rxTN1UFvevrcVZWOjGDF1Fbe9uoi5G/cTbO+W5EelLyISADOjS/1r+fCR63m1X3NOnzvPb99O465xS1mx/bDX8QKm0hcRuQIREUaPptX4/LHO/M/tjdhx6BR9Xl/Kb99ewYa9x72Oly8N5IqIFMDp7PO8vWQ7Y+dncOJsDj2bVuOxbvWpWal0sebQ7B0RkWJ07NQ5xi3cwpuLt5Fz3tGvdU0e6VqXa8sVzwIvlb6IiAcOHD/Dq3O/ZcpXu4iOjOC31yczqFMKFUpFF+nzqvRFRDy0/eB3jJyzmVmr91ChVDRDuqQwoF0ypWKK5hNjVfoiIkFg3Z5jvDR7E/M2ZVG5fCzDu9ajjy+J6MjCnUejxVkiIkGgYbUKvHVva6YOaktSfGme/mAtN7+8kH+t3uPJAi+VvohIMWhTpxIzBrdj4m98xERG8Mh7X/PLv37Jgs1ZxbrAS6UvIlJMzIybrqvMx8M78vLdTTl2+hwD3vyKfhOWsXLnkWLJoNIXESlmkRHGHc2TmPt4F/67R0MyDpzkzjFLeHjSyiK/6g/e+4SKiIS4mKgIBrRPpnfLJN5avI3T585jZkX6nCp9ERGPlYmNYuiNqcXyXHp7R0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCSNDdWtnMsoAdBfgWCcDBQopT0um1uJhej4vp9fi3UHgtajnnEvM7KOhKv6DMLC2Qe0qHA70WF9PrcTG9Hv8WTq+F3t4REQkjKn0RkTASiqU/3usAQUSvxcX0elxMr8e/hc1rEXLv6YuIyI8LxSt9ERH5ESFT+mbW3cw2mVmGmT3ldR4vmVkNM5tnZuvNbJ2ZDfc6k9fMLNLMvjazD73O4jUzq2hmM8xso5ltMLN2Xmfykpk96v9z8o2ZvWdmcV5nKkohUfpmFgl27bA1AAACK0lEQVSMBm4FrgP6mdl13qbyVA7wuHPuOqAt8HCYvx4Aw4ENXocIEq8AnzrnGgBNCePXxcyqA8MAn3OuERAJ9PU2VdEKidIHWgMZzrmtzrlsYArQ0+NMnnHO7XXOrfT/+gS5f6ire5vKO2aWBPwcmOh1Fq+ZWQWgE/AGgHMu2zl31NtUnosCSplZFFAa2ONxniIVKqVfHdiVZzuTMC65vMwsGWgOLPc2iadGAU8CF7wOEgRqA1nAW/63uyaaWRmvQ3nFObcbeAnYCewFjjnnPvM2VdEKldKXyzCzssBMYIRz7rjXebxgZr8ADjjn0r3OEiSigBbAWOdcc+A7IGzHwMwsntx3BWoD1YAyZvYrb1MVrVAp/d1AjTzbSf59YcvMoskt/EnOufe9zuOhDkAPM9tO7tt+N5rZ372N5KlMINM59/2//GaQ+5dAuLoJ2Oacy3LOnQPeB9p7nKlIhUrprwBSzay2mcWQOxAzy+NMnjEzI/c92w3OuZFe5/GSc+4Pzrkk51wyub8v5jrnQvpK7qc45/YBu8ysvn9XV2C9h5G8thNoa2al/X9uuhLiA9tRXgcoDM65HDMbCswmd/T9TefcOo9jeakD8GtgrZmt8u972jn3sYeZJHg8AkzyXyBtBe71OI9nnHPLzWwGsJLcWW9fE+Krc7UiV0QkjITK2zsiIhIAlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgYUemLiIQRlb6ISBj5/2H04vKsXcqoAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_bi_gru.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 100.00%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}