{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Gated Recurrent Unit.\n",
    "\n",
    "### Many to One Classification by GRU\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.0\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcdefghijklmnopqrstuvwxyz *'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n"
     ]
    }
   ],
   "source": [
    "char_dic = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create pad_seq function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_seq(sequences, max_len, dic):\n",
    "    seq_len, seq_indices = [], []\n",
    "    for seq in sequences:\n",
    "        seq_len.append(len(seq))\n",
    "        seq_idx = [dic.get(char) for char in seq]\n",
    "        seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n",
    "        seq_indices.append(seq_idx)\n",
    "    return seq_len, seq_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply pad_seq function to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 10\n",
    "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharGRU class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharGRU:\n",
    "    def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n",
    "        \n",
    "        # data pipeline\n",
    "        with tf.variable_scope('input_layer'):\n",
    "            self._X_length = X_length\n",
    "            self._X_indices = X_indices\n",
    "            self._y = y\n",
    "            \n",
    "            one_hot = tf.eye(len(dic), dtype = tf.float32)\n",
    "            self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n",
    "                                            trainable = False) # embedding vector training 안할 것이기 때문\n",
    "            self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n",
    "        \n",
    "        # GRU cell\n",
    "        with tf.variable_scope('gru_cell'):\n",
    "            gru_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n",
    "            _, state = tf.nn.dynamic_rnn(cell = gru_cell, inputs = self._X_batch,\n",
    "                                         sequence_length = self._X_length, dtype = tf.float32)\n",
    "            \n",
    "        with tf.variable_scope('output_layer'):\n",
    "            self._score = slim.fully_connected(inputs = state, num_outputs = n_of_classes, activation_fn = None)\n",
    "            \n",
    "        with tf.variable_scope('loss'):\n",
    "            self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n",
    "            \n",
    "        with tf.variable_scope('prediction'):\n",
    "            self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n",
    "    \n",
    "    def predict(self, sess, X_length, X_indices):\n",
    "        feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n",
    "        return sess.run(self._prediction, feed_dict = feed_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharGRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "tr_iterator = tr_dataset.make_initializable_iterator()\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_gru = CharGRU(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb, n_of_classes = 2,\n",
    "                   hidden_dim = 16, dic = char_dic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creat training op and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create training op\n",
    "opt = tf.train.AdamOptimizer(learning_rate = lr)\n",
    "training_op = opt.minimize(loss = char_gru.ce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.711\n",
      "epoch :   2, tr_loss : 0.673\n",
      "epoch :   3, tr_loss : 0.633\n",
      "epoch :   4, tr_loss : 0.599\n",
      "epoch :   5, tr_loss : 0.573\n",
      "epoch :   6, tr_loss : 0.538\n",
      "epoch :   7, tr_loss : 0.514\n",
      "epoch :   8, tr_loss : 0.479\n",
      "epoch :   9, tr_loss : 0.455\n",
      "epoch :  10, tr_loss : 0.424\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    sess.run(tr_iterator.initializer)\n",
    "    try:\n",
    "        while True:\n",
    "            _, tr_loss = sess.run(fetches = [training_op, char_gru.ce_loss])\n",
    "            avg_tr_loss += tr_loss\n",
    "            tr_step += 1\n",
    "            \n",
    "    except tf.errors.OutOfRangeError:\n",
    "        pass\n",
    "    \n",
    "    avg_tr_loss /= tr_step\n",
    "    tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x111e92d68>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd8VHW+//HXZ1LpNaAQOgFF6aEFAoKCqEgRQcACllU6qHf3uvfevbs/9bqu7iIgKILo2hERFUFBFKS30ER6CC3U0Dsk8P39kdndyCIMkOQkM+/n4zEPcr5zTuadecA7h3O+c4455xARkdDg8zqAiIjkHpW+iEgIUemLiIQQlb6ISAhR6YuIhBCVvohICFHpi4iEEJW+iEgIUemLiISQcK8DXKx06dKucuXKXscQEclXli9ffsA5F3Ol9fJc6VeuXJmkpCSvY4iI5Ctmtj2Q9XR4R0QkhKj0RURCiEpfRCSEqPRFREKISl9EJISo9EVEQohKX0QkhARN6V+44Hjpm/XsOHjK6ygiInlW0JT+toMnmbB0Bx1en8esDfu8jiMikicFTelXjSnM1EGJxJYoyGN/T2LYdxs5f0E3fRcRySpoSh+gYqmCTO6fQLeGsYyclUyfd5dy6OQ5r2OJiOQZQVX6ANERYbzarS4v31ebJVsPce/r81m984jXsURE8oSASt/M2pvZRjNLNrPnLvH8a2a2yv/YZGZHsjzX28w2+x+9szP85fRoXJFJfZsB0G3MIj5ash3ndLhHREKbXakIzSwM2AS0BVKBZUBP59y6X1l/EFDfOfeYmZUEkoB4wAHLgYbOucO/9nrx8fEuO6+yefjkOYZ+uoo5m9Lo2iCWFzvfSoHIsGz7/iIieYGZLXfOxV9pvUD29BsDyc65FOfcOWAC0Oky6/cEPvF/fScw0zl3yF/0M4H2AbxmtilRKJJ3+zRiyO1xTF6Zyn1vLmT7wZO5GUFEJM8IpPTLAzuzLKf6x/6NmVUCqgCzrnbbnOTzGU+3rcE7fRqx+8hpOrw+n+/XaVqniISe7D6R2wOY5Jw7fzUbmdmTZpZkZklpaWnZHOlfWtcsw9RBLahUqiBPvJ/EqzM2aFqniISUQEp/F1Ahy3Ksf+xSevCvQzsBb+ucG+uci3fOxcfEXPFuX9elQsmCTOqbQI9GFRg9ewu931nKwRNnc/Q1RUTyikBKfxkQZ2ZVzCySzGKfcvFKZnYTUAJYlGV4BtDOzEqYWQmgnX/MU9ERYbzctQ6vdK3D0m2H6PD6fFbu+NVzyyIiQeOKpe+cywAGklnW64GJzrm1Zva8mXXMsmoPYILLMh3IOXcIeIHMXxzLgOf9Y3lC90YVmNwvgTCf0f2tRXywaJumdYpIULvilM3clt1TNgNx5NQ5nv50FbM3ptGlfnle6lJb0zpFJF/JzimbQa94wUjG927EM21r8OWqXXR5YwFbD2hap4gEH5W+n89nDL49jr8/2pi9x87Q8fX5fLd2r9exRESylUr/Iq1qxDB1UAuqxBTiyQ+W8/K3G8g4f8HrWCIi2UKlfwmxJQryWd9m9GpSkTFztvDw+KUc0LROEQkCKv1fERUexktdavPq/XVYseMwHUbOZ/l2TesUkfxNpX8F3eIrMLl/ApHhPh54axF/X7BV0zpFJN9S6QfglnLF+HpgC1rViOFPX69jyIRVnDqX4XUsEZGrptIPULGCEYx7JJ7f3lmTqT/tpvPoBaSknfA6lojIVVHpXwWfzxjQujrvP9aEAyfO0XHUAqb/vMfrWCIiAVPpX4MWcaWZOqgF1coUpu+HK3jpm/Wa1iki+YJK/xqVK16AiU815aGmFRk7N4UH317C/uNnvI4lInJZKv3rEBUexoudazOse11Wpx6hw8j5JG3LM9eTExH5Nyr9bHBfg1i+6N+cgpFh9Bi7mPHzNa1TRPImlX42ufnGonw1sAWtbyrDC1PX8czE1ZxJv6obiImI5DiVfjYqViCCtx5qyLNta/DFyl10G7OI3UdOex1LROSfVPrZzOczBt0ex7hH4tl64CQdR+k4v4jkHSr9HNK2Vlm+6J9A4ahweo5bzCdLd3gdSUREpZ+T4soW4asBLWhWrTS/n7yG//lyDecyNJ9fRLyj0s9hxQpG8G6fRjzVsiofLt7BQ+OX6DLNIuIZlX4uCPMZv7/7ZoY/UI/VO4/QadQCft511OtYIhKCVPq5qHP98kzqm8AF57h/zEKmrN7tdSQRCTEq/VxWO7YYUwa2oHb5Ygz+ZCV/mb6B8xf0QS4RyR0qfQ/EFInioyea0qtJRd78cQuPv7eMo6fTvY4lIiFApe+RyHAfL3WpzYudb2X+5gN0Gb2A5P26Pr+I5CyVvscealqJj3/TlKOn0+kyegE/rN/ndSQRCWIq/TygcZWSTBnUgkqlC/LE+0mMnp2sC7aJSI4IqPTNrL2ZbTSzZDN77lfW6W5m68xsrZl9nGX8vJmt8j+mZFfwYFO+eAE+eyqBe+uU49UZGxn4yUrdh1dEsl34lVYwszBgNNAWSAWWmdkU59y6LOvEAb8HmjvnDptZmSzf4rRzrl425w5KBSLDGNGjHreUK8rL0zeQknaSsQ83pELJgl5HE5EgEciefmMg2TmX4pw7B0wAOl20zm+A0c65wwDOuf3ZGzN0mBlPtarGu30akXr4FB1HzWfRloNexxKRIBFI6ZcHdmZZTvWPZVUDqGFmC8xssZm1z/JctJkl+cc7X+oFzOxJ/zpJaWlpV/UDBKvbapbhqwHNKVkokofGL+G9hdt0nF9Erlt2ncgNB+KA24CewDgzK+5/rpJzLh7oBQw3s2oXb+ycG+uci3fOxcfExGRTpPyvakxhvhzQnNY1Y/jjlLU89/kazmboxiwicu0CKf1dQIUsy7H+saxSgSnOuXTn3FZgE5m/BHDO7fL/mQL8CNS/zswhpUh0BGMfjmdQm+p8mrSTnmMXs/+YbsAuItcmkNJfBsSZWRUziwR6ABfPwvmSzL18zKw0mYd7UsyshJlFZRlvDqxDrorPZzzbriZvPNiA9XuOc++o+azaecTrWCKSD12x9J1zGcBAYAawHpjonFtrZs+bWUf/ajOAg2a2DpgN/NY5dxC4GUgys9X+8ZezzvqRq3N37Rv5vF8CEWE+ur+1iM+Xp3odSUTyGctrJwfj4+NdUlKS1zHytEMnzzHgoxUsSjnI4y2q8Pu7biI8TJ+zEwllZrbcf/70stQU+VDJQpG8/3hj+iRUZvz8rfR5dxlHTp3zOpaI5AMq/XwqIszHnzrewiv312Hp1kN0HLWAjXuPex1LRPI4lX4+1z2+AhOeasqZ9PN0eWMB03/e63UkEcnDVPpBoEHFEnw9qAU1yhah74fLeW3mJi7oxiwicgkq/SBRtmg0E55sStcGsYz4YTN9P1zOibO6YJuI/JJKP4hER4Tx1251+N8Otfhhw366jF7A+j3HvI4lInmISj/ImBmPtajC+4815vCpdDqNWsC4uSk63CMigEo/aDWvXpoZQxO5rWYM//fNeh58ewm7j5z2OpaIeEylH8RKFY7irYcb8peutVmdeoT2w+cyZfVur2OJiIdU+kHOzHigUUW+HZJI9TKFGfzJSoZMWMnR0+leRxMRD6j0Q0SlUoWY+FQznmlbg6k/7eGu4XNZuOWA17FEJJep9ENIeJiPwbfH8Xm/BKIiwnjw7SW89M16XaNfJISo9ENQvQrFmTa4Bb0aV2Ts3BQ66RIOIiFDpR+iCkaG839dajO+dzwHTpzl3lHzGT9/q6Z2igQ5lX6Iu/3mskwf2pKWcaV5Yeo6Hn5nCXuOamqnSLBS6QulC0cx7pF4/nxfbVZsP0L74fOY+pOmdooEI5W+AJlTO3s2rsg3QxKpUroQAz9eydOfruLYGU3tFAkmKn35hSqlCzGpbzOG3hHHlNW7uWv4PJakHPQ6lohkE5W+/JvwMB9D76jBpL7NiAgzeoxbzMvfbuBcxgWvo4nIdVLpy6+qX7EE0wYn0qNRBcbM2ULn0QvYvE9TO0XyM5W+XFahqHD+fF8dxj0Sz75jZ+jw+nzeXaCpnSL5lUpfAtK2VubUzubVS/P/vl5H73eXsu/YGa9jichVUulLwGKKRDG+dzwvdr6VZdsOcefwuXyzZo/XsUTkKqj05aqYGQ81rcS0wYlULFmQ/h+t4NmJqzmuqZ0i+YJKX65JtZjCfN4vgcFtqvPFylTuGjGPpVsPeR1LRK4goNI3s/ZmttHMks3suV9Zp7uZrTOztWb2cZbx3ma22f/onV3BxXsRYT6eaVeTz/o2w2fGA2MX8cp0Te0UycvMucvPwjCzMGAT0BZIBZYBPZ1z67KsEwdMBNo45w6bWRnn3H4zKwkkAfGAA5YDDZ1zh3/t9eLj411SUtJ1/liS206czeD5r9cyMSmVW8sXZfgD9ahepojXsURChpktd87FX2m9QPb0GwPJzrkU59w5YALQ6aJ1fgOM/keZO+f2+8fvBGY65w75n5sJtA/0h5D8o3BUOK/cX5e3Hm7IrsOnuWfkfN5ftI0r7VSISO4KpPTLAzuzLKf6x7KqAdQwswVmttjM2l/FthJE7rzlBmYMbUnTqqX436/W0ufdZezX1E6RPCO7TuSGA3HAbUBPYJyZFQ90YzN70sySzCwpLS0tmyKJV8oUjebvjzbihU63sDjlIG3+NofXZm7SxdtE8oBASn8XUCHLcqx/LKtUYIpzLt05t5XMcwBxAW6Lc26scy7eORcfExNzNfkljzIzHm5WmW+HJNKyRmlG/LCZxL/M5o0fkzl1LsPreCIhK5DSXwbEmVkVM4sEegBTLlrnSzL38jGz0mQe7kkBZgDtzKyEmZUA2vnHJERUjSnMGw82ZOqgFsRXKsEr0zfS8pXZvD0vhTPpujevSG67Yuk75zKAgWSW9XpgonNurZk9b2Yd/avNAA6a2TpgNvBb59xB59wh4AUyf3EsA573j0mIubV8Mcb3acTk/gncdENRXpy2nlavzuaDxds1xVMkF11xymZu05TN0LBoy0GGzdzIsm2HKV+8AEPuiOO++uUJD9PnBUWuRXZO2RTJds2qlWLiU81477HGlCocye8m/UTb1+by1apdnNcVPEVyjEpfPGNmtKoRw1cDmjPukXiiwn0MmbCKu0bMZfrPezTHXyQHqPTFc2ZG21pl+WZwIqN61SfjgqPvhyu4d9R8Zm/Yr/IXyUYqfckzfD6jQ51yfDe0JX/rVpejp9N59O/L6PrmQhYmH/A6nkhQ0IlcybPSz1/gs6RUXp+1mT1Hz9C0akn+o11N4iuX9DqaSJ4T6Ilclb7keWfSzzNh6Q5Gzd7CgRNnaVUjhmfb1aBObMAf+hYJeip9CTqnz53n/UXbGDNnC4dPpdOuVlmeaVeDm24o6nU0Ec+p9CVonTibwbvztzJ2XgonzmbQoU45ht4RR7WYwl5HE/GMSl+C3tFT6Yybl8I7C7ZyJv08XerHMuT2OCqWKuh1NJFcp9KXkHHwxFnGzNnC+4u2c/6Co3ujCgxqU50bixXwOppIrlHpS8jZd+wMb8xO5uOlOzAzHmxSkX63VaNMkWivo4nkOJW+hKzUw6cYNSuZz5anEhnm45GESvRtWY0ShSK9jiaSY1T6EvK2HTjJiB828+WqXRSKDGdQm+r8JrEqPp95HU0k2+mCaxLyKpcuxGsP1OM7/+0b//ztBh5+Zwn7dPtGCWEqfQl6cWWLMO6RhrzStQ4rth/hrhHz+GH9Pq9jiXhCpS8hwczo3qgCXw9qQdmi0Tz+XhJ/mrKWsxm6e5eEFpW+hJTqZQrzRf8EHm1emb8v3EaX0QtJ3n/C61giuUalLyEnOiKMP957C+N7x7P32BnufX0+ny7boUs4S0hQ6UvIuv3msnw7JJH6FYvzn5+vYdAnKzl6Ot3rWCI5SqUvIa1s0Wg+eLwJv2tfk29/3svdI+axfPshr2OJ5BiVvoS8MJ/R/7bqfNa3GT4fdH9rMaNmbda9eiUoqfRF/BpULMG0wYncXftG/vrdJh56ewl7j2pOvwQXlb5IFkWjIxjZox6v3l+HVTuPcNeIuXy/TnP6JXio9EUuYmZ0i6/A1MEtKFe8AE+8n8Qfv/qZM+ma0y/5n0pf5FdUiynM5P4JPN6iCu8t2k7n0QtI3n/c61gi10WlL3IZUeFh/KFDLd7t04i042fp8Pp8PlmqOf2SfwVU+mbW3sw2mlmymT13ief7mFmama3yP57I8tz5LONTsjO8SG5pfVMZvh2SSHylkvx+8hoGfLyCo6c0p1/yn/ArrWBmYcBooC2QCiwzsynOuXUXrfqpc27gJb7FaedcveuPKuKtMkWjef+xxoydl8JfZ2xk9c55jOxZj4aVSnodTSRggezpNwaSnXMpzrlzwASgU87GEsmbfD6jb6tqTOqXQJjP6P7WYkb+oDn9kn8EUvrlgZ1ZllP9YxframY/mdkkM6uQZTzazJLMbLGZdb7UC5jZk/51ktLS0gJPL+KRehWKM21wCzrUuZFhMzfRa9xi9hw97XUskSvKrhO5XwOVnXN1gJnAe1meq+S/m0svYLiZVbt4Y+fcWOdcvHMuPiYmJpsiieSsItERDH+gHn/rVpc1u45y14h5zFi71+tYIpcVSOnvArLuucf6x/7JOXfQOXfWv/g20DDLc7v8f6YAPwL1ryOvSJ5iZnRtGMvUQS2ILVGApz5Yzh++1Jx+ybsCKf1lQJyZVTGzSKAH8ItZOGZ2Y5bFjsB6/3gJM4vyf10aaA5cfAJYJN+rGlOYyf2a85vEKnyweDudRi1g0z7N6Ze854ql75zLAAYCM8gs84nOubVm9ryZdfSvNtjM1prZamAw0Mc/fjOQ5B+fDbx8iVk/IkEhMtzHf99Ti78/2oiDJ89y7+vz+WjJds3plzzF8tpfyPj4eJeUlOR1DJHrsv/4GZ6duJp5mw/Q/pYbeLlrbYoXjPQ6lgQxM1vuP396WfpErkgOKFMkmvcebcx/3X0T36/fx90j5rF0q67TL95T6YvkEJ/PeLJlNT7vl0BEuI8eYxcx/PtNZJy/4HU0CWEqfZEcVrdCcaYNTqRzvfIM/34z7YbP5YuVqSp/8YRKXyQXFI4KZ9gD9Xjr4YZEhvl4+tPV3DFsDp8l7SRd5S+5SCdyRXLZhQuO79fvY+Sszfy86xixJQowoHV1ujaIJTJc+2FybQI9kavSF/GIc47ZG/cz4odkVu88Qrli0fRrXZ3u8bFEhYd5HU/yGZW+SD7hnGPu5gOM+H4TK3Yc4Yai0fRtVZUejSsSHaHyl8Co9EXyGeccC7ccZMQPm1m69RAxRaJ4qmVVHmxSiQKRKn+5PJW+SD62OOUgI3/YzMItByldOJInEqvycNNKFIq64i0wJESp9EWCQNK2Q4yclczcTWmUKBjBE4lVeaRZJYpER3gdTfIYlb5IEFm54zCvz0pm1ob9FI0O5/EWVenTvDLFCqj8JZNKXyQIrUk9yshZm5m5bh9FosJ5tHllHmtRRdf1EZW+SDBbu/soo2Yl8+3PeykUGUbvhMo8kViVkoVU/qFKpS8SAjbuPc6o2clM/Wk3BSLCeKhpJX6TWJWYIlFeR5NcptIXCSHJ+48zevYWvlq1i8hwH70aV+KpVlUpWzTa62iSS1T6IiFo64GTjJ6dzBcrdxHmM3o2qkDf26pxY7ECXkeTHKbSFwlhOw6e4o0fk5m0PBWfGd3iY+l3WzViSxT0OprkEJW+iJB6+BRv/riFiUk7cQ66Noilf+tqVCpVyOtoks1U+iLyT7uPnOatOVv4ZNlOzl9wdK5Xnv9sX5MyOuYfNFT6IvJv9h87w1tzU/hw8XYKRYXz5/tqc+ctN3gdS7KB7pErIv+mTNFo/tChFtMGJ1KueDRPfbCc5z7/iZNnM7yOJrlEpS8SgqqXKczkfs0Z0Loanybt5O6R81ix47DXsSQXqPRFQlRkuI/f3nkTnz7ZjIzzjm5jFvHaTN24Pdip9EVCXOMqJfl2aCKd6pVjxA+buX/MIrYdOOl1LMkhKn0RoWh0BMO612NUr/psPXCSu0fOY8LSHeS1iR5y/QIqfTNrb2YbzSzZzJ67xPN9zCzNzFb5H09kea63mW32P3pnZ3gRyV4d6pRj+tBE6lcsznOT1/DkB8s5eOKs17EkG12x9M0sDBgN3AXUAnqaWa1LrPqpc66e//G2f9uSwB+BJkBj4I9mViLb0otItruxWAE+eKwJ/3PPzczZmMadw+cxe+N+r2NJNglkT78xkOycS3HOnQMmAJ0C/P53AjOdc4ecc4eBmUD7a4sqIrnF5zOeSKzKVwObU6pQJI++u4z//epnTp8773U0uU6BlH55YGeW5VT/2MW6mtlPZjbJzCpc5bYikgfdfGNRvhrYnMdbVOH9Rdu5d9R8ft511OtYch2y60Tu10Bl51wdMvfm37uajc3sSTNLMrOktLS0bIokItkhOiKMP3SoxYePN+H4mXS6vLGAN3/cwvkLOsmbHwVS+ruAClmWY/1j/+ScO+ic+8fZnreBhoFu699+rHMu3jkXHxMTE2h2EclFLeJKM2NoS9rWKstfpm+g57jFpB4+5XUsuUqBlP4yIM7MqphZJNADmJJ1BTO7MctiR2C9/+sZQDszK+E/gdvOPyYi+VDxgpGM7tWAv3Wry7rdx7hr+Dy+XLlLUzvzkSuWvnMuAxhIZlmvByY659aa2fNm1tG/2mAzW2tmq4HBQB//toeAF8j8xbEMeN4/JiL5lJnRtWEs3w5JpOYNRRj66SoGT1jF0VPpXkeTAOgqmyJyzc5fcIyZs4XXZm4ipkgUf+tel4Rqpb2OFZJ0lU0RyXFhPmNA6+pM7p9AgYgwHnx7CS99s56zGZramVep9EXkutWJLc7UwS14sElFxs5NodOoBWzce9zrWHIJKn0RyRYFI8N5sXNt3ukTz4ETZ7l31Hzemb+VC5ramaeo9EUkW7W5qSzTh7akZVxpnp+6jt7vLmXfsTNexxI/lb6IZLvShaMY90g8L3WpTdK2w9w5fC7frNnjdSxBpS8iOcTM6NWkItMGt6BSyYL0/2gF//HZao6f0dROL6n0RSRHVY0pzKR+CQxuU53JK1K5e+Q8krbp4zpeUemLSI6LCPPxTLuafNa3GQDd31rE377bSLpuzZjrVPoikmsaVirJN4MT6dogltdnJXPfGwuZv/mALuOQi1T6IpKrikRH8Gq3urz5YAP2Hz/DQ+OX0PmNhcxct0/lnwt0GQYR8czZjPN8vnwXb85JZueh09x0QxEGtK7O3bVvJMxnXsfLVwK9DINKX0Q8l3H+AlNW72b07GS2pJ2kaulC9L2tGl3qlyciTAckAqHSF5F858IFx/S1exk1K5l1e45RvngB+raqSrf4CkRHhHkdL09T6YtIvuWc48eNabw+azMrdhwhpkgUTyZWpVeTihSKCvc6Xp6k0heRfM85x6KUg4yencyC5IMULxjBY82r0DuhMsUKRHgdL09R6YtIUFmx4zCjZyXzw4b9FIkK5+FmlXi8RRVKFY7yOlqeoNIXkaC0bvcxRv+YzDdr9hAV7qNX40o82bIqNxSL9jqap1T6IhLUkvef4M0ft/Dlql2E+W/h2K9VNSqWKuh1NE+o9EUkJOw8dIoxc7bwWVIq552jU91y9G9djeplingdLVep9EUkpOw7doZxc1P4aMkOzmScp/0tNzCgdXVuLV/M62i5QqUvIiHp0MlzvDN/K+8t3Mbxsxm0rhnDwDbVaVippNfRcpRKX0RC2tHT6XywaBvj52/l8Kl0mlUtxcA21UmoVgqz4LvEg0pfRAQ4dS6Dj5fsYOzcFPYfP0u9CsUZ1KY6bW4qE1Tlr9IXEcniTPp5Ji1PZcycLaQePs3NNxZlQOtq3HVrcFzcTaUvInIJ6ecvMGXVbkb/mExK2kmqxhRiUJvqdK5XPl/v+Qda+gFdvs7M2pvZRjNLNrPnLrNeVzNzZhbvX65sZqfNbJX/MSbwH0FEJPtFhPno2jCWmU+3YnSvBkSFh/H0p6vpNW4JWw+c9Dpejrti6ZtZGDAauAuoBfQ0s1qXWK8IMARYctFTW5xz9fyPvtmQWUTkuoX5jHvq3Mi0QS14qUttft59lPbD5/LGj8lBfRvHQPb0GwPJzrkU59w5YALQ6RLrvQD8BTiTjflERHKUz2f0alKR759pReuaZXhl+kY6jlrA6p1HvI6WIwIp/fLAzizLqf6xfzKzBkAF59y0S2xfxcxWmtkcM0u89qgiIjmnbNFoxjzckDEPNeTgibN0eWMBL0xdx6lzGV5Hy1bXfUsaM/MBw4BnL/H0HqCic64+8AzwsZkVvcT3eNLMkswsKS0t7XojiYhcs/a33sD3z7aiZ+OKjJ+/lXavzWXOpuDppUBKfxdQIctyrH/sH4oAtwI/mtk2oCkwxczinXNnnXMHAZxzy4EtQI2LX8A5N9Y5F++ci4+Jibm2n0REJJsUjY7g/7rUZuJTzYgK99H7naUMnbCSgyfOeh3tugVS+suAODOrYmaRQA9gyj+edM4ddc6Vds5Vds5VBhYDHZ1zSWYW4z8RjJlVBeKAlGz/KUREckDjKiX5Zkgig2+PY9qaPdwxbA6TV6SS16a6X40rlr5zLgMYCMwA1gMTnXNrzex5M+t4hc1bAj+Z2SpgEtDXOXfoekOLiOSWqPAwnmlbg2mDE6lcuhDPTFzNI+8sZeehU15Huyb6cJaISIDOX3B8uHg7r0zfwAUHz7arQZ+EyoSHXffp0euWrR/OEhGRzLn9vRMqM/OZViRUK8WL09bT5Y2FrN191OtoAVPpi4hcpXLFC/B273hG9arPnqOn6ThqAS9/u4Ez6ee9jnZFKn0RkWtgZnSoU47vn2lF1wblGTNnC3cOn8vC5ANeR7sslb6IyHUoXjCSV+6vy8dPNAGg19tL+N2k1Rw5dc7jZJem0hcRyQYJ1UszY2hL+t1Wjc9X7OKOYXOY+tPuPDe9U6UvIpJNoiPC+M/2NzFlYHPKFS/AwI9X8sR7Sew+ctrraP+k0hcRyWa3lCvG5H4J/M89N7Nwy0HaDpvDewu3cf6C93v9Kn0RkRwQHubjicSqfPd0SxpWLskfp6zl/jEL2bTvuKcU2AUXAAAD8ElEQVS5VPoiIjmoQsmCvPdoI157oC7bDpzknpHzGPbdRs5meDO9U6UvIpLDzIwu9WP5/plWdKhTjpGzkrl7xDyWbcv9q9Ko9EVEckmpwlG89kA93nusMWfSL9BtzCL++4s1HDuTnmsZVPoiIrmsVY0Yvnu6JY+3qMInS3fQdtgcpv+8N1deW6UvIuKBQlHh/KFDLb7o35yShaLo++FyBny0ggs5PMMnPEe/u4iIXFbdCsWZMrA54+alcPJsBj6f5ejrqfRFRDwWEeaj/23Vc+W1dHhHRCSEqPRFREKISl9EJISo9EVEQohKX0QkhKj0RURCiEpfRCSEqPRFREKI5bVbeZlZGrD9Or5FaSBv35k49+i9+CW9H7+k9+NfguG9qOSci7nSSnmu9K+XmSU55+K9zpEX6L34Jb0fv6T3419C6b3Q4R0RkRCi0hcRCSHBWPpjvQ6Qh+i9+CW9H7+k9+NfQua9CLpj+iIi8uuCcU9fRER+RdCUvpm1N7ONZpZsZs95ncdLZlbBzGab2TozW2tmQ7zO5DUzCzOzlWY21essXjOz4mY2ycw2mNl6M2vmdSYvmdnT/n8nP5vZJ2YW7XWmnBQUpW9mYcBo4C6gFtDTzGp5m8pTGcCzzrlaQFNgQIi/HwBDgPVeh8gjRgDTnXM3AXUJ4ffFzMoDg4F459ytQBjQw9tUOSsoSh9oDCQ751Kcc+eACUAnjzN5xjm3xzm3wv/1cTL/UZf3NpV3zCwWuAd42+ssXjOzYkBLYDyAc+6cc+6It6k8Fw4UMLNwoCCw2+M8OSpYSr88sDPLciohXHJZmVlloD6wxNsknhoO/A644HWQPKAKkAa86z/c9baZFfI6lFecc7uAvwI7gD3AUefcd96mylnBUvpyCWZWGPgcGOqcO+Z1Hi+YWQdgv3NuuddZ8ohwoAHwpnOuPnASCNlzYGZWgsyjAlWAckAhM3vI21Q5K1hKfxdQIctyrH8sZJlZBJmF/5FzbrLXeTzUHOhoZtvIPOzXxsw+9DaSp1KBVOfcP/7nN4nMXwKh6g5gq3MuzTmXDkwGEjzOlKOCpfSXAXFmVsXMIsk8ETPF40yeMTMj85jteufcMK/zeMk593vnXKxzrjKZfy9mOeeCek/ucpxze4GdZlbTP3Q7sM7DSF7bATQ1s4L+fze3E+QntsO9DpAdnHMZZjYQmEHm2fd3nHNrPY7lpebAw8AaM1vlH/sv59w3HmaSvGMQ8JF/BykFeNTjPJ5xzi0xs0nACjJnva0kyD+dq0/kioiEkGA5vCMiIgFQ6YuIhBCVvohICFHpi4iEEJW+iEgIUemLiIQQlb6ISAhR6YuIhJD/D2WmcLcIySslAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "yhat = char_gru.predict(sess = sess, X_length = X_length, X_indices = X_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training acc: 100.00%\n"
     ]
    }
   ],
   "source": [
    "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}