{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Long Short-Term Memory. \n",
    "\n",
    "### Many to One Classification by LSTM\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharLSTM class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharLSTM:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "        \n",
    "        # LSTM cell\n",
    "        with tf.variable_scope('lstm_cell'):\n",
    "            lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            outputs, states = tf.nn.dynamic_rnn(cell = lstm_cell, inputs = self._X_batch,\n",
    "                                         sequence_length = self._X_length, dtype = tf.float32)\n",
    "            \n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = states.h, num_outputs = n_of_classes, activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharLSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_lstm = CharLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb, n_of_classes = 2,\n",
    "                     hidden_dim = 16, dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_lstm.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.686\n",
      "epoch :   2, tr_loss : 0.663\n",
      "epoch :   3, tr_loss : 0.633\n",
      "epoch :   4, tr_loss : 0.613\n",
      "epoch :   5, tr_loss : 0.586\n",
      "epoch :   6, tr_loss : 0.547\n",
      "epoch :   7, tr_loss : 0.514\n",
      "epoch :   8, tr_loss : 0.473\n",
      "epoch :   9, tr_loss : 0.429\n",
      "epoch :  10, tr_loss : 0.379\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_lstm.ce_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11a517a58>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VPW9x/H3N5OEhB1JQEgCYZVdlpE1gNYNN9AiCipFLYIKorW11d7e9l7trbWLIooK4r4hoFUUFFGRHSQBRHZCWJLIEvY1JITf/SOjjRTMAElOMvN5Pc88MmfOyXwyj3zm8PudxZxziIhIeIjwOoCIiJQdlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgYUemLiISRoErfzPqY2TozSzezh0/x+lNmtjzwWG9m+4q8NsTMNgQeQ0oyvIiInBkr7uQsM/MB64HLgSxgCTDIObf6NOvfB3Rwzt1pZucBqYAfcEAa0Mk5t7fkfgUREQlWZBDrdAbSnXMZAGY2EegHnLL0gUHAnwJ/vhKY6ZzbE9h2JtAHeOd0bxYXF+eSk5ODCi8iIoXS0tJ2Oefii1svmNJPADKLPM8CupxqRTNrCDQCvvyJbRNOsd0wYBhAgwYNSE1NDSKWiIh8z8y2BLNeSU/kDgSmOOcKzmQj59x455zfOeePjy/2i0pERM5SMKWfDSQVeZ4YWHYqA/nx0M2ZbCsiIqUsmNJfAjQzs0ZmFk1hsU89eSUzawHUAhYWWTwDuMLMaplZLeCKwDIREfFAsWP6zrnjZjaSwrL2AS8751aZ2aNAqnPu+y+AgcBEV+RwIOfcHjN7jMIvDoBHv5/UFRGRslfsIZtlze/3O03kioicGTNLc875i1tPZ+SKiIQRlb6ISBgJmdJ3zvGX6Wv4JnNf8SuLiISpYE7OqhC27D7CO4u3Mn5OBhcl12Joz8Zc1rIuvgjzOpqISLkRUhO5h44d590lmbw8bxPZ+46SXLsyd6Y04sZOiVSODpnvNxGR/xDsRG5Ilf73jhecYMaqHbw4N4PlmfuoERvFrV0aMKR7MnWrx5RQUhGR8iOsS7+otC17eHHOJmas3k5khHHdhfUZmtKYVvWrl9h7iIh4LdjSD/kxj04Nz6PT4PPYsvswr8zfzKTUTN5fmk2PprUZmtKY3s3jidC4v4iEiZDf0z/Z/iP5vLNkK6/O38z2A7k0rVOVX6Y04oYOCcRE+UrtfUVESpOGd4qRd/wE0779jhfnbGL1tgPUrhLNbV0bMrhbQ+KqVir19xcRKUkq/SA551iUsYcJczP4Yu1OoiMj+HmHBIb2bETTOtXKLIeIyLnQmH6QzIxuTWrTrUltNuYc4qV5m3gvLYuJSzK5+IJ47urZmO5NamOmcX8RqfjCfk//VHYfOsZbi7fy+sLN7DqUR8t61Rma0ojrLqxPdGTInMQsIiFEwzslIDe/gKnLv2PCvAzW7zhEnWqVGNI9mVu7NKBm5Wiv44mI/EClX4Kcc8zZsIsJczOYu2EXsVE+BvgTubNHI5LjqngdT0REpV9a1m4/wIS5m/hweTbHTzgub1mXoT0bc1FyLY37i4hnVPqlbOeBXN5YtIU3Fm1h35F82iXWYGjPxlzd5nwifRr3F5GypdIvI0fzCpiyNIuX521i067DJNSM5fbuydzWtSGx0TrZS0TKhkq/jJ044fhi7U4mzM1g8aY91K8Rw8NXt+S6dvU07CMipU63SyxjERHG5a3q8u7wbkwc1pWalaMZ9c4ybnxhoW7sIiLlhkq/FHRtXJuP7kvhif5t2bL7MP3GzufBScvZcSDX62giEuZU+qXEF2HcfFEDZv3mYu7u3YSPv9nGxX//ime+2EBufoHX8UQkTKn0S1m1mCgevqoFMx/sRe/m8fxz5nou/edsPvrmO8rbfIqIhL6gSt/M+pjZOjNLN7OHT7POTWa22sxWmdnbRZYXmNnywGNqSQWvaBrWrsILgzvxzl1dqR4bxX3vLGPACwtZkaXxfhEpO8UevWNmPmA9cDmQBSwBBjnnVhdZpxkwCfiZc26vmdVxzu0MvHbIOVc12EAV9eidM1FwwjE5NZN/fLaOXYfyuLFTIg9deYFu5SgiZ60kj97pDKQ75zKcc3nARKDfSevcBYx1zu0F+L7w5dR8EcbAzoXj/cN7N2bq8u+45B9fMXZWusb7RaRUBVP6CUBmkedZgWVFNQeam9l8M1tkZn2KvBZjZqmB5def6g3MbFhgndScnJwz+gUqsmoxUTxyVUtmPtiLns3i+PuMdVz6z9l8vELj/SJSOkpqIjcSaAZcDAwCXjSzmoHXGgb+yXELMNrMmpy8sXNuvHPO75zzx8fHl1CkiqNh7SqMG+zn7bu6UC0mkpFvL+PmcYtYmb3f62giEmKCKf1sIKnI88TAsqKygKnOuXzn3CYK5wCaATjnsgP/zQC+AjqcY+aQ1b1JHNNG9eQvN7RlY84hrnt2Hg9N/oadOr5fREpIMKW/BGhmZo3MLBoYCJx8FM4HFO7lY2ZxFA73ZJhZLTOrVGR5D2A1clq+COOWLg2Y9dDF3NWzMR8sz9Z4v4iUmGJL3zl3HBgJzADWAJOcc6vM7FEz6xtYbQaw28xWA7OAh5xzu4GWQKqZfRNY/teiR/3I6VWPieL3V7dk5q96071p4Xj/ZU/OZvq32zTeLyJnTRdcqyDmp+/isY9Xs3b7QTo3Oo8/XtuKNgk1vI4lIuWELrgWYno0LRzv/78b2pC+s3C8/7dTvmHnQY33i0jwVPoViC/CuLVLQ2b95mKGpjTiX8uy+dk/ZvP8Vxs13i8iQVHpV0A1YqP4r2ta8dmvetO1cW2e+HQtlz81m0803i8ixVDpV2CN4qowYYifN3/ZhcpRkdzz1lIGjtfx/SJyeir9EJDSLI5po1L48/Vt2BAY73/4vRXkHDzmdTQRKWd09E6I2X80n2e+2MCrCzYT6TOublOPAf4kujQ6j4gI3bZRJFTpHrlhLiPnEBPmbeKj5d9x8Nhxks6L5caOSfTvlEBircpexxOREqbSFwBy8wuYsWo7k1IzmZ++GzPo0SSOAf5Ermx9PjFRPq8jikgJUOnLf8jae4T30rKZnJZJ1t6jVIuJ5LoL63OTP4kLE2tgpuEfkYpKpS+ndeKEY9Gm3UxJzWL6ym3k5p+gWZ2qDPAnckOHROKrVfI6ooicIZW+BOVgbj4fr9jG5NRMlm7dhy/CuOSCOgzwJ/KzFnWI8ukAL5GKQKUvZyx95yGmpGXx3tIscg4eo3aVaK7vkMAAfyItzq/udTwR+QkqfTlrxwtOMGdDDpNTs/h8zQ7yCxztEmswoFMifS9MoEblKK8jishJVPpSIvYczuODZdlMTstizbYDREdGcEWrugzwJ5HSNA6fjv0XKRdU+lLiVmbvZ0paFh8sz2bfkXzq1Yihf8dEbuyUSHJcFa/jiYQ1lb6UmmPHC/h89U4mp2UyZ30OJxx0Tj6PG/2JXNO2HlUqRXodUSTsqPSlTGzfn8t7S7OYkpbFpl2HqRzt45q2hZd+uCi5lo79FykjKn0pU845UrfsZXJqJtNWbONwXgHJtSszwJ9E/46JnF8jxuuIIiFNpS+eOXzsOJ+sLLz0w9eb9hDlMwZe1IB7L2lCvRqxXscTCUkqfSkXNu86zLg5GUxOzSTCjEGdk7j3kqbUra49f5GSpNKXciVzzxHGzkpnclpW4LaPDbjn4ibUqabyFykJKn0pl7buPsIzX27g/WXZREYYg7s2ZHjvJrrej8g5UulLubZ512HGfLmBD5ZlEx0ZwS+6JTO8V2NqV1X5i5yNYEs/qKtpmVkfM1tnZulm9vBp1rnJzFab2Soze7vI8iFmtiHwGBL8ryChLDmuCk/e1J7PH+zNVW3qMWFuBj3/Nou/frKWPYfzvI4nErKK3dM3Mx+wHrgcyAKWAIOcc6uLrNMMmAT8zDm318zqOOd2mtl5QCrgBxyQBnRyzu093ftpTz88pe88xJgvNvDRiu+oHOXj9h7J3NWzMTUrR3sdTaRCKMk9/c5AunMuwzmXB0wE+p20zl3A2O/L3Dm3M7D8SmCmc25P4LWZQJ9gfwkJH03rVGXMoA589kAvLm5Rh+e+2kjKE7P452fr2H8k3+t4IiEjmNJPADKLPM8KLCuqOdDczOab2SIz63MG24r8oFndaoy9pSOf3t+LXs3jeObLdFKe+JKnZq5n/1GVv8i5KqmLpEQCzYCLgURgjpm1DXZjMxsGDANo0KBBCUWSiuyC86vx3K2dWLPtAKM/X8/TX2zglfmb+GVKY+5ISaZ6jC7vLHI2gtnTzwaSijxPDCwrKguY6pzLd85tonAOoFmQ2+KcG++c8zvn/PHx8WeSX0Jcy3rVGTfYz7RRKXRpXJunPl9Pzydm8eyXGzh07LjX8UQqnGAmciMpLPFLKSzsJcAtzrlVRdbpQ+Hk7hAziwOWAe359+Rtx8CqSymcyN1zuvfTRK78lG+z9jP68/V8sXYnNStHMaxXY4Z0S9aVPSXsldhErnPuODASmAGsASY551aZ2aNm1jew2gxgt5mtBmYBDznndgfK/TEKvyiWAI/+VOGLFKdtYg1euv0iPhzRgw5JNfnbp+vo+bdZvDB7I0fytOcvUhydnCUV2tKtexn9+QbmrM8hrmo0w3s14bauDYmN9nkdTaRM6YxcCStpW/bw1MwNzEvfRVzVStxzcRNu7dKAmCiVv4QHlb6EpSWb9/DUzPUs2LibOtUKy39QZ5W/hD6VvoS1RRm7eWrmehZv2kPd6pUYcUlTbr4oiUqRKn8JTSV67R2RiqZr49q8O7wbb9/VhYbnVeGPH67i2jHz2LDjoNfRRDyl0peQ1r1JHO8O78rLt/vZeySPvs/O5/2lWV7HEvGMSl9CnpnxsxZ1mT6qJ+0Sa/DgpG/47ZRvOJpX4HU0kTKn0pewUad6DG8N7cLIS5oyOS2L68fOJ33nIa9jiZQplb6ElUhfBL+58gJevaMzOYeO0ffZeXy4/D+uDCISslT6EpZ6N49n+qietKlfg/snLueR91eQm6/hHgl9Kn0JW+fXiOHtu7pwz8VNeOfrTG54bgEZORrukdCm0pewFumL4Hd9WvDKHRexff9RrntmHh99853XsURKjUpfBLjkgjpMG9WTFvWqc987y/ivf32r4R4JSSp9kYD6NWOZOKwrw3s15q3FW+n//AI27zrsdSyREqXSFykiyhfBI1e35KUhfrL2HuXaZ+YxbcU2r2OJlBiVvsgpXNqyLtPv70mzulUZ8fZS/vjhSo4d13CPVHwqfZHTSKgZy7vDujE0pRGvL9zCjc8vZOvuI17HEjknKn2RnxAdGcEfrm3F+MGd2LL7MNc8M5dPV2q4Ryoulb5IEK5ofT7TRvWkcXxV7n5zKf8zdRV5x094HUvkjKn0RYKUdF5lJg/vxh09knl1wWYGvLCAzD0a7pGKRaUvcgaiIyP403WteeG2TmTsOsw1Y+by2artXscSCZpKX+Qs9GlzPtPu60nD2lUY9kYaf/54NfkFGu6R8k+lL3KWGtSuzJR7ujGkW0MmzNvETeMWkr3vqNexRH6SSl/kHFSK9PG//drw3K0dSd9xiKufnssXa3Z4HUvktIIqfTPrY2brzCzdzB4+xeu3m1mOmS0PPIYWea2gyPKpJRlepLy4um09ProvhYSasfzytVQen75Gwz1SLkUWt4KZ+YCxwOVAFrDEzKY651aftOq7zrmRp/gRR51z7c89qkj5lhxXhffv7c6fp61m3JwMUrfs5ZlBHahfM9braCI/CGZPvzOQ7pzLcM7lAROBfqUbS6Riiony8efr2zJmUAfWbjvANWPmMmvdTq9jifwgmNJPADKLPM8KLDtZfzNbYWZTzCypyPIYM0s1s0Vmdv25hBWpKPpeWJ+P7kuhbvUY7nhlCU98upbjGu6RcqCkJnI/ApKdc+2AmcBrRV5r6JzzA7cAo82syckbm9mwwBdDak5OTglFEvFW4/iqfDCiB4M6N+D5rzZyy4uL2b4/1+tYEuaCKf1soOiee2Jg2Q+cc7udc8cCTycAnYq8lh34bwbwFdDh5Ddwzo13zvmdc/74+Pgz+gVEyrOYKB+P/7wto29uz8rv9nO1TuYSjwVT+kuAZmbWyMyigYHAj47CMbN6RZ72BdYEltcys0qBP8cBPYCTJ4BFQt71HRKYOrJwuGfYG2kMfyNVe/3iiWJL3zl3HBgJzKCwzCc551aZ2aNm1jew2igzW2Vm3wCjgNsDy1sCqYHls4C/nuKoH5Gw0LROVaaO7MHv+rTgq3U5XPbkbF5bsJmCE87raBJGzLny9T+c3+93qampXscQKVVbdh/mDx+sZO6GXbRPqsnjP29Ly3rVvY4lFZiZpQXmT3+SzsgV8UDD2lV4/c7OjL65PZl7jnDtM/N4/JM1HM3T3bmkdKn0RTxiZlzfIYEvft2bGzsmMm52BleMns3s9TqCTUqPSl/EYzUrR/PEje14d1hXonwRDHn5a0a9s4ycg8eK31jkDKn0RcqJLo1r88n9PXngsmZ8unI7lz05m4lfb+WEJnqlBKn0RcqRSpE+HrisOdPv70mL86vx8PvfMnD8ItJ3HvQ6moQIlb5IOdS0TlUmDuvK3/q3Y92Og1z19FyenLme3HxN9Mq5UemLlFNmxk0XJfHFr3tzTdt6jPliA1c/PZeFG3d7HU0qMJW+SDkXV7USowd24PU7O3P8hGPQi4t4aPI37D2c53U0qYBU+iIVRK/m8cx4oBf3XNyEfy3L5tInZ/OvZVmUtxMspXxT6YtUILHRPn7XpwUfj0qhYe3K/Ordb/jFy1+zZfdhr6NJBaHSF6mAWpxfnSl3d+exfq1ZvnUfVzw1h7Gz0nWLRimWSl+kgvJFGIO7JfP5r3vzsxZ1+PuMdVw7Zh5pW/Z6HU3KMZW+SAVXt3oMz9/WiQm/8HMwN58bX1jAHz74lgO5+V5Hk3JIpS8SIi5rVZeZD/bmju6NeHvxVi7752ymf7tNE73yIyp9kRBSpVIkf7yuFR+OSCG+WiXufWspQ19LJXvfUa+jSTmh0hcJQW0Ta/DhiB784ZqWLNi4m8ufnM2EuRm6Obuo9EVCVaQvgqE9GzPzwV50bVybP09bw/XPzefbrP1eRxMPqfRFQlxircq8NMTP2Fs6suPAMfqNncejH63WDVvClEpfJAyYGde0q8fnD/ZmUOcGvDx/Ezc8N59Nu3RSV7hR6YuEkRqxUfzfDW157c7ObD+QS99n5vHpyu1ex5IypNIXCUO9m8czbVRPGtepyt1vpvGX6Ws0yRsmVPoiYSqhZiyThndlcNeGjJ+TwS0TFrPzQK7XsaSUqfRFwlilSB+PXd+Gpwe259us/Vw9Zh6LMnS9/lAWVOmbWR8zW2dm6Wb28Clev93McsxseeAxtMhrQ8xsQ+AxpCTDi0jJ6Nc+gQ9H9qB6bCS3TljMuNkbdSZviCq29M3MB4wFrgJaAYPMrNUpVn3XOdc+8JgQ2PY84E9AF6Az8Cczq1Vi6UWkxDSvW42pI1Po0/p8Hv9kLcPfSNP1e0JQMHv6nYF051yGcy4PmAj0C/LnXwnMdM7tcc7tBWYCfc4uqoiUtqqVInn2lg7897Wt+HLtTvo+M4/V3x3wOpaUoGBKPwHILPI8K7DsZP3NbIWZTTGzpDPcVkTKCTPjlymNmDisK0fzC7jhuflMTs0sfkOpEEpqIvcjINk5147CvfnXzmRjMxtmZqlmlpqTk1NCkUTkXPiTz2PaqJ50aliLh6as4JH3V5Cbr7N4K7pgSj8bSCryPDGw7AfOud3OuWOBpxOATsFuG9h+vHPO75zzx8fHB5tdREpZXNVKvPHLLoy4pAnvfJ3JjS8sIHPPEa9jyTkIpvSXAM3MrJGZRQMDgalFVzCzekWe9gXWBP48A7jCzGoFJnCvCCwTkQrCF2E8dGULJvzCz9bdR7hmzFy+XLvD61hylootfefccWAkhWW9BpjknFtlZo+aWd/AaqPMbJWZfQOMAm4PbLsHeIzCL44lwKOBZSJSwVzWqi4f39eTpPMqc+erqfx9xloKTuiwzorGytuxuH6/36WmpnodQ0ROIze/gP+ZuoqJSzLp3qQ2YwZ1IK5qJa9jhT0zS3PO+YtbT2fkisgZiYny8df+7fjbje1I27KXa8bMJXWz/gFfUaj0ReSs3ORP4v17uxMT5WPg+EW8NG+TzuKtAFT6InLWWtevwdSRKVzSog6PfbyakW8v49Cx417Hkp+g0heRc1IjNorxgzvx8FUt+GTlNvo+O4/1Ow56HUtOQ6UvIufMzLi7dxPeGtqVA0eP0+/Z+Xyw7D9OyZFyQKUvIiWmW5PaTB+VQtuEGjzw7nL++4OVHDuus3jLE5W+iJSoOtVjeOuuLgzr1Zg3Fm3hpnGLyN531OtYEqDSF5ESF+WL4PdXt+SF2zqycechrh0zl9nrdV2t8kClLyKlpk+bekwd2YO61WO4/ZWvGf35ek7oLF5PqfRFpFQ1jq/Kv+7twQ0dEhj9+QZuf3UJew7neR0rbKn0RaTUxUb7+OeAC/nLDW1ZtHE3146Zy/LMfV7HCksqfREpE2bGLV0aMOWebpgZA15YwBsLN+ss3jKm0heRMtUusSbTRqWQ0jSO//5wFY+8/y15x094HStsqPRFpMzVrBzNS0MuYuQlTZm4JJPbXlqscf4yotIXEU9ERBi/ufICnh7YnuWZ++g3dh7rtuvyDaVNpS8inurXPoFJw7uRm3+Cnz83ny/W6K5cpUmlLyKea59Uk6kje9AovgpDX09l3OyNmuAtJSp9ESkX6tWIZfLw7lzdph6Pf7KW30xeoev2lAKVvoiUG7HRPp69pQMPXNaM95ZmMWj8InIOHvM6VkhR6YtIuWJmPHBZc8be0pHV2w5w/dj5rP7ugNexQoZKX0TKpWva1WPy8O4UnHD0f34Bn67c7nWkkKDSF5Fyq21iDaaO7EHz86tx95tpjJ2Vrgnec6TSF5FyrU71GN4d1pV+7evz9xnruH/icnLzNcF7tiK9DiAiUpyYKB+jb25P87rV+PuMdWzZc4QXB3eiTvUYr6NVOEHt6ZtZHzNbZ2bpZvbwT6zX38ycmfkDz5PN7KiZLQ88Xiip4CISXsyMEZc0ZdzgTmzYcZC+z87n26z9XseqcIotfTPzAWOBq4BWwCAza3WK9aoB9wOLT3ppo3OufeBxdwlkFpEwdmXr85lyd3d8EcaAcQuYtmKb15EqlGD29DsD6c65DOdcHjAR6HeK9R4DngBySzCfiMh/aFW/Oh+M6EHr+jUY8fZSnpqpO3IFK5jSTwAyizzPCiz7gZl1BJKcc9NOsX0jM1tmZrPNrOep3sDMhplZqpml5uToPpoiUrz4apV4+64u9O+YyNNfbOC+d5ZxNE8TvMU556N3zCwCeBL49Sle3gY0cM51AB4E3jaz6iev5Jwb75zzO+f88fHx5xpJRMJEpUgf/xjQjt9f3YLpK7cxYNwCtu0/6nWsci2Y0s8Gkoo8Twws+141oA3wlZltBroCU83M75w75pzbDeCcSwM2As1LIriICBRO8A7r1YQJv/CzedcR+j47n2Vb93odq9wKpvSXAM3MrJGZRQMDganfv+ic2++ci3POJTvnkoFFQF/nXKqZxQcmgjGzxkAzIKPEfwsRCXuXtqzL+/d2JyYqgpvHL+LD5dnFbxSGii1959xxYCQwA1gDTHLOrTKzR82sbzGb9wJWmNlyYApwt3Nuz7mGFhE5leZ1q/HhiBTaJ9Xk/onL+fuMtZrgPYmVt1Oa/X6/S01N9TqGiFRgecdP8McPVzJxSSZXtKrLUze3p0ql0D4X1czSnHP+4tbTZRhEJORER0bw+M/b8sdrW/H5mh30f34BWXuPeB2rXFDpi0hIMjPuTGnEK3d0JnvfUfo9O5/UzRpdVumLSEjr3Tyef93bg2oxkdzy4mImp2YWv1EIU+mLSMhrWqcqH4zogT+5Fg9NWcFfpq+hIEwneFX6IhIWalaO5rU7OzO4a0PGz8ngrtdTOZib73WsMqfSF5GwEeWL4LHr2/BYv9bMXp9D/+cXsHV3eE3wqvRFJOwM7pbM63d2ZseBY/QbO49FGbu9jlRmVPoiEpZ6NI3jgxE9qFUlml+89DWfrQqPe/Cq9EUkbDWKq8K/7ulBq/rVueetpUz95juvI5U6lb6IhLUalaN4c2gXOjWsxf0TlzFpSWgf0qnSF5GwV7VSJK/d0ZmUpnH89r0VvLZgs9eRSo1KX0QEiI32MWGIn8ta1uVPU1fxwuyNXkcqFSp9EZGASpE+nr+tI9e2q8dfP1nLUzPXU94uSnmuQvuycyIiZyjKF8HTAzsQG+Xj6S82cDS/gEeuaoGZeR2tRKj0RURO4oswnujfjthoH+PnZHA0r4D/7duaiIiKX/wqfRGRU4iIMP63b2tio3yMm5PB0fwCnujfDl8FL36VvojIaZgZD1/VgthoH6M/30BufgFP3dyeKF/FnQ5V6YuI/AQz44HLmlM52sdfpq8lN/8Ez97SgZgon9fRzkrF/boSESlDw3o14bF+rfl8zQ7uej2Vo3kFXkc6Kyp9EZEgDe6WzN9ubMf89F0MefnrCnlpZpW+iMgZuMmfxOiBHUjbupfbXvqafUfyvI50RlT6IiJnqO+F9Xn+1o6s+e4Ag15czK5Dx7yOFDSVvojIWbii9flMGOJn065D3DxuITsO5HodKShBlb6Z9TGzdWaWbmYP/8R6/c3MmZm/yLJHAtutM7MrSyK0iEh50Kt5PK/d0Znt+3O5adxCsvaW/7twFVv6ZuYDxgJXAa2AQWbW6hTrVQPuBxYXWdYKGAi0BvoAzwV+nohISOjSuDZvDu3C3sN53PTCQjbtOux1pJ8UzJ5+ZyDdOZfhnMsDJgL9TrHeY8ATQNF/4/QDJjrnjjnnNgHpgZ8nIhIyOjSoxTvDupJ7/AQ3jVvI+h0HvY50WsGUfgJQ9K4CWYFlPzCzjkCSc27amW4rIhIKWtevwaThXTHg5nELWZm93+tIp3TOE7lmFgE8Cfz6HH7GMDNLNbPUnJycc40kIuKJpnWqMWl4NypHRzLoxUWkbdnrdaT/EEzpZwNJRZ4nBpbScuPTAAAFHUlEQVR9rxrQBvjKzDYDXYGpgcnc4rYFwDk33jnnd8754+Pjz+w3EBEpR5LjqjDp7m7UrhLN4JcWs3Djbq8j/Ugwpb8EaGZmjcwsmsKJ2anfv+ic2++ci3POJTvnkoFFQF/nXGpgvYFmVsnMGgHNgK9L/LcQESlHEmrGMml4NxJqxnL7K1/z1bqdXkf6QbGl75w7DowEZgBrgEnOuVVm9qiZ9S1m21XAJGA18CkwwjlXMS9YISJyBupUj2HisK40rVOVu15P5dOV272OBICVt1uB+f1+l5qa6nUMEZESsf9oPre/8jUrsvbz5E0X0q996RzLYmZpzjl/cevpjFwRkVJUIzaKN37ZhYuSa/HAu8t5d8lWT/Oo9EVESlnVSpG8cntnejaL53fvfcur8zd5lkWlLyJSBmKjfbz4i05c0aou//PRap7/aqMnOVT6IiJlpFKkj7G3dqTvhfV54tO1PPnZOsp6XlW3SxQRKUNRvgieurk9sVE+xnyZzpG8Av7rmpaYlc0N11X6IiJlzBdhPP7ztsRG+5gwbxNH8wt4rF8bIiJKv/hV+iIiHoiIMP50XStiony8MHsjufkneKJ/WyJ9pTvqrtIXEfGImfG7PhdQOdrHkzPXk5tfwJhBHfCV4h6/Sl9ExENmxqhLm1E52sf+o/mlWvig0hcRKReG9mxcJu+jQzZFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIyUu9slmlkOsOUcfkQcsKuE4lR0+ix+TJ/Hj+nz+LdQ+CwaOufii1up3JX+uTKz1GDuExkO9Fn8mD6PH9Pn8W/h9FloeEdEJIyo9EVEwkgolv54rwOUI/osfkyfx4/p8/i3sPksQm5MX0RETi8U9/RFROQ0Qqb0zayPma0zs3Qze9jrPF4ysyQzm2Vmq81slZnd73Umr5mZz8yWmdnHXmfxmpnVNLMpZrbWzNaYWTevM3nJzH4V+Huy0szeMbMYrzOVppAofTPzAWOBq4BWwCAza+VtKk8dB37tnGsFdAVGhPnnAXA/sMbrEOXE08CnzrkWwIWE8ediZgnAKMDvnGsD+ICB3qYqXSFR+kBnIN05l+GcywMmAv08zuQZ59w259zSwJ8PUviXOsHbVN4xs0TgGmCC11m8ZmY1gF7ASwDOuTzn3D5vU3kuEog1s0igMvCdx3lKVaiUfgKQWeR5FmFcckWZWTLQAVjsbRJPjQZ+C5zwOkg50AjIAV4JDHdNMLMqXofyinMuG/gHsBXYBux3zn3mbarSFSqlL6dgZlWB94AHnHMHvM7jBTO7FtjpnEvzOks5EQl0BJ53znUADgNhOwdmZrUoHBVoBNQHqpjZbd6mKl2hUvrZQFKR54mBZWHLzKIoLPy3nHPve53HQz2Avma2mcJhv5+Z2ZveRvJUFpDlnPv+X35TKPwSCFeXAZuccznOuXzgfaC7x5lKVaiU/hKgmZk1MrNoCidipnqcyTNmZhSO2a5xzj3pdR4vOececc4lOueSKfz/4kvnXEjvyf0U59x2INPMLggsuhRY7WEkr20FuppZ5cDfm0sJ8YntSK8DlATn3HEzGwnMoHD2/WXn3CqPY3mpBzAY+NbMlgeW/d45N93DTFJ+3Ae8FdhBygDu8DiPZ5xzi81sCrCUwqPelhHiZ+fqjFwRkTASKsM7IiISBJW+iEgYUemLiIQRlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgY+X93TYd2utQMuAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 83.33%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}