{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Stacked LSTM with Drop out.\n",
    "\n",
    "### Many to One Classification by Stacked LSTM with Drop out\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class** \n",
    "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n",
    "- Applying **Stacking** to model by `tf.contrib.rnn.MultiRNNCell`\n",
    "- Replacing **RNN Cell** with **LSTM Cell**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n",
    "    - https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharStackedLSTM class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharStackedLSTM:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, dic, hidden_dims = [32, 16]):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            self._keep_prob = tf.placeholder(dtype = tf.float32)\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "            \n",
    "        # Stacked-LSTM\n",
    "        with tf.variable_scope('stacked_lstm'):\n",
    "            \n",
    "            cells = []\n",
    "            for hidden_dim in hidden_dims:\n",
    "                cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "                cell = tf.contrib.rnn.DropoutWrapper(cell = cell, output_keep_prob = self._keep_prob)\n",
    "                cells.append(cell)\n",
    "            else:\n",
    "                cells = tf.contrib.rnn.MultiRNNCell(cells = cells)\n",
    "                \n",
    "            _, states = tf.nn.dynamic_rnn(cell = cells, inputs = self._X_batch,\n",
    "                                         sequence_length = self._X_length, dtype = tf.float32)\n",
    "            \n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = states[-1].h, num_outputs = n_of_classes,\n",
    "                                               activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharStackedLSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_stacked_lstm = CharStackedLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n",
    "                                    n_of_classes = 2, dic = char_dic, hidden_dims = [32,16])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_stacked_lstm.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.692\n",
      "epoch :   2, tr_loss : 0.665\n",
      "epoch :   3, tr_loss : 0.638\n",
      "epoch :   4, tr_loss : 0.610\n",
      "epoch :   5, tr_loss : 0.572\n",
      "epoch :   6, tr_loss : 0.503\n",
      "epoch :   7, tr_loss : 0.439\n",
      "epoch :   8, tr_loss : 0.338\n",
      "epoch :   9, tr_loss : 0.282\n",
      "epoch :  10, tr_loss : 0.229\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_stacked_lstm.ce_loss],\n",
    "                                  feed_dict = {char_stacked_lstm._keep_prob : .5})\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x117ea0048>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH+tJREFUeJzt3Xl0VdX9/vH3JwkBwpAwhCEkEGSSKUEIM6WKEw4BxKrQYh3rBIp+7a/a1n5rtV+r1lYsg0qdBUHrSKiCVlEEBUmQMA9hDAlDGBJmMu3fHwkaLZoANzl3eF5rsRb35iTnWWfBs3bOuXtvc84hIiLBJczrACIi4nsqdxGRIKRyFxEJQip3EZEgpHIXEQlCKncRkSCkchcRCUIqdxGRIKRyFxEJQhFenbhp06YuMTHRq9OLiASkjIyMPc652MqO86zcExMTSU9P9+r0IiIBycy2VuU43ZYREQlCVSp3MxtqZuvMLMvM7j/J1580s2Xlf9abWb7vo4qISFVVelvGzMKBycCFwHZgiZnNcs6tPnGMc+6eCsffCZxTDVlFRKSKqjJy7wNkOec2OecKgZnA8B85fjQwwxfhRETk9FSl3FsB2RVeby9/77+YWRugLfDJmUcTEZHT5esHqqOAN51zJSf7opndYmbpZpael5fn41OLiMgJVSn3HCChwuv48vdOZhQ/ckvGOTfVOZfinEuJja30Y5oiInKaqlLuS4AOZtbWzCIpK/BZ3z/IzM4GGgFf+jbid2Vm5zPpkw1s2XO4Ok8jIhLQKv20jHOu2MzGAXOBcOAF59wqM3sISHfOnSj6UcBMV82bsi7atJcnPlzPEx+uJzk+mtTkOC5LaknL6LrVeVoRkYBiXm2QnZKS4k53hmpu/lFmL88lLXMHK3IKMIPeiY1JTY7j0m4taFK/to/Tioj4BzPLcM6lVHpcIJZ7RZvyDjF7+Q5mZeaStfsQ4WHGwPZNSU1qyUVdWxBdt5YP0oqI+IeQKfcTnHOs3XmQtMxc0pbnkr3vKJHhYZzbKZbU5Dgu6NycupHhPjufiIgXQq7cK3LOsSw7n7TMHcxensvug8eJigzngs7NSU2OY3DHptSOUNGLSOAJ6XKvqKTU8dXmfaQtz+WDFTvYf6SIhnUiGNqtBanJcfQ/qwkR4Vo/TUQCg8r9JIpKSlmQtYe0zFw+XLWLQ8eLaVo/kku7t2RYchw9WzciLMxqNJOIyKlQuVfiWFEJn67bTVrmDv6zZhfHi0uJi67D5clxDEuOo2tcQ8xU9CLiX1Tup+DQ8WL+s3oXaZm5fLY+j+JSR9um9UhNasmwHnG0b9bA64giIoDK/bTlHylkzsqdpC3P5cuNeyl1cHaLBqSWj+gTGkd5HVFEQpjK3Qd2HzzG++WfoV+6rWz/kR4JMQwrnxXbvGEdjxOKSKhRuftY9r4j/HvFDmYty2X1jgOYQd+2J2bFtqRRvUivI4pICFC5V6Os3Ye+mSy1Ke8wEWHGoA5NGZYcx4VdmtOgjmbFikj1ULnXAOccq3ccYFZmLrMzd5CTf5TaEWEMObsZqclxDDm7GXVqabKUiPiOyr2GOedYui2ftMxcZi/fwZ5Dx6kXGc6FXZozrEccg9rHEhmhyVIicmZU7h4qKXUs2rSXtMxcPli5k4KjRcRE1eKSbi1ITYqj71lNCNdkKRE5DSp3P1FYXMrnG/LKZsWu3sWRwhJiG9Tmsu4tSU2Oo2frGE2WEpEqU7n7oaOFJXyydjdpmbl8sm43hcWltIqpS2pyHKnJLenSUrNiReTHqdz93IFjRXy0ahdpy3P5fMMeSkod7WLrlRd9HO1i63sdUUT8kMo9gOw7XMgHK3eQlpnL4s37cA66tGzIsB5xXJ7UkvhGmhUrImVU7gFqZ8GxsslSmblkZpfNiu3VphGpSS25NKklzRpoVqxIKFO5B4Fte4+QtjyXtMxc1u48SJhB/3ZNSE2KY2i3FsREaVasSKhRuQeZ9bvKtxDMzGXL3iPUCjcGdyjfQrBLc+rXjvA6oojUAJV7kHLOsSKn4JvJUjsKjlE7IozzOzcjNSmO8zQrViSoqdxDQGmpI33rfmYvz+X9FTvYc6iQepHhXNS1BanJLTUrViQIqdxDTHFJKYs27SufFbuDA8eKia5bi6Fdy/aK7XdWY+0VKxIEVO4hrLC4lAVZeaRl7uDDVTs5XFjyzV6xlyfFkdJGe8WKBCqVuwBle8XOW7ub2cu/3Su2RcM6XJ5UtvxBUny0ZsWKBBCVu/yXQ8eL+XjNt3vFFpU4WjeO+qboz27RQEUv4udU7vKjCo4UMXf1TtIyc/li415KSh3tm9UnNSmOy5NbavkDET+lcpcq23PoOB+sLCv6JVvKlj/oGteQy5PKlj/QpuAi/kPlLqflxPIHaZm5LCtf/uCc1jGkJmlTcBF/oHKXM5a97wizl5cV/YlNwfsklm0Kfkm3FjSpX9vriCIhR+UuPpW1+xCzy9e52Zh3mPAwY2D7pqQmteSiri2IrqtNwUVqgspdqoVzjrU7y9e5WZ5L9r6jRIaHMbhjLGP6teanHWP1iRuRaqRyl2rnnCNzewGzM3OZlZnL7oPHSYqP5s4hHbigczOVvEg1qGq5V2k+upkNNbN1ZpZlZvf/wDFXm9lqM1tlZq+damAJPGZGj4QYHri8CwvuG8KjI7uTf6SIX72SziVPfc6/l++gtNSbwYNIqKt05G5m4cB64EJgO7AEGO2cW13hmA7AG8AQ59x+M2vmnNv9Yz9XI/fgVFxSyqzMXCbNy2JT3mHaN6vP2PPakZoUp7VtRHzAlyP3PkCWc26Tc64QmAkM/94xvwImO+f2A1RW7BK8IsLDGNkzno/u+SkTR59DuBn3vJ7JBX//jDeWZFNYXOp1RJGQUJVybwVkV3i9vfy9ijoCHc1soZktMrOhvgoogSk8zEhNjuOD8T/h2Wt7Ub9OBL95aznnPfEpry7ayvHiEq8jigQ1X/2eHAF0AM4FRgP/NLOY7x9kZreYWbqZpefl5fno1OLPwsKMi7u2IG3cIF68vjfNGtbmD++uZPDj83hhwWaOFqrkRapDVco9B0io8Dq+/L2KtgOznHNFzrnNlN2j7/D9H+Scm+qcS3HOpcTGxp5uZglAZsZ5Zzfj7dsHMP3mviQ2qcdDs1fzk8c/4ZnPNnLoeLHXEUWCSlXKfQnQwczamlkkMAqY9b1j3qVs1I6ZNaXsNs0mH+aUIGFWNvnp9Vv788at/encsiGPfrCWQY99wsSPN3DgWJHXEUWCQqXl7pwrBsYBc4E1wBvOuVVm9pCZDSs/bC6w18xWA/OA/+ec21tdoSU49GnbmFdv6ss7dwygV+tG/O2j9Qx89BP+9uE69h8u9DqeSEDTJCbxGytzCpg8L4sPVu6kXmQ4Y/q34eZBZxHbQGvYiJygGaoSsNbvOsikT7KYvTyXyIgwRvdpza2D29EiWitSiqjcJeBtyjvElE838s7XOYSbcVVKPLef2474RlpfXkKXyl2CRva+I0z5dCNvZmTjHIzs2Yo7zm1PYtN6XkcTqXEqdwk6OwqO8uxnm5jx1TaKSkoZ3qMVY89rR/tmDbyOJlJjVO4StHYfPMZzn2/m1S+3cqy4hEu7tWTckPZ0btnQ62gi1U7lLkFv3+FCnl+wiZe/2Mqh48Vc2KU5dw5pT1L8f02OFgkaKncJGQVHinjpiy28sHAzBUeLGHteO+69sBNhYVpPXoKPT9dzF/Fn0VG1GH9BBxbcdx6jeicwed5Gxr++jGNFWrdGQleE1wFEfKVBnVr8ZWR32jSpx2Nz1rKz4CjPXptC43qRXkcTqXEauUtQMTNuP7cdk35+DpnbCxg5ZSGb9xz2OpZIjVO5S1C6PCmOGb/qS8HRIkZOWUj6ln1eRxKpUSp3CVq92jTmnTsGEhMVyc+fW0xaZq7XkURqjMpdglpi03q8ffsAesTHcOeMr5k8LwuvPiEmUpNU7hL0GtWL5NWb+zC8Rxx/nbuO+99aQVGJ9nKV4KZPy0hIqB0RzoRretC6cRQTP8kit+Aok3/Rk4Z1ankdTaRaaOQuIcPMuPeiTjz+syS+3LiXq57+kpz8o17HEqkWKncJOVenJPDyjX3IzT/KiMkLWbG9wOtIIj6ncpeQNLB9U966YwCR4WFc/eyX/Gf1Lq8jifiUyl1CVsfmDXhn7AA6NK/PLa+m89LCzV5HEvEZlbuEtGYN6jDzln6c37k5D6at5k9pqygp1UclJfCp3CXkRUVG8MyYXtw4sC0vLtzCbdMyOFJY7HUskTOichcBwsOM/03twp+GdeXjNbsYNXURuw8e8zqWyGlTuYtUcN2ARKZem8KGXYe4YvIXrN910OtIIqdF5S7yPRd0ac4bt/ansKSUK6d8wYINe7yOJHLKVO4iJ9E9Ppp3xw4kLqYu17/4FW8syfY6ksgpUbmL/IBWMXX51+396d+uCb95azlPzF2nRcckYKjcRX5Ewzq1eOH63ozqncCkeVmMn7mM48Xavk/8nxYOE6lErfAw/jKyO62bRPH4nHXsKDjK1GtTaKTt+8SPaeQuUgVmxh3ntmfi6PLt+57+gi3avk/8mMpd5BSkJsfx2s19yT9SyBVTFpKxVdv3iX9SuYucopTEb7fvG/1Pbd8n/knlLnIaTmzflxwfzZ0zvmbKp9q+T/yLyl3kNDWqF8mrN/VlWHIcj89Zx2/f1vZ94j/0aRmRM1Cn1rfb902al0VOvrbvE/9QpZG7mQ01s3VmlmVm95/k69ebWZ6ZLSv/c7Pvo4r4p7Aw49cXd+LxK7V9n/iPSsvdzMKBycAlQBdgtJl1OcmhrzvnepT/ec7HOUX83tW9E3jphm+371u6bb/XkSSEVWXk3gfIcs5tcs4VAjOB4dUbSyQwDepQtn1fnVphjHp2Ea8v2eZ1JAlRVSn3VkDFVZO2l7/3fVea2XIze9PMEk72g8zsFjNLN7P0vLy804gr4v86Nm9A2rhB9D2rMfe9tYI/vLuSwmI9aJWa5atPy6QBic65JOAj4OWTHeScm+qcS3HOpcTGxvro1CL+JyYqkhev780tg8/i1UVbGfPcYvYcOu51LAkhVSn3HKDiSDy+/L1vOOf2OudO/Mt9Dujlm3gigSsiPIzfXdqZp0b1YHlOPqkTF7B8e77XsSREVKXclwAdzKytmUUCo4BZFQ8ws5YVXg4D1vguokhgG96jFW/eNoAwM372zJe8lbHd60gSAiotd+dcMTAOmEtZab/hnFtlZg+Z2bDyw+4ys1VmlgncBVxfXYFFAlG3VtHMGjeQXq0bce+/MvlT2ipNeJJqZV5NmU5JSXHp6emenFvEK0UlpTzy/hpeXLiF/mc1YfIvetJYSwfLKTCzDOdcSmXHafkBkRpUKzyMP6Z25YmrksnYtp/UiQtYmVPgdSwJQip3EQ/8rFc8/7q1P6XO8bNnvuC9ZTmVf5PIKVC5i3gkOSGGWeMG0b1VNONnLuOR99dQrPvw4iMqdxEPxTaozfSb+3FtvzZMnb+JG15aQv6RQq9jSRBQuYt4LDIijIdHdOPRkd1ZvGkfwyYtZO3OA17HkgCnchfxE6P6tGbmrf04VlTCyClf8P6KHV5HkgCmchfxIz1bN2L2nYM4u0UD7pi+lL/OXUtJqXZ4klOnchfxM80a1mHGLf0Y1TuByfM2ctPLSyg4WuR1LAkwKncRP1Q7Ipy/jOzOn0d0Y8GGPYyYvJANuw56HUsCiMpdxE+ZGWP6tWHGLf04eKyYEZMXMnfVTq9jSYBQuYv4ud6JjUm7cyDtm9Xn1lcz+PtH6ynVfXiphMpdJAC0jK7L67f258qe8fzj4w3c8moGB4/pPrz8MJW7SICoUyucJ65K4sHULsxbt5sRkxeyMe+Q17HET6ncRQKImXH9wLZMu6kv+48UMWLSQj5Zu8vrWOKHVO4iAah/uybMGjeQ1k2iuOnldCZ+vEH34eU7VO4iASq+URRv3jaA4clx/O2j9dwxfSmHjhd7HUv8hMpdJIDVjQznyWt68MBlnflw9U5GTlnIlj2HvY4lfkDlLhLgzIybf3IWr9zYl90HjzNs0gI+W5/ndSzxmMpdJEgM6tCUWWMHERdTlxte/IqnP92IV9toivdU7iJBpHWTKN6+YwCXdG/JY3PWcs/ry/SgNURFeB1ARHwrKjKCSaPPoUOz+kz4zwbaNq3P+As6eB1LapjKXSQImRnjz+/Atr1HmPDxepISojmvUzOvY0kN0m0ZkSBlZvzfFd3p1LwBd89cRva+I15HkhqkchcJYnUjw3n22l6UOsdt0zI4VlTidSSpISp3kSDXpkk9JlzTg1W5B3jg3ZX6BE2IULmLhIDzOzfnriHteTNjOzO+yvY6jtQAlbtIiBh/QUd+2jGWB2etYll2vtdxpJqp3EVCRHiY8dSoHjRrWJs7pmWw99BxryNJNVK5i4SQmKhInhnTiz2HC7lr5teUaIJT0FK5i4SYbq2i+fOIbizM2ssTH67zOo5UE5W7SAi6OiWB0X1a8/SnG5mzUptuByOVu0iIenBYF5Ljo/n1vzLZpO36go7KXSRE1Y4IZ8qYXtQKN26blsFhbfQRVKpU7mY21MzWmVmWmd3/I8ddaWbOzFJ8F1FEqkurmLpMHN2TrN2HuO+t5ZrgFEQqLXczCwcmA5cAXYDRZtblJMc1AMYDi30dUkSqz6AOTfn1xZ2YvXwHLy7c4nUc8ZGqjNz7AFnOuU3OuUJgJjD8JMc9DDwGHPNhPhGpAbf/tB0XdWnOI++v4avN+7yOIz5QlXJvBVScr7y9/L1vmFlPIME5928fZhORGmJmPHF1MgmNoxj72lJ2H9AYLdCd8QNVMwsD/g7cW4VjbzGzdDNLz8vTHo8i/qRhnVo8M6YXh44VM/a1pRSVlHodSc5AVco9B0io8Dq+/L0TGgDdgE/NbAvQD5h1soeqzrmpzrkU51xKbGzs6acWkWrRqUUDHr2yO0u27OeR99d4HUfOQFXKfQnQwczamlkkMAqYdeKLzrkC51xT51yicy4RWAQMc86lV0tiEalWw3u04oaBiby4cAvvLcup/BvEL1Va7s65YmAcMBdYA7zhnFtlZg+Z2bDqDigiNe93l3YmpU0j7n9rBet2HvQ6jpwG8+pzrSkpKS49XYN7EX+1+8AxLpu4gPq1I3hv3EAa1qnldSQBzCzDOVfpXCLNUBWRk2rWsA6Tf96T7H1HuPeNTEq1gmRAUbmLyA/q07Yxv7u0Mx+t3sUz8zd6HUdOgcpdRH7UDQMTSU2O44m561iwYY/XcaSKVO4i8qPMjEdHdqd9s/rcNfNrcvKPeh1JqkDlLiKVqlc7gmfG9KKwuJQ7pmVwvLjE60hSCZW7iFTJWbH1eeKqZDK3F/DgrNVex5FKqNxFpMqGdmvB7ee2Y8ZX23hjSXbl3yCeUbmLyCm598KODGzfhAfeW8nKnAKv48gPULmLyCmJCA/jH6POoUm9SG6blsH+w4VeR5KTULmLyClrUr82T4/pxe4Dxxn/+jJKNMHJ76jcReS09EiI4cFhXZm/Po+nPt7gdRz5HpW7iJy20X0SuKpXPP/4eAMfr9nldRypQOUuIqfNzHh4RDe6xjXknteXsXXvYa8jSTmVu4ickTq1wnlmTC/MjNumLeVooSY4+QOVu4icsYTGUUwY1YO1Ow/w+3dW4NVS4vItlbuI+MR5nZpx9/kdefvrHKYt2up1nJCnchcRn7lzSHvO6xTLQ7NXk7F1v9dxQprKXUR8JizMmHDNObSMrssd0zPIO3jc60ghS+UuIj4VHVWLp8f0JP9IEXfOWEpxSanXkUKSyl1EfK5rXDSPXNGdRZv28de567yOE5IivA4gIsHpyl7xLMvO59n5m+jaKpphyXFeRwopGrmLSLX5w+Vd6Nk6hvEzv+Yv76+hsFi3aGqKyl1Eqk1kRBjTb+7H6D6teXb+Jq6YspCs3Ye8jhUSVO4iUq3qRobzyBXdmXptL3Lzj3L5xM+ZvnirJjpVM5W7iNSIi7q2YO7dg+md2Jjfv7OSX72Swd5D+qhkdVG5i0iNadawDi/f0IcHLuvM/PV5DH3qc+avz/M6VlBSuYtIjQoLM27+yVm8O3YgMXVr8csXvuLh2as5VqQFx3xJ5S4inugS15C0OwdxXf82PL9gMyMmL2T9roNexwoaKncR8UydWuH8aXg3Xrg+hT2HjpM6cQEvf7FFD1t9QOUuIp4bcnZzPhg/mP7tmvDHWau48aUlWpfmDKncRcQvxDaozYvX9+ZPw7qycONeLnlqPvPW7vY6VsBSuYuI3zAzrhuQSNq4QTStX5sbXlrCH99bqYetp0HlLiJ+p1OLBrw7diA3DmzLy19uZdikBazZccDrWAFF5S4ifqlOrXD+N7ULL9/Yh/1Hihg+aSHPL9hMaaketlZFlcrdzIaa2TozyzKz+0/y9dvMbIWZLTOzBWbWxfdRRSQU/bRjLHPG/4TBHZvy8OzVXPfiV+w+cMzrWH6v0nI3s3BgMnAJ0AUYfZLyfs0519051wN4HPi7z5OKSMhqUr82//xlCn8e0Y0lW/Yx9KnP+Wj1Lq9j+bWqjNz7AFnOuU3OuUJgJjC84gHOuYo3w+oB+r1JRHzKzBjTrw2z7xxEi4Z1+NUr6fz+nRUcLdTD1pOpSrm3ArIrvN5e/t53mNlYM9tI2cj9Lt/EExH5rvbNGvDO2AHcMvgspi/exmUTP2dlToHXsfyOzx6oOucmO+faAfcBD5zsGDO7xczSzSw9L0+LBYnI6akdEc7vLu3MtJv6cvh4MVdMWcjU+Rv1sLWCqpR7DpBQ4XV8+Xs/ZCYw4mRfcM5Ndc6lOOdSYmNjq55SROQkBnVoypzxgxlydjMeeX8tY55fzM4CPWyFqpX7EqCDmbU1s0hgFDCr4gFm1qHCy8uADb6LKCLywxrVi+SZMb14dGR3vt6Wz8UT5jNn5Q6vY3mu0nJ3zhUD44C5wBrgDefcKjN7yMyGlR82zsxWmdky4H+A66otsYjI95gZo/q05t93DaJ14yhum7aU+95czuHjxV5H84x5tfpaSkqKS09P9+TcIhK8CotLefI/63nms40kNqnHhGt6kJwQ43UsnzGzDOdcSmXHaYaqiASVyIgw7ht6Nq/d3I9jRSVc+fQXTJ6XRUmIPWxVuYtIUOrfrglzxg/m4q4t+OvcdYz+5yJy8o96HavGqNxFJGhFR9Vi0s/P4a8/S2JVTgEXPzmfGV9tC4nNQFTuIhLUzIyrUhKYc/dgurVqyG/fXsEvX/gq6EfxKncRCQkJjaN47eZ+PDyiGxlb93Pxk/N5bXHwjuJV7iISMsLCjGv7tWHu3YPp3iqa372zgmuf/4rt+494Hc3nVO4iEnISGkcx/ea+PDyiG0u37WfohM+DbhSvcheRkFRxFJ8UH3yjeJW7iIS0hMZRTLupL38e0Y2vt5Xdi5++eGvAj+JV7iIS8sLCytaKn3P3YHq0juH376xkzPOLA3oUr3IXESl3YhT/f1d0Y9m2fC5+cj7TFgXmKF7lLiJSgZnxi77fjuIfeHclv3huMdn7AmsUr3IXETmJiqP4zOx8hk4oG8UHyoYgKncRkR9wYhQ/957BnNO6EQ+8W3YvPhBG8Sp3EZFKxDeK4tWb+vCXkd1Zvr2AiyfM51U/H8Wr3EVEqsDMGN2nNXPvGUyvNo34g5/fi1e5i4icglYxdXnlxrJR/Iqc8lH8l1v8bhSvchcROUX/NYp/b5XfjeJV7iIip+nEKP5RPxzFq9xFRM7Aic25K47if/7cIrbt9XYUr3IXEfGBE6P4x67szqqcAwx9aj6vfOndKF7lLiLiI2bGNb3LRvEpiY353/dWMfqf3oziVe4iIj4WF1OXl2/ozWNXdmd17gEunjCfl7+o2VG8yl1EpBpUHMX3aduYP84qG8Vv3Xu4Rs6vchcRqUZxMXV56YbePH5lEqtzDzB0wuekZeZW+3lV7iIi1czMuLp3Ah/+z2AGtm9C26b1qv2cEdV+BhERAaBldF2eu653jZxLI3cRkSCkchcRCUIqdxGRIKRyFxEJQip3EZEgpHIXEQlCKncRkSCkchcRCULmnDfLUZpZHrD1NL+9KbDHh3ECna7Hd+l6fEvX4ruC4Xq0cc7FVnaQZ+V+Jsws3TmX4nUOf6Hr8V26Ht/StfiuULoeui0jIhKEVO4iIkEoUMt9qtcB/Iyux3fpenxL1+K7QuZ6BOQ9dxER+XGBOnIXEZEfEXDlbmZDzWydmWWZ2f1e5/GKmSWY2TwzW21mq8xsvNeZ/IGZhZvZ12Y22+ssXjOzGDN708zWmtkaM+vvdSavmNk95f9PVprZDDOr43Wm6hZQ5W5m4cBk4BKgCzDazLp4m8ozxcC9zrkuQD9gbAhfi4rGA2u8DuEnngLmOOfOBpIJ0etiZq2Au4AU51w3IBwY5W2q6hdQ5Q70AbKcc5ucc4XATGC4x5k84Zzb4ZxbWv73g5T9x23lbSpvmVk8cBnwnNdZvGZm0cBg4HkA51yhcy7f21SeigDqmlkEEAVU/yamHgu0cm8FZFd4vZ0QLzQAM0sEzgEWe5vEcxOA3wClXgfxA22BPODF8ttUz5lZ9W/c6YeccznAE8A2YAdQ4Jz70NtU1S/Qyl2+x8zqA28BdzvnDnidxytmdjmw2zmX4XUWPxEB9ASeds6dAxwGQvIZlZk1ouw3/LZAHFDPzMZ4m6r6BVq55wAJFV7Hl78XksysFmXFPt0597bXeTw2EBhmZlsou103xMymeRvJU9uB7c65E7/NvUlZ2YeiC4DNzrk851wR8DYwwONM1S7Qyn0J0MHM2ppZJGUPRWZ5nMkTZmaU3U9d45z7u9d5vOac+61zLt45l0jZv4tPnHNBPzr7Ic65nUC2mXUqf+t8YLWHkby0DehnZlHl/2/OJwQeLkd4HeBUOOeKzWwcMJeyJ94vOOdWeRzLKwOBa4EVZras/L3fOefe9zCT+Jc7genlA6FNwA0e5/GEc26xmb0JLKXsU2ZfEwIzVTVDVUQkCAXabRkREakClbuISBBSuYuIBCGVu4hIEFK5i4gEIZW7iEgQUrmLiAQhlbuISBD6/yzdG3f0PeNDAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_stacked_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 83.33%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}