{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CS 20 : TensorFlow for Deep Learning Research\n",
    "## Lecture 11 : Recurrent Neural Networks\n",
    "Simple example for Many to One Classification (word sentiment classification) by Recurrent Neural Networks. \n",
    "\n",
    "### Many to One Classification by RNN\n",
    "- Creating the **data pipeline** with `tf.data`\n",
    "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `tf.keras.preprocessing.sequence.pad_sequences`\n",
    "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n",
    "- Creating the model as **Class**\n",
    "- Reference\n",
    "    - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n",
    "    - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.12.0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import os, sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "from tensorflow import keras\n",
    "import string\n",
    "%matplotlib inline\n",
    "\n",
    "print(tf.__version__)\n",
    "\n",
    "tf.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n",
    "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<pad>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '*']\n"
     ]
    }
   ],
   "source": [
    "# Character quantization\n",
    "char_space = string.ascii_lowercase \n",
    "char_space = char_space + ' ' + '*'\n",
    "char_space = list(char_space)\n",
    "char_space.insert(0, '<pad>')\n",
    "print(char_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'<pad>': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, ' ': 27, '*': 28}\n"
     ]
    }
   ],
   "source": [
    "char2idx = {char : idx for idx, char in enumerate(char_space)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### padding example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 3, 7, 7, 9, 7]\n",
      "(6, 10)\n"
     ]
    }
   ],
   "source": [
    "words = list(map(lambda word : [char2idx.get(char) for char in word],words))\n",
    "\n",
    "max_length = 10\n",
    "X_length = list(map(lambda word : len(word), words))\n",
    "X_indices = pad_sequences(sequences=words, maxlen=max_length, padding='post', truncating='post')\n",
    "\n",
    "print(X_length)\n",
    "print(np.shape(X_indices))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define CharRNN class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharRNN(keras.Model):\n",
    "    def __init__(self, num_classes, hidden_dim, max_length, dic):\n",
    "        super(CharRNN, self).__init__()\n",
    "\n",
    "        self.look_up = keras.layers.Embedding(input_dim=len(dic), output_dim=len(dic),\n",
    "                                         trainable=False, mask_zero=True, input_length=max_length,\n",
    "                                         embeddings_initializer=keras.initializers.Constant(np.eye(len(dic))))\n",
    "        self.rnn_cell = keras.layers.SimpleRNN(units=hidden_dim, return_sequences=True,\n",
    "                                               return_state=True)\n",
    "        self.dense = keras.layers.Dense(units=num_classes)\n",
    "        \n",
    "    def call(self, inputs):\n",
    "        token_representation = self.look_up(inputs)        \n",
    "        _, final_h = self.rnn_cell(token_representation)\n",
    "        score = self.dense(final_h)\n",
    "        return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a model of CharRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "# hyper-parameter#\n",
    "lr = .003\n",
    "epochs = 10\n",
    "batch_size = 2\n",
    "total_step = int(np.shape(X_indices)[0] / batch_size)\n",
    "print(total_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 10), (?, 2)), types: (tf.int32, tf.float32)>\n"
     ]
    }
   ],
   "source": [
    "## create data pipeline with tf.data\n",
    "tr_dataset = tf.data.Dataset.from_tensor_slices((X_indices, y))\n",
    "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n",
    "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n",
    "print(tr_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_rnn = CharRNN(num_classes=2, hidden_dim=16, dic=char2idx, max_length=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(model, x, y):\n",
    "    return tf.losses.softmax_cross_entropy(onehot_labels=y, logits=model(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = tf.train.AdamOptimizer(learning_rate=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch :   1, tr_loss : 0.739\n",
      "epoch :   2, tr_loss : 0.649\n",
      "epoch :   3, tr_loss : 0.546\n",
      "epoch :   4, tr_loss : 0.444\n",
      "epoch :   5, tr_loss : 0.367\n",
      "epoch :   6, tr_loss : 0.272\n",
      "epoch :   7, tr_loss : 0.215\n",
      "epoch :   8, tr_loss : 0.156\n",
      "epoch :   9, tr_loss : 0.113\n",
      "epoch :  10, tr_loss : 0.088\n"
     ]
    }
   ],
   "source": [
    "tr_loss_hist = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    avg_tr_loss = 0\n",
    "    tr_step = 0\n",
    "    \n",
    "    for x_mb, y_mb in tr_dataset:\n",
    "        with tf.GradientTape() as tape:\n",
    "            tr_loss = loss_fn(char_rnn, x=x_mb, y=y_mb)\n",
    "        grads = tape.gradient(target=tr_loss, sources=char_rnn.variables)\n",
    "        opt.apply_gradients(grads_and_vars=zip(grads, char_rnn.variables))\n",
    "        avg_tr_loss += tr_loss\n",
    "        tr_step += 1\n",
    "    else:\n",
    "        avg_tr_loss /= tr_step\n",
    "        tr_loss_hist.append(avg_tr_loss)\n",
    "    \n",
    "    print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f3de0704b38>]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd4lGW+//H3N5OEkBAIJdQkhI4gfYgURV11Qd0FC9WKDURYWbfqlnN+xz17zuqeZW2RpmBbiYgNd+2LqGCAhKJ0CC0JRUKHUJLA/fsj0Q0smhEmeTKTz+u6uC6fZ25mPo7mw8NT7tucc4iISHiJ8DqAiIgEn8pdRCQMqdxFRMKQyl1EJAyp3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMJQpFcf3KhRI5eamurVx4uIhKSlS5fucc4lVjTOs3JPTU0lOzvbq48XEQlJZrYtkHE6LSMiEoZU7iIiYUjlLiIShlTuIiJhSOUuIhKGVO4iImFI5S4iEoZCrty37ink0ffWUXLylNdRRESqrZAr9/dX7+Lp+Zu4feYS9hcWeR1HRKRaCrlyH3tpGx4d2pWsLfsZnL6AdbsOeR1JRKTaCblyBxjuTyZjbB9OFJ/ihqc/592VO72OJCJSrYRkuQP0TKnP2z+5mPZN4hn3t2VM+mA9p045r2OJiFQLIVvuAE3qxpAxpg/DeiXxxLwcxryYzeHjxV7HEhHxXEiXO0BMlI9Hh3blvwZ35uP1BVz/9OdsLjjidSwREU+FfLkDmBm390vlxbvS2HvkBEPSFzJ//W6vY4mIeCYsyv1r/do0Yu6Ei2mRUJs7nstiyiebcE7n4UWk5gmo3M1skJmtN7McM3vwLK//1cxWlP3aYGYHgh81MMkNYnn9vn5c06UZf3p3HRMzVnCs6KRXcUREPFHhSkxm5gPSgauAfCDLzOY659Z8PcY590C58T8BelRC1oDFRkfy1KgedGpWl//7YD2bCo4w9dZeJNWP9TKWiEiVCeTIPQ3Icc5tds4VARnAkO8YPwqYFYxw58PMGH95W5693U/u3qMMfmohizbv9TqWiEiVCKTcWwB55bbzy/b9GzNrCbQC5p1/tOD4QccmvDmhPwmxUdzyzGJezNyq8/AiEvaCfUF1JDDHOXfWk9xmNsbMss0su6CgIMgf/e3aJNbhzfH9GdA+kd+/tZqHXl/JiRKdhxeR8BVIuW8HksttJ5XtO5uRfMcpGefcNOec3znnT0xMDDxlENSNiWL6bX7GX96GjKw8bpq+mN2HjldpBhGRqhJIuWcB7cyslZlFU1rgc88cZGYdgfpAZnAjBo8vwvjlwI6k39STNTsO8eOnFrAiz7Mbe0REKk2F5e6cKwEmAO8Da4HZzrnVZvawmQ0uN3QkkOFC4IT2tV2b8dq4fkT5Ihg+NZPXluZ7HUlEJKjMqy72+/0uOzvbk8/+2r7CIsb/bRmZm/dyZ/9W/OaajkT6wuq5LhEJM2a21Dnnr2hcjW6yBnHRvHBXGqP7pTJj4RYtACIiYaNGlztAlC+C/ze4sxYAEZGwUuPL/WvD/cm8ogVARCRMqNzL6VG2AEiHploARERCm8r9DF8vADLc//UCIEu1AIiIhByV+1nUivTxyI1fLwCyWwuAiEjIUbl/i68XAHnprovYV1ikBUBEJKSo3CvQt01D3hrfn+T6sVoARERChso9AMkNYnltXD+u1QIgIhIiVO4Bqh3t48lRPfjVoA68/eUOhk75nO0HjnkdS0TkrFTu34OZcd9lbZlxe29y9x1lyFMLWbNDDzyJSPWjcj8Hl3dszBv39SPaZ4yYlsmSLfu8jiQichqV+zlq2zieV8f1IzG+Frc+u5h5677yOpKIyDdU7uehRUJtXh3blw5N47nnhaW8sVxTB4tI9aByP08N69Ti5Xv6cFGrBjzwyhfMXLjF60giIir3YKhTK5IZo3szsHMT/uvtNUz6YL3uhRcRT6ncgyQmykf6TT0Z4U/miXk5/MdbqzXpmIh4JtLrAOEk0hfBn27sQkJcFFM/2cyBY8X8ZVg3oiP1Z6iIVC2Ve5CZGQ9dfQH1Y6P507vrOHismCm39CQ2Wl+1iFQdHVJWknsvbcMjN3ZhwcYCbnlmMQeOavk+Eak6KvdKNKJ3Ck/f3ItV2w8xYuoivjp03OtIIlJDqNwr2aALm/LcHb3J33+UGyd/ztY9hV5HEpEaIKByN7NBZrbezHLM7MFvGTPczNaY2Wozezm4MUNbv7aNmDWmD0eLTjJ0Siardxz0OpKIhLkKy93MfEA6cDXQCRhlZp3OGNMOeAjo75zrDPy0ErKGtK5JCcwe25donzFy6iLNRyMilSqQI/c0IMc5t9k5VwRkAEPOGHMPkO6c2w/gnNOSRWfRtnGd0vlo6pbOR/PPtZqPRkQqRyDl3gLIK7edX7avvPZAezNbaGaLzGzQ2d7IzMaYWbaZZRcUFJxb4hBXfj6aMS8u5fVlmo9GRIIvWBdUI4F2wGXAKGC6mSWcOcg5N80553fO+RMTE4P00aGn/Hw0P5v9BTMWaD4aEQmuQMp9O5BcbjupbF95+cBc51yxc24LsIHSspdvUX4+mof/rvloRCS4Ain3LKCdmbUys2hgJDD3jDFvUnrUjpk1ovQ0zeYg5gxLZ85H8/u3VnFS89GISBBU+Ey8c67EzCYA7wM+YIZzbrWZPQxkO+fmlr32QzNbA5wEfumc21uZwcPFv81Hc7SYScO7az4aETkv5tWpAL/f77Kzsz357Opq6ieb+N931zGgfaLmoxGRszKzpc45f0XjdHhYjYy9tA2P3tiVBRsLuFnz0YjIeVC5VzPDeyfz9M29WK35aETkPKjcqyHNRyMi50vlXk2dPh/N55qPRkS+F5V7Nfav+WgiGDl1EYs36wYkEQmMyr2aa9u4DnPG9aNx3VrcNmMJH63RfDQiUjGVewhonlCbV+/tR4em8Yx9SfPRiEjFVO4hokFctOajEZGAqdxDyNfz0Qzq3JSH/76Gv2g+GhH5Fir3EBMT5SP95tL5aJ6cl8Pv3tR8NCLy7/R8ewjyRdjp89EcK+avmo9GRMpRuYcoM+Ohqy+gQWw0//vuOg4dK2bKLb2Iq6X/pCKi0zIhb+ylbXh0aFcW5uzh5mcWs79Q89GIiMo9LAz3JzP5ll6s2XmIYVMz2XnwmNeRRMRjKvcwMbBzU56/I41dB48zdHImmwqOeB1JRDykcg8jfds0JGNMH06UnGTYlEy+zD/gdSQR8YjKPcxc2KIer97bj9hoH6OmLWJhzh6vI4mIB1TuYahVozheG9ePpPqx3DEzi3dX7vQ6kohUMZV7mGpSN4bZY/vSJake419exqwluV5HEpEqpHIPY/Vio3jprosY0D6Rh15fSfrHOZquQKSGULmHudrRPqbf5ue67s358/vr+e9/rOWUpisQCXsBlbuZDTKz9WaWY2YPnuX10WZWYGYryn7dHfyocq6ifBFMGt6d0f1SeXbBFn7x6hcUnzzldSwRqUQVPqtuZj4gHbgKyAeyzGyuc27NGUNfcc5NqISMEgQREcZ//rgTDeOi+cuHGzh4rJinbupJ7Wif19FEpBIEcuSeBuQ45zY754qADGBI5caSymBm/OSKdvzhuguZt343t81YzMFjxV7HEpFKEEi5twDyym3nl+07041m9qWZzTGz5KCkk0pxa5+WPDmqByvyDjBiaia7Dx33OpKIBFmwLqi+DaQ657oCHwLPn22QmY0xs2wzyy4oKAjSR8u5+FHX5swY3ZvcfUcZOiWTbXsLvY4kIkEUSLlvB8ofiSeV7fuGc26vc+5E2eYzQK+zvZFzbppzzu+c8ycmJp5LXgmiS9ol8vI9fTh0vJgbJ2eyZschryOJSJAEUu5ZQDsza2Vm0cBIYG75AWbWrNzmYGBt8CJKZeqenMCce/sS5TNGTMtkyZZ9XkcSkSCosNydcyXABOB9Skt7tnNutZk9bGaDy4bdb2arzewL4H5gdGUFluBr2zieOeP6kRhfi1ufXcxHa77yOpKInCfz6olFv9/vsrOzPflsObu9R05wx3NZrN5xiEdu7MrQXkleRxKRM5jZUuecv6JxekJVvtGwTi1evqcPfVo34BevfsEzn232OpKInCOVu5ymTq1IZozuzTVdmvLf/1jLI++t03w0IiFIqynLv6kV6ePJUT1JiF3F5Pmb2F9YxB+v74IvwryOJiIBUrnLWfkijD9edyENYqN56uMcDhwt5rGR3YmJ0nQFIqFAp2XkW5kZvxjYgd//qBPvrd7FHTOzOHxc0xWIhAKVu1TorotbMWl4N5Zs3cdN0xez58iJin+TiHhK5S4BuaFnEtNv68WGrw4zfEom+fuPeh1JRL6Dyl0C9oOOTXjp7ovYc+QEQydnsuGrw15HEpFvoXKX76V3agNeGduXk84xbEomy3L3ex1JRM5C5S7f2wXN6vLavf1IiI3i5umL+WSDZvgUqW5U7nJOUhrG8uq9fUltFMfdz2cx94sdXkcSkXJU7nLOGsfH8MrYPvRIqc/EjOU88c+NlGhtVpFqQeUu56VuTBQv3JnG4G7NmfThBoZOyWRzwRGvY4nUeCp3OW8xUT4eH9mDJ0f1YMueQq554jNeyNyqOWlEPKRyl6D5cbfmfPDAAPq0bsh/vLWa22YsYefBY17HEqmRVO4SVE3qxjBzdG/+eP2FZG/dz8C/fspbK7brKF6kiqncJejMjJsvasm7Ey+hXZN4JmasYMLLy9lfWOR1NJEaQ+UulSa1URyzx/bl14M68sGaXfzwsU+Zt05L+IlUBZW7VCpfhDHusja8Nf5iGsZFc+dz2Tz0+pccOVHidTSRsKZylyrRqXld3prQn3svbUNGVh5XP/4pS7bs8zqWSNhSuUuVqRXp48GrOzJ7bF8MY8S0TP73nbWcKDnpdTSRsKNylyrXO7UB7068hFFpKUz9dDODn1zI6h0HvY4lElYCKnczG2Rm680sx8we/I5xN5qZMzN/8CJKOIqrFcn/XN+FmaN7s+9oEdelLyT94xxNXyASJBWWu5n5gHTgaqATMMrMOp1lXDwwEVgc7JASvi7v2JgPfjqAH3Zqyp/fX8/wqZls3VPodSyRkBfIkXsakOOc2+ycKwIygCFnGfcH4BHgeBDzSQ1QPy6ap27qweMju5Oz+whXP/4ZLy3apgefRM5DIOXeAsgrt51ftu8bZtYTSHbO/eO73sjMxphZtpllFxRoDnD5FzNjSPcWfPDApfhT6/O7N1dx+8wsdh3UsYLIuTjvC6pmFgFMAn5e0Vjn3DTnnN85509MTDzfj5Yw1LReDC/cmcYfhnRmyZa9DHzsU80VL3IOAin37UByue2ksn1fiwcuBOab2VagDzBXF1XlXJkZt/ZN5d2JA2idGMf9s5Yz4eVlmr5A5HsIpNyzgHZm1srMooGRwNyvX3TOHXTONXLOpTrnUoFFwGDnXHalJJYao1WjOF4d25dfDuzAe6t2MfCxT/l4/W6vY4mEhArL3TlXAkwA3gfWArOdc6vN7GEzG1zZAaVmi/RFMP7ytrw5vj8JsVHcMTOL37yxkkJNXyDyncyrOxL8fr/LztbBvQTuePFJJn24gemfbSa5fiyThnfDn9rA61giVcrMljrnKjztrSdUJWTERPn4zTUXkHFPH045x/CpmTzy3jpNXyByFip3CTkXtW7Iez8dwHB/MpPnb2LIUwtZu/OQ17FEqhWVu4SkOrUi+dONXXn2dj97jhQx+KkFTJ6/iZOn9OCTCKjcJcRdcUETPnhgAFde0IRH3lvHiKmZ7DigdVtFVO4S8hrERfP0zT3564hurNt1mOvSF7IyX7NMSs2mcpewYGZc3yOJ18b1I8oXwfCpmby/epfXsUQ8o3KXsNKhaTxvju9Ph6bx3PvSUqZ9ukkTkEmNpHKXsJMYX4uMMX245sJm/M876/jNGysp1jzxUsNEeh1ApDLERPl4clQPUhvFkv7xJvL2HSP95p7Uqx3ldTSRKqEjdwlbERHGLwd25M9Du7J4y15unPw5efuOeh1LpEqo3CXsDfMn88KdF1Fw+ATXpS9k6bZ9XkcSqXQqd6kR+rZpyBv39SM+JpJR0xdrjngJeyp3qTFaJ9bh9fv60z0pgftnLeeJf27UnTQStlTuUqM0iIvmxbvTuKFHCyZ9uIGfz/5CE49JWNLdMlLj1Ir08Zfh3WjVKI6/fLiB/P3HmHJrLxrERXsdTSRodOQuNZKZ8ZMr2vHEqB6syD/A9U8vZFPBEa9jiQSNyl1qtMHdmjPrnj4cOV7CDU9/TuamvV5HEgkKlbvUeL1a1ufN8f1JjK/FbTMWMzs7z+tIIudN5S4CJDeI5bVx/bioVUN+NedLHn1vHac0N7yEMJW7SJl6taOYeUdvRqWl8PT8TUyYtYzjxbqTRkKTyl2knChfBP9z/YX89poLeHfVLkZMW0TB4RNexxL53gIqdzMbZGbrzSzHzB48y+v3mtlKM1thZgvMrFPwo4pUDTPjngGtmXJLLzaULf6xftdhr2OJfC8VlruZ+YB04GqgEzDqLOX9snOui3OuO/AoMCnoSUWq2MDOTZk9ti/FJ09x4+TP+WRDgdeRRAIWyJF7GpDjnNvsnCsCMoAh5Qc458ovPR8H6EqUhIUuSfV4a0J/khvEcudzWby4aJvXkUQCEki5twDK3xuWX7bvNGY23sw2UXrkfn9w4ol4r1m92rx6b18ubZ/I799cxcNvr+Gk7qSRai5oF1Sdc+nOuTbAr4HfnW2MmY0xs2wzyy4o0F9xJXTUqRXJ9Nv83NE/lRkLtzD2xWwKT5R4HUvkWwVS7tuB5HLbSWX7vk0GcN3ZXnDOTXPO+Z1z/sTExMBTilQDvgjjP3/cmYeHdGbeut0Mm5LJzoPHvI4lclaBlHsW0M7MWplZNDASmFt+gJm1K7d5LbAxeBFFqpfb+qYyY3Rvcvcd5br0hazaftDrSCL/psJyd86VABOA94G1wGzn3Goze9jMBpcNm2Bmq81sBfAz4PZKSyxSDVzWoTFzxvUlMiKCYVMy+WD1Lq8jiZzGvFqswO/3u+zsbE8+WyRYdh8+zj3PZ/Pl9oP89poLuOviVpiZ17EkjJnZUuecv6JxekJV5Dw0jo8hY0xfBnVuyn//Yy2/fXMVxSdPeR1LROUucr5qR/tIv6kn4y5rw8uLc7nzuSwOHS/2OpbUcCp3kSCIiDB+Pagjj97YlcxNe7k+fSGfbCjQGq3iGZW7SBAN753MC3elcbz4FLfPWMKwKZl8nrPH61hSA6ncRYKsX5tGzPvFpfzhugvJ33+Mm55ZzMhpmSzZss/raFKD6G4ZkUp0vPgks5bk8vT8TRQcPsHFbRvxwFXt6dWyvtfRJEQFereMyl2kChwrOsnfFm9j8vxN7C0s4rIOiTxwZXu6JSd4HU1CjMpdpBoqPFHCC5nbmPrpJg4cLebKC5rwwFXt6Ny8ntfRJESo3EWqscPHi3lu4Vamf7aZQ8dLuPrCpvz0yvZ0aBrvdTSp5lTuIiHg4LFiZizYwowFWzhSVMK1XZrx0yvb07ZxHa+jSTWlchcJIQeOFjH9s83MXLiV48UnGdK9BROvaEdqozivo0k1o3IXCUF7j5xg2qebeT5zK8UnHTf0aMH9V7QjuUGs19GkmlC5i4Sw3YePM2X+Zl5avI1TpxzD/MlM+EFbWiTU9jqaeEzlLhIGvjp0nPSPc8hYUrrS5ci0ZO67rC1N68V4nEy8onIXCSPbDxwj/eMcZmflERFh3HxRCuMua0PjeJV8TaNyFwlDefuO8uS8jby2bDtRPuO2vqmMHdCahnVqeR1NqojKXSSMbd1TyBPzNvLm8u3ERPkY3S+Vey5pTf24aK+jSSVTuYvUADm7j/D4Pzfy9y93EBcdyZ0Xt+Kui1tRr3aU19GkkqjcRWqQ9bsO8/g/N/DOyl3UjYnknktaM7p/KvExKvlwo3IXqYHW7DjEXz/awIdrviIhNoqxA9owul8qtaN9XkeTIFG5i9RgK/MPMunD9Xy8voDG8bW4/4p2jOidTJRPSziEuqAukG1mg8xsvZnlmNmDZ3n9Z2a2xsy+NLN/mlnLcwktIsHRJakeM+9IY869fWnZMJbfvbmKKyd9wlsrtnPqlJb+qwkqLHcz8wHpwNVAJ2CUmXU6Y9hywO+c6wrMAR4NdlAR+f78qQ2YPbYvM0f3pnaUj4kZK7j2yQV8vH631ncNc4EcuacBOc65zc65IiADGFJ+gHPuY+fc0bLNRUBScGOKyLkyMy7v2Jh37r+Ex0d2p/BECXfMzGLE1EVkb9XSf+EqkHJvAeSV284v2/dt7gLePZ9QIhJ8ERHGkO4t+Ohnpeu7btlbyNApmdz9fBbrdh3yOp4EWVCvrpjZLYAf+PO3vD7GzLLNLLugoCCYHy0iAYqOjODWPi355JeX8cuBHVi8ZR9XP/4ZD7yygty9Ryt+AwkJgZT7diC53HZS2b7TmNmVwG+Bwc65E2d7I+fcNOec3znnT0xMPJe8IhIksdGRjL+8LZ/96nLGDmjDOyt3csWk+fznW6soOHzWH2EJIRXeCmlmkcAG4ApKSz0LuMk5t7rcmB6UXkgd5JzbGMgH61ZIkepl18HjPDFvI69k5RHti+Cui1sx5tLW1NWDUNVKUO9zN7NrgMcAHzDDOfdHM3sYyHbOzTWzj4AuwM6y35LrnBv8Xe+pchepnrbsKWTShxt4+4sdJMRGcd9lbbitbyoxUXoQqjrQQ0wicl5WbT/In99fzycbCmhaN4aJV7ZjWK8kIvUglKeC+hCTiNQ8F7aox/N3ppExpg/NE2J46PWV/PCvn/KPL3fqQagQoHIXke/Up3VDXhvXj+m3+Yn0GeNfXsbg9AV8uqFAD0JVYyp3EamQmXFVpya8O3EAfxnWjf2Fxdw2Ywk3TV/M8tz9XseTs9A5dxH53k6UnGTW4lyenJfD3sIiftipCb8Y2IH2TeK9jhb2dEFVRCrdkRMlzFiwhWmfbuZoUQnX90jigavakVQ/1utoYUvlLiJVZl9hEZPn5/B85jZwcHOfFMZf3pZGWts16FTuIlLldhw4xuMfbeTVpXnUjvJx9yWtufuSVloRKohU7iLimZzdR5j04XreWbmL+rFR3H1Ja4b7k0mM15H8+VK5i4jnvsg7wP99sJ7PNu4hMsK48oImjEhLZkC7RHwR5nW8kKRyF5FqI2f3EV7JyuW1ZdvZV1hEi4TaDPMnMcyfTIuE2l7HCykqdxGpdk6UnOSjNbvJyMrls417MINL2ycysncKV1zQWGu8BkDlLiLVWt6+o7ySlcerS/P46tAJEuNrMbRXEiN7J9OyYZzX8aotlbuIhISSk6eYv76AjKxc5q3bzSkHfVs3ZGRaMgM7N9VslGdQuYtIyNl18DivZufxSnYe+fuPkRAbxQ09khiZlqynX8uo3EUkZJ065Vi4aQ8ZS/L4YM0uik86eqYkMDIthR91bUZsdKTXET2jcheRsLDnyAneWLadWVm5bC4oJL5WJIO7N2dk7xS6JNXzOl6VU7mLSFhxzpG1dT8ZS3L5x8qdnCg5RefmdRmZlsKQ7s1rzHKAKncRCVsHjxbz1hfbmbUkj7U7D1E7yse1XZsxKi2Znin1MQvfB6RU7iIS9pxzfJl/kIysXOau2EFh0UnaNa7DiN7J3NAziQZx0V5HDDqVu4jUKIUnSvj7lzuYtSSPFXkHiPZF8MPOTRiVlkLf1g2JCJPpDlTuIlJjrdt1iIwleby+LJ9Dx0to2TCW4f7S++ZbN4oL6aJXuYtIjXe8+CTvrdrFrCW5LN6yD4B6taPonpxAj5QEeqbUp1tyAvVqh87F2KCWu5kNAh4HfMAzzrk/nfH6AOAxoCsw0jk3p6L3VLmLSFXatreQxVv2sTx3P8tzD7D+q8N8XX9tG9ehZ0oCPVLq0zOlPm0b16m2s1YGrdzNzAdsAK4C8oEsYJRzbk25MalAXeAXwFyVu4hUd4ePF/Nl/kGW5+5nWe4BlufuZ//RYgDq1IqkW3I9eqbUp0dKAt2T61ebi7OBlnsgj3mlATnOuc1lb5wBDAG+KXfn3Nay106dU1oRkSoWHxNF/7aN6N+2EVB65822vUdZVnZkvzxvP0/P38TJU6UHwK0axdGj7HROj5T6dGwaT2Q1nsUykHJvAeSV284HLjqXDzOzMcAYgJSUlHN5CxGRSmFmpDaKI7VRHDf0TALgaFEJK/MPfnNk/+nGPby+fDsAtaN8dE2qR4+yo/seKQk0jo/x8l/hNFU6QYNzbhowDUpPy1TlZ4uIfF+x0ZFc1LohF7VuCJQe3efvP8byvAMs27af5XkHeHbBZopPltZZUv3aZeftS4/uOzWrS3SkN0f3gZT7diC53HZS2T4RkRrFzEhuEEtyg1gGd2sOlN6Rs3rHwdJTObkHyN66j7e/2AFAdGQEXVrUKzudU5+eLRNoVq9qVp4KpNyzgHZm1orSUh8J3FSpqUREQkRMlI9eLRvQq2WDb/btPHisrOxLz9+/sGgbzyzYAkDTujE8dE1HhnRvUam5Kix351yJmU0A3qf0VsgZzrnVZvYwkO2cm2tmvYE3gPrAj83sv5xznSs1uYhINdWsXm2adanNNV2aAVBUcoq1Ow99c2dOVZyb10NMIiIhJNBbIavvfTwiInLOVO4iImFI5S4iEoZU7iIiYUjlLiIShlTuIiJhSOUuIhKGVO4iImHIs4eYzKwA2HaOv70RsCeIcUKdvo/T6fv4F30XpwuH76Olcy6xokGelfv5MLPsQJ7Qqin0fZxO38e/6Ls4XU36PnRaRkQkDKncRUTCUKiW+zSvA1Qz+j5Op+/jX/RdnK7GfB8hec5dRES+W6geuYuIyHcIuXI3s0Fmtt7McszsQa/zeMXMks3sYzNbY2arzWyi15mqAzPzmdlyM/u711m8ZmYJZjbHzNaZ2Voz6+t1Jq+Y2QNlPyerzGyWmVWflawrSUiVu5n5gHTgaqATMMrMOnmbyjMlwM+dc52APsD4GvxdlDcRWOt1iGriceA951xHoBs19HsxsxbA/YDfOXchpSvKjfQ2VeULqXIH0oAc59xm51wRkAEM8TiTJ5xzO51zy8r++TClP7han/riAAABzUlEQVSVuyhjNWdmScC1wDNeZ/GamdUDBgDPAjjnipxzB7xN5alIoLaZRQKxwA6P81S6UCv3FkBeue18anihAZhZKtADWOxtEs89BvwKOOV1kGqgFVAAzCw7TfWMmcV5HcoLzrntwP8BucBO4KBz7gNvU1W+UCt3OYOZ1QFeA37qnDvkdR6vmNmPgN3OuaVeZ6kmIoGewGTnXA+gEKiR16jMrD6lf8NvBTQH4szsFm9TVb5QK/ftQHK57aSyfTWSmUVRWux/c8697nUej/UHBpvZVkpP1/3AzF7yNpKn8oF859zXf5ubQ2nZ10RXAluccwXOuWLgdaCfx5kqXaiVexbQzsxamVk0pRdF5nqcyRNmZpSeT13rnJvkdR6vOececs4lOedSKf3/Yp5zLuyPzr6Nc24XkGdmHcp2XQGs8TCSl3KBPmYWW/ZzcwU14OJypNcBvg/nXImZTQDep/SK9wzn3GqPY3mlP3ArsNLMVpTt+41z7h0PM0n18hPgb2UHQpuBOzzO4wnn3GIzmwMso/Qus+XUgCdV9YSqiEgYCrXTMiIiEgCVu4hIGFK5i4iEIZW7iEgYUrmLiIQhlbuISBhSuYuIhCGVu4hIGPr/1mKGnKp2jjkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(tr_loss_hist, label = 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc : 100.00%\n"
     ]
    }
   ],
   "source": [
    "yhat = np.argmax(char_rnn(inputs=tf.convert_to_tensor(X_indices)), axis=-1)\n",
    "print('acc : {:.2%}'.format(np.mean(yhat == np.argmax(y, axis=-1))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}