{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Bi-directional Long Short-Term Memory.\n",
    "\n",
    "### Many to One Classification by Bi-directional LSTM\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://pozalabs.github.io/blstm/\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharBiLSTM class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharBiLSTM:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "        \n",
    "        # Bi-directional LSTM\n",
    "        with tf.variable_scope('bi-directional_lstm'):\n",
    "            lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            _, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw = lstm_fw_cell,\n",
    "                                                               cell_bw = lstm_bw_cell,\n",
    "                                                               inputs = self._X_batch,\n",
    "                                                               sequence_length = self._X_length,\n",
    "                                                               dtype = tf.float32)\n",
    "\n",
    "            final_state = tf.concat([output_states[0].h, output_states[1].h], axis = 1)\n",
    "\n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n",
    "                                               activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharBiLSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_bi_lstm = CharBiLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n",
    "                          n_of_classes = 2, hidden_dim = 16, dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_bi_lstm.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.726\n",
      "epoch :   2, tr_loss : 0.679\n",
      "epoch :   3, tr_loss : 0.643\n",
      "epoch :   4, tr_loss : 0.618\n",
      "epoch :   5, tr_loss : 0.583\n",
      "epoch :   6, tr_loss : 0.554\n",
      "epoch :   7, tr_loss : 0.518\n",
      "epoch :   8, tr_loss : 0.489\n",
      "epoch :   9, tr_loss : 0.457\n",
      "epoch :  10, tr_loss : 0.422\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_bi_lstm.ce_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11b42a978>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd8VGW+x/HPLwlJCDVAkN4kSFNaQGoQkKYuoCKCDSuiAlLWe3F3b1nd4l4VQURQUVZdaQIqgoBKC1USuvTQAygx9BrKc//IuBtZMAGSnMnM9/16zUvOmXPIN/OSb06eOfM85pxDRESCQ4jXAUREJO+o9EVEgohKX0QkiKj0RUSCiEpfRCSIqPRFRIKISl9EJIio9EVEgohKX0QkiIR5HeBSpUqVclWqVPE6hohIvrJy5cqfnHMxWR3nd6VfpUoVkpKSvI4hIpKvmNnu7Byn4R0RkSCi0hcRCSIqfRGRIKLSFxEJIip9EZEgotIXEQkiKn0RkSASMKV/4aLjL19tIuXwKa+jiIj4rYAp/d1pJ5m4Yg/3jl7K1h+Pex1HRMQvBUzpV4spzOS+zXAO7huzjJW7D3kdSUTE7wRM6QPULFOUqc80p0ShcB4c+x3zNx/0OpKIiF8JqNIHqFgiik/7NqN66cI8+VES01aleB1JRMRvBFzpA5QqHMGEp5pya9USDJ68lrGLdngdSUTELwRk6QMUiSzAuMcac8fNZfjTzE28MmszzjmvY4mIeMrvplbOSRFhoYzs1ZDiUd8zZuF2Dp08y1/uvpmw0ID9WSci8qsCuvQBQkOMP3erS6nCEbw5dxuHTp7jrQcaEFkg1OtoIiJ5Liguec2Mwe1r8McudZi7+UceeX8FR0+f8zqWiEieC4rS/1nv5lUY0bMBq/ce5v53lnHw2BmvI4mI5KmgKn2ALvXK8X7vxuw5dIruY5axO+2k15FERPJM0JU+QHyNGMY/1ZTjZ85x7+hlfL/vqNeRRETyRFCWPkD9isX5tG9zwkONXu8uZ9n2NK8jiYjkuqAtfYDqpQsz9dnm3FAskt7jVjD7+x+8jiQikquyVfpm1snMtphZspkNvczzb5jZGt9jq5kdyfRcbzPb5nv0zsnwOaFssYJ8+nQz6pQryrOfrGTiij1eRxIRyTVZlr6ZhQKjgM5AbaCXmdXOfIxzbpBzrr5zrj4wEpjmO7cE8D/ArUAT4H/MLDpnv4XrF10onE+evJVWsTEMnbaeUfOT9eldEQlI2bnSbwIkO+d2OOfSgYlA1185vhcwwffnjsA3zrlDzrnDwDdAp+sJnFuiwsMY2zuObvXL8eqcLbw0YyMXL6r4RSSwZOcTueWBvZm2U8i4cv83ZlYZqArM+5Vzy1/mvD5AH4BKlSplI1LuKBAawrAe9YkuFM64Jbs4dDKdV7vXIzwsqN/6EJEAktNt1hOY4py7cDUnOefedc7FOefiYmJicjjS1QkJMf77rtq80PEmvlizn6c+SuJU+nlPM4mI5JTslP4+oGKm7Qq+fZfTk38N7VztuX7DzHiuTXVeuedmFm1L5YH3vuPwyXSvY4mIXLfslH4iEGtmVc0snIxin37pQWZWE4gGlmXaPQfoYGbRvjdwO/j25Qs9m1Ti7QcbsfHAMe57Zxn7j5z2OpKIyHXJsvSdc+eBfmSU9SZgsnNug5m9ZGZdMh3aE5joMt324pw7BLxMxg+OROAl3758o1PdMnz4WBN+PHqG7qOXknzwhNeRRESumfnbrYlxcXEuKSnJ6xj/ZsP+o/T+IJELFy8y7rEm1K9Y3OtIIiL/ZGYrnXNxWR2n21KyqU65Ykx9phlFIgvwwHvLSdia6nUkEZGrptK/CpVLFmJK32ZULlmIJz5MZPra/V5HEhG5Kir9q1S6aCQT+zSlQaVonp+4mr8v2el1JBGRbFPpX4NiBQvw0eNNuL3WDfzvlxsZ9vUWTdsgIvmCSv8aRRYIZfSDDekRV4E35yXz+8+/54KmbRARPxfwC6PnprDQEP527y2ULBzB6AXbOXwyneE96xMRpkXXRcQ/6Ur/OpkZ/9mpJn+4sxazvv+Bx8YlcvyMFl0XEf+k0s8hT7aqxrAe9fhu5yF6vbeclMOnvI4kIvJvVPo56J6GFRj7SBw7U0/SafgiJiXu0Ru8IuJXVPo5rE3N0sweGE+dckX5z6nrefLDJA4eO+N1LBERQKWfKyqWiGLCU035r7tqszj5JzoMT2DGOn2QS0S8p9LPJSEhxhMtqzJzQCsql4ii3/jV9J+wWlM0i4inVPq5rHrpwkx9pjlD2tdg1voDdBiewPzNB72OJSJBSqWfB8JCQ+jfLpbPn2tBiahwHvt7IkOnruPEWa3IJSJ5S6Wfh+qWL8b0/i3o2/pGJiftpdPwBJbvSPM6logEEZV+HosIC2Vo55p82rcZoSFGr/eW8/KMjZw5d1XLCouIXBOVvkcaVS7BrOdb8dCtlXl/8U7ufHMRa/ce8TqWiAQ4lb6HosLDeLlbXT5+ogmn0i9wz+ilDPt6C+nnL3odTUQClErfD7SKjWH2wHi61i/Hm/OSufvtJWz54bjXsUQkAKn0/USxggUY1qM+7zzciB+OnuE3IxczZuF2TdcsIjlKpe9nOtYpw5xB8bSpGcMrszZz/zvL2PXTSa9jiUiAUOn7oVKFIxjzUCPeuL8eW348TucRi/h4+W5N3iYi102l76fMjLsbVODrQfHEVYnmvz7/nkc+WMGBo6e9jiYi+Vi2St/MOpnZFjNLNrOhVzimh5ltNLMNZjY+0/4LZrbG95ieU8GDRdliBfno8Sb8qVtdknYdpsMbCUxblaKrfhG5JpZVeZhZKLAVaA+kAIlAL+fcxkzHxAKTgbbOucNmVto5d9D33AnnXOHsBoqLi3NJSUlX/50Egd1pJxkyeS1Juw/Tsc4N/PnumylVOMLrWCLiB8xspXMuLqvjsnOl3wRIds7tcM6lAxOBrpcc8xQwyjl3GODnwpecVblkISY93YwXO9dk/uZUOr6RwOzvf/A6lojkI9kp/fLA3kzbKb59mdUAapjZEjNbbmadMj0XaWZJvv3drjNv0AsNMZ5ufSMzBrSkbPFI+v5jJYMnreHoaa3LKyJZy6k3csOAWOA2oBfwnpkV9z1X2fcrxwPAcDO78dKTzayP7wdDUmpqag5FCmw1bijCZ8+2YEC7WL5Yu59OwxNYtE2vnYj8uuyU/j6gYqbtCr59maUA051z55xzO8l4DyAWwDm3z/ffHcACoMGlX8A5965zLs45FxcTE3PV30SwKhAawuD2NZj2THOiwkN5+P0V/OHz9ZxK15TNInJ52Sn9RCDWzKqaWTjQE7j0LpzPybjKx8xKkTHcs8PMos0sItP+FsBGJEfVq1icmQNa8WTLqnzy3R46j1hE0q5DXscSET+UZek7584D/YA5wCZgsnNug5m9ZGZdfIfNAdLMbCMwH3jBOZcG1AKSzGytb/8rme/6kZwTWSCUP9xVmwlPNeXCRcd97yzjzzM3cuSUlmcUkX/J8pbNvKZbNq/fibPn+fPMjUxYsZfCEWE80qwyT7SsSknd3ikSsLJ7y6ZKP4BtOnCMt+Yn89X6A0SGhfJQ00o8FV+N0kUivY4mIjlMpS//lHzwOKPmb+eLNfsoEBpCryaVeLp1NcoWK+h1NBHJISp9+Tc7fzrJ2/OT+Wz1PkLMuC+uAs/cdiMVoqO8jiYi10mlL1e099ApRi/czqdJe3EO7mlYnufaVKdyyUJeRxORa6TSlyztP3KadxZuZ0LiXi5cdHStV45n21SneulsT5UkIn5CpS/ZdvDYGd5N2MEn3+3hzPkL3HlzWfq3jeWmMkW8jiYi2aTSl6uWduIsYxfv5KOluziZfoFOdcrQr2116pYv5nU0EcmCSl+u2eGT6YxbspNxS3dx/Mx52tUsTf92sdSvWDzrk0XEEyp9uW5HT5/jo6W7eH/JTo6cOker2FIMaBdL4yolvI4mIpdQ6UuOOXH2PB8v283YRTtIO5lOs2ol6d+uOs2qlcTMvI4nIqj0JRecSj/P+O/28E7CDlKPnyWucjT928USH1tK5S/iMZW+5Joz5y4wKXEvYxZu58DRM9SrWJwBbavTtmZplb+IR1T6kuvOnr/A1JX7eHtBMimHT1OnXFH6t61Oh9plCAlR+YvkJZW+5JlzFy7y2ep9vD0/mV1pp7jphiL0a1udO24uS6jKXyRPqPQlz52/cJEZ6w7w1vxkkg+eoFpMIfq1qU6XeuUIC82plTlF5HKyW/r6lyg5Jiw0hG4NyjNnYDyjHmhIeGgIgyevpePwBFbu1kpeIv5ApS85LjTEuPOWsnw1oBVjHmrEmXMX6T5mGX/8coPW7xXxmEpfck1IiNGpbhnmDIrn4aaVGbdkFx2HJ7Ak+Sevo4kELZW+5LrCEWG81LUuk/o0JdSMB8d+x4vT1nHszDmvo4kEHZW+5Jlbq5Vk9sB4no6vxqTEvXQYlsC8zT96HUskqKj0JU9FFgjlxTtqMe3ZFhQtGMbjf09i0KQ1HD6Z7nU0kaCg0hdP1K9YnC/7t2RAu1i+XLuf9m8s5Kv1B7yOJRLwVPrimYiwUAa3r8H0fi0pUyySZz9ZRd+PV3Lw+Bmvo4kELJW+eK52uaJ8/mwL/qPTTczbcpD2wxKYujIFf/vgoEggyFbpm1knM9tiZslmNvQKx/Qws41mtsHMxmfa39vMtvkevXMquASWsNAQnr2tOl8NaEX10oUZ8ulaHvt7IvuPnPY6mkhAyXIaBjMLBbYC7YEUIBHo5ZzbmOmYWGAy0NY5d9jMSjvnDppZCSAJiAMcsBJo5Jw7fKWvp2kY5MJFx0fLdvF/s7cQGmK8eEdNejWupEncRH5FTk7D0ARIds7tcM6lAxOBrpcc8xQw6ucyd84d9O3vCHzjnDvke+4boFN2vwkJTqEhxmMtqjJnYDy3VCjG7z/7ngfGLmd32kmvo4nke9kp/fLA3kzbKb59mdUAapjZEjNbbmadruJczKyPmSWZWVJqamr200tAq1Qyik+evJW/3nMzG/Ydo+PwBMYu2sGFixrrF7lWOfVGbhgQC9wG9ALeM7Nsr6LtnHvXORfnnIuLiYnJoUgSCMyMXk0q8fXgeJrfWIo/zdxE9zFLST543OtoIvlSdkp/H1Ax03YF377MUoDpzrlzzrmdZLwHEJvNc0WyVLZYQd7vHcfw++uz86eT3DFiMaPmJ3PuwkWvo4nkK9kp/UQg1syqmlk40BOYfskxn5NxlY+ZlSJjuGcHMAfoYGbRZhYNdPDtE7lqZka3BuX5ZlBr2te+gVfnbKHrW0vYsP+o19FE8o0sS985dx7oR0ZZbwImO+c2mNlLZtbFd9gcIM3MNgLzgRecc2nOuUPAy2T84EgEXvLtE7lmMUUiGPVgQ8Y81JCDx8/S9a0lvDZnC2fPX/A6mojf08pZkq8dOZXOyzM2MXVVCrGlC/O37rfQsFK017FE8pxWzpKgUDwqnNd71GPcY405efY8945eyp9mbOR0uq76RS5HpS8Boc1NpZkzKJ4HmlRi7OKddBqRwLLtaV7HEvE7Kn0JGEUiC/Dnu29mwlNNAej13nJ+/9l6jmuxFpF/UulLwGl2Y0lmPx/Pky2rMn7FHjq+kcCCLQezPlEkCKj0JSAVDA/lD3fVZuozzYmKCOPRcYk8P3G1JnCToKfSl4DWsFI0MwdkLNYy6/sfaPv6At74Ziun0s97HU3EEyp9CXg/L9Yyd3Brbq91AyPmbqPtawv5fPU+LmoeHwkyKn0JGhVLRPHWAw35tG8zYopEMHDSGu4ZvZRVe64407dIwFHpS9BpXKUEXzzXgtfuq8f+I6e55+2lGu+XoKHSl6AUEmJ0b1SB+b+9jf5tqzPbN94/TOP9EuBU+hLUCkWEMaTDTcwd0pr2tcvw5txttHltAdNWpWi8XwKSSl8EqBAdxcheDZj6TDPKFI1k8OS13P32Elbu1vyAElhU+iKZNKpcgs+ebcGwHvX44dgZ7h29jP4TVpNy+JTX0URyhEpf5BIhIcY9DTPG+we0i+XrDT/Q7vWFvP71Fk6e1Xi/5G8qfZEriAoPY3D7Gsz77W10qluGkfOSafPaAqas1Hi/5F8qfZEslC9ekBE9GzD1meaUK16Q3366lm5vLyFxl8b7Jf9R6YtkU6PK0Ux7pjnD76/PwWNnuW/MMp4bv4q9hzTeL/mHSl/kKoSEZKzTO++3rRl4eyxzN/1Iu2ELeXXOZo33S76g0he5BlHhYQy8vQbzf3sbd95cllHzt3Pbawv4NGmvxvvFr6n0Ra5D2WIFeeP++nz2bHMqRBfkhSnr6DJqMSt2arxf/JNKXyQHNKiUMd4/omd9Dp1Ip8c7y3juE433i/9R6YvkEDOja/3yzB1yW8atnpsP0m7YQv5v9mZOaLxf/IRKXySHFQwPZUC7WOb/9jbuuqUsby/Yzm2vLmBy4l4uaLxfPJat0jezTma2xcySzWzoZZ5/1MxSzWyN7/FkpucuZNo/PSfDi/izMsUiGdajPp8/14LKJaP4j6nr6PLWYlZr/n7xkDn361ceZhYKbAXaAylAItDLObcx0zGPAnHOuX6XOf+Ec65wdgPFxcW5pKSk7B4uki8455ix7gB//WoTPx4/y3NtqtO/bXUKhOqXbckZZrbSOReX1XHZ+T+uCZDsnNvhnEsHJgJdrzegSDAxM35TrxyzB8XTrX553py7jXtHL2V76gmvo0mQyU7plwf2ZtpO8e271L1mts7MpphZxUz7I80sycyWm1m36wkrkt8VjSzA6z3qMfrBhuw9dIo731zER8t2kdVv3CI5Jad+t/wSqOKcuwX4Bvgw03OVfb9yPAAMN7MbLz3ZzPr4fjAkpaam5lAkEf/V+eayzBkYT9NqJfnvLzbQe1wiPx4743UsCQLZKf19QOYr9wq+ff/knEtzzp31bY4FGmV6bp/vvzuABUCDS7+Ac+5d51yccy4uJibmqr4BkfyqdNFIxj3amJe71WXFzjQ6Dk9g5roDXseSAJed0k8EYs2sqpmFAz2BX9yFY2ZlM212ATb59kebWYTvz6WAFsBGRATIGOt/uGllvhrQisolonhu/CoGTVrD0dPnvI4mASrL0nfOnQf6AXPIKPPJzrkNZvaSmXXxHTbAzDaY2VpgAPCob38tIMm3fz7wSua7fkQkQ7WYwkx5pjkDb49l+tr9dB6ewLLtaV7HkgCU5S2beU23bEqwW7P3CIMmrWFX2kmebFmVIR1uIrJAqNexxM/l5C2bIpKH6lcszswBLXno1sq8t2gnXd9awsb9x7yOJQFCpS/ih6LCw3i5W13GPdaYQ6fS6TpqMWMWbtc0DnLdVPoifqzNTaWZMzCe22vdwCuzNtPr3eWauVOui0pfxM+VKBTO2w82ZFiPemw6cIzOIxbxadJefaBLrolKXyQfMDPuaViBWQNbUbtcUV6Yso6+/1hJ2omzWZ8skolKXyQfqRAdxYSnmvK7O2oyf3MqHYcvYt7mH72OJfmISl8knwkNMfrE38gX/VpQqnA4j/89id9/tp5T6VqoRbKm0hfJp2qVLcoX/VrwdHw1xq/Ywx0jFmmufsmSSl8kH4sIC+XFO2ox4ammnLvg6D5mGcO+2cq5Cxe9jiZ+SqUvEgCaVivJrIGtNFe/ZEmlLxIgNFe/ZIdKXyTAaK5++TUqfZEApLn65UpU+iIBSnP1y+Wo9EUCnObql8xU+iJBoEBoCANvr8GUvs2IKBDKA2OXM3jyGvYfOe11NMljKn2RINKgUjQzB7SkT3w1Zqw7QJvXFvDKrM0cO6Mhn2ChlbNEglTK4VMM+3orn63ZR/GCBejfNpaHmlYmPEzXgvmRVs4SkV9VITqKYffX58t+LalTrhgvzdjI7cMWMmPdft3bH8BU+iJBrm75Ynz8RBM+fLwJUeGh9Bu/mm5vL2XFzkNeR5NcoNIXEcyM1jVimDmgFa92v4Ufj56hxzvLePLDJJIPajqHQKIxfRH5N2fOXeD9xTsZvWA7p89d4P7GFRl4eyyli0R6HU2uILtj+ip9EbmitBNnGTkvmX8s3014WAh94qvxVKtqFIoI8zqaXCJH38g1s05mtsXMks1s6GWef9TMUs1sje/xZKbnepvZNt+j99V9GyLipZKFI/jfLnX4dnBr2txUmuHfbuO21xYw/rs9nNf0zflSllf6ZhYKbAXaAylAItDLObcx0zGPAnHOuX6XnFsCSALiAAesBBo556640oOu9EX816o9h/nrV5tI3HWY6qULM7RTTdrVKo2ZeR0t6OXklX4TINk5t8M5lw5MBLpmM0dH4Bvn3CFf0X8DdMrmuSLiZxpWimby08149+FGXHSOJz9K4v53l7Nm7xGvo0k2Zaf0ywN7M22n+PZd6l4zW2dmU8ys4lWeKyL5hJnRoU4Z5gyM50/d6rIj9QTdRi2h3/hV7E476XU8yUJO3bL5JVDFOXcLGVfzH17NyWbWx8ySzCwpNTU1hyKJSG4qEBrCQ00rs+CFNgxoF8vcTQe5fdhC/vjlBg6fTPc6nlxBdkp/H1Ax03YF375/cs6lOefO+jbHAo2ye67v/Hedc3HOubiYmJjsZhcRP1A4IozB7Wuw8IXb6N6oAh8u3UX8q/MZvWA7Z85d8DqeXCI7pZ8IxJpZVTMLB3oC0zMfYGZlM212ATb5/jwH6GBm0WYWDXTw7RORAFO6aCR/vecW5gyM59aqJfjb7M20fW0BU1emcPGif90aHsyyLH3n3HmgHxllvQmY7JzbYGYvmVkX32EDzGyDma0FBgCP+s49BLxMxg+OROAl3z4RCVCxNxRhbO/GTHiqKaWKRDDk07XcOXIxi7Zp6NYf6MNZIpJrLl50zFh/gFfnbGbvodO0ii3Fi51rUbtcUa+jBRzNsikingsJMbrUK8e3g1vzhztrsS7lKHeOXMSQyWu1gItHdKUvInnm6KlzvL0wmXFLdmHA4y2r0q9NdU3rkAN0pS8ifqdYVAFe7FyLeUNac+fNZRm9YDu3D1vIrPUHNId/HlHpi0ie+3kBl2nPNqd4VDjPfLKKR8cl6sNdeUClLyKeaVgpmi/7teC/76rNyt2Haf9GAiO+3ab7+3ORSl9EPBUWGsLjLasyd0hrOtYpwxvfbqXziEUkbNUtnrlBpS8ifuGGopGM7NWAj59oAsAjH6zgufGr+OHoGY+TBRaVvoj4lVaxMcwe2Ioh7Wvw7cYfaff6AsYu2qH5+3OISl9E/E5EWCj928XyzaDWNK5agj/N3MRdIxezcrc+0H+9VPoi4rcqlYxi3KONGfNQI46ePse9o5cxdOo6zeJ5HVT6IuLXzIxOdcvw7eDWPB1fjSkrU2j7+gImJe7RRG7XQKUvIvlCoYgwXryjFjMHtCK2dBH+c+p67ntnGZsOHPM6Wr6i0heRfOWmMkWY9HRTXruvHjt/OsldIxfz8oyNnDh73uto+YJKX0TyHTOje6MKzBvSmvsbV+SDJTtp9/oCZqzbr+kcsqDSF5F8q3hUOH+5+2amPdOcUoUj6Dd+NY98sIKdP2k6hytR6YtIvtegUjTT+7Xkj13qsGbPETq+kcCwb7ZqOofLUOmLSEAIDTF6N6/C3CGt6XxzGd6cu42OwxNYsOWg19H8ikpfRAJK6aKRjOjZgE+evJXQEOPRcYk884+VHDiqRVtApS8iAapF9VLMer4VL3S8iXmbD9Lu9YW8l7CDc0E+nYNKX0QCVkRYKM+1qc63g1vTrFpJ/vzVJn4zcjGJu4J3OgeVvogEvIolohjbO453H27E8TPnuW/MMl74dC1pJ856HS3PqfRFJCiYGR3qlOGbwfH0bX0jn63eR9vXFzL+u+CazkGlLyJBJSo8jKGdazLr+VbULFOE3322nu5jlrI99YTX0fKESl9EglLsDUWY2Kcpw3rUY8dPJ7ljxCI+WLwz4K/6s1X6ZtbJzLaYWbKZDf2V4+41M2dmcb7tKmZ22szW+B5jciq4iMj1MjPuaViBrwfG06J6KV6asZFe7y1n76FTXkfLNVmWvpmFAqOAzkBtoJeZ1b7McUWA54HvLnlqu3Ouvu/RNwcyi4jkqNJFI3m/dxz/1/0WNuw/RqfhCUxYsScg5/HJzpV+EyDZObfDOZcOTAS6Xua4l4G/AVrQUkTyHTOjR1xFZg9sRb2KxXlx2noe+3tiwK3Rm53SLw/szbSd4tv3T2bWEKjonJt5mfOrmtlqM1toZq2uPaqISO6rEB3FP564lT92qcPyHWl0eGMhn6/eFzBX/df9Rq6ZhQDDgCGXefoAUMk51wAYDIw3s6KX+Tv6mFmSmSWlpqZebyQRkesS4pvHZ9bz8VQvXZiBk9bw7CerAuK+/uyU/j6gYqbtCr59PysC1AUWmNkuoCkw3czinHNnnXNpAM65lcB2oMalX8A5965zLs45FxcTE3Nt34mISA6rWqoQn/ZtztDONZm76SAd3khgzoYfvI51XbJT+olArJlVNbNwoCcw/ecnnXNHnXOlnHNVnHNVgOVAF+dckpnF+N4IxsyqAbHAjhz/LkREckloiNG39Y182b8lZYpF8vTHKxk8aQ1HT53zOto1ybL0nXPngX7AHGATMNk5t8HMXjKzLlmcHg+sM7M1wBSgr3MueCe9EJF866YyRfj8uRY83y6WL9bup+PwBBZuzX/D0eZvb07ExcW5pKQkr2OIiFzRupQjDJm8lm0HT/DArZX43R21KBwR5mkmM1vpnIvL6jh9IldE5CrdUqE4X/ZvSZ/4akxYsYfOIxL4bkea17GyRaUvInINIguE8rs7ajH56WaEmNHzveW8PGOj3y/RqNIXEbkOjauUYNbzrXjo1sq8v3gnd765iDV7j3gd64pU+iIi1ykqPIyXu9Xl4yeacCr9AveOXsrrX28h/bz/rdKl0hcRySGtYmOYMyieuxuUZ+S8ZLqOWsKmA8e8jvULKn0RkRxUNLIAr91Xj/ceiSP1+Fm6vLWYUfOTOe8na/Oq9EVEckH72jfw9aB4OtQuw6tzttB9zDK/WKhFpS8ikktKFApn1IMNGdmrAbvS/GOhFpW+iEgu+029cr9YqOWBsd4t1KLSFxHJA5kXavl+n3cLtaj0RUTyyJWuBcNFAAADdUlEQVQWavnxWN4t1KLSFxHJY/++UEsCX6zJm4VaVPoiIh7IvFDLjTGFeH7iGvqNX53rb/J6Oy2ciEiQ+3mhlvcW7eDEmfOEhFiufj2VvoiIx35eqCUvaHhHRCSIqPRFRIKISl9EJIio9EVEgohKX0QkiKj0RUSCiEpfRCSIqPRFRIKI5fUMb1kxs1Rg93X8FaWAn3IoTn6n1+KX9Hr8kl6PfwmE16Kycy4mq4P8rvSvl5klOefivM7hD/Ra/JJej1/S6/EvwfRaaHhHRCSIqPRFRIJIIJb+u14H8CN6LX5Jr8cv6fX4l6B5LQJuTF9ERK4sEK/0RUTkCgKm9M2sk5ltMbNkMxvqdR4vmVlFM5tvZhvNbIOZPe91Jq+ZWaiZrTazGV5n8ZqZFTezKWa22cw2mVkzrzN5ycwG+f6dfG9mE8ws0utMuSkgSt/MQoFRQGegNtDLzGp7m8pT54EhzrnaQFPguSB/PQCeBzZ5HcJPjABmO+dqAvUI4tfFzMoDA4A451xdIBTo6W2q3BUQpQ80AZKdczucc+nARKCrx5k845w74Jxb5fvzcTL+UZf3NpV3zKwCcCcw1ussXjOzYkA88D6Acy7dOXfE21SeCwMKmlkYEAXs9zhPrgqU0i8P7M20nUIQl1xmZlYFaAB8520STw0H/gO46HUQP1AVSAXG+Ya7xppZIa9DecU5tw94DdgDHACOOue+9jZV7gqU0pfLMLPCwFRgoHPumNd5vGBmdwEHnXMrvc7iJ8KAhsBo51wD4CQQtO+BmVk0GaMCVYFyQCEze8jbVLkrUEp/H1Ax03YF376gZWYFyCj8T5xz07zO46EWQBcz20XGsF9bM/uHt5E8lQKkOOd+/s1vChk/BILV7cBO51yqc+4cMA1o7nGmXBUopZ8IxJpZVTMLJ+ONmOkeZ/KMmRkZY7abnHPDvM7jJefci865Cs65KmT8fzHPORfQV3K/xjn3A7DXzG7y7WoHbPQwktf2AE3NLMr376YdAf7GdpjXAXKCc+68mfUD5pDx7vsHzrkNHsfyUgvgYWC9ma3x7fudc+4rDzOJ/+gPfOK7QNoBPOZxHs84574zsynAKjLueltNgH86V5/IFREJIoEyvCMiItmg0hcRCSIqfRGRIKLSFxEJIip9EZEgotIXEQkiKn0RkSCi0hcRCSL/D00DuDqE8HMFAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_bi_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 83.33%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}