{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Bi-directional Recurrent Neural Networks.\n", "\n", "### Many to One Classification by Bi-directional RNN\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://pozalabs.github.io/blstm/\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharBiRNN class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharBiRNN:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # Bi-directional RNN\n", " with tf.variable_scope('bi-directional_rnn'):\n", " rnn_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " rnn_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " _, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw = rnn_fw_cell,\n", " cell_bw = rnn_bw_cell,\n", " inputs = self._X_batch,\n", " sequence_length = self._X_length,\n", " dtype = tf.float32)\n", "\n", " final_state = tf.concat([output_states[0], output_states[1]], axis = 1)\n", "\n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharBiRNN" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_bi_rnn = CharBiRNN(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n", " n_of_classes = 2, hidden_dim = 16, dic = char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_bi_rnn.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.712\n", "epoch : 2, tr_loss : 0.554\n", "epoch : 3, tr_loss : 0.454\n", "epoch : 4, tr_loss : 0.350\n", "epoch : 5, tr_loss : 0.278\n", "epoch : 6, tr_loss : 0.213\n", "epoch : 7, tr_loss : 0.161\n", "epoch : 8, tr_loss : 0.121\n", "epoch : 9, tr_loss : 0.092\n", "epoch : 10, tr_loss : 0.072\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_bi_rnn.ce_loss])\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x119b0d828>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd8VfX9x/HXJzcJIRBmEkYS9gwyApEtgqs4EUEEW6t14Kh1t7WtXfanta0LlSpIrVtERMWqxQWKLAl7hBHCSlhhBggQknx/fyRitEgucJOTe+/7+XjwqOfcLzlv76O8PXzPOd9jzjlERCS0RHgdQEREAk/lLiISglTuIiIhSOUuIhKCVO4iIiFI5S4iEoJU7iIiIUjlLiISglTuIiIhKNKrA8fHx7sWLVp4dXgRkaC0YMGCnc65hIrGeVbuLVq0ICMjw6vDi4gEJTPb6M84TcuIiIQglbuISAhSuYuIhCCVu4hICFK5i4iEIJW7iEgIUrmLiISgoCv3rB37eeSjVej1gCIiP8yvcjezwWa22syyzOz+43z+hJktLvu1xsz2Bj5qqRmr83jui3W8lZFTWYcQEQl6FZa7mfmAscCFQCowysxSy49xzt3tnOvmnOsGPA1MqYywANf3a0nvVg348/sr2Ly7oLIOIyIS1Pw5c+8JZDnnsp1zhcBEYMgJxo8C3ghEuOOJiDAevbIrEWbcO2kJxSWanhER+T5/yj0J2FxuO6ds3/8ws+ZAS+Dz04/2w5Lrx/LHyzrx9YbdTJiZXZmHEhEJSoG+oDoSmOycKz7eh2Y22swyzCwjLy/vtA40rHsSP+rUiMc+XkPm1vzT+lkiIqHGn3LPBVLKbSeX7TuekZxgSsY5N945l+6cS09IqHDFyhMyMx4e2pk6NSO5+83FHCk67n9PRETCkj/lPh9oa2YtzSya0gKf+v1BZtYBqA/MCWzEH9awdg0euaILq7bt54lP1lbVYUVEqr0Ky905VwTcDkwDMoFJzrkVZvagmV1WbuhIYKKr4hvQz0ttxFXpKYz7ch3zN+yuykOLiFRb5tXDQOnp6S5QL+s4cKSIC8d8CcBHdw6gdg3P3kEiIlKpzGyBcy69onFB94Tq8dSuEcljV3YjZ88h/u8/K72OIyLiuZAod4CeLRswekArJs7fzGeZ272OIyLiqZApd4B7zm9Hh8Zx/PrtZew6cMTrOCIingmpcq8R6eOJq7qRf+gov3tnuRYXE5GwFVLlDtCxSR3uuaAd/12xjSkLf+h2fBGR0BZy5Q5w01mt6NmiAX+auoLcvYe8jiMiUuVCstx9EcZjI7pS4hz3TVpCiRYXE5EwE5LlDpDSIJY/XJrKnOxdvDBrvddxRESqVMiWO8CI9BTO65jI36etZs32/V7HERGpMiFd7mbGX6/oQlyN0sXFCotKvI4kIlIlQrrcARLiavDwFZ1ZsSWfpz7T4mIiEh5CvtwBftSpMcN7JPPPGVks2LjH6zgiIpUuLMod4I+XptKkbk3unbSYgsIir+OIiFSqsCn3uJgoHhvRlY27C3jog0yv44iIVKqwKXeA3q0acmP/lrw2bxPTV+/wOo6ISKUJq3IHuPeC9rRrVJtfT17KnoOFXscREakUYVfuMVE+Hh/RjT0FhTzwnhYXE5HQFHblDnBGUl3uOq8dHyzdytQlW7yOIyIScGFZ7gA3D2hF92b1+P27y9m6T4uLiUhoCdtyj/RF8PiIbhSVOH751lItLiYiISVsyx2gRXwtfndxR77K2snLczZ4HUdEJGDCutwBru7ZjEHtE/jrR6vI2nHA6zgiIgHhV7mb2WAzW21mWWZ2/w+MGWFmK81shZm9HtiYlcfM+NuwLsRG+7hn0mKOFmtxMREJfhWWu5n5gLHAhUAqMMrMUr83pi3wG6Cfc64TcFclZK00iXVieGhoZ5bm7OOZz7O8jiMictr8OXPvCWQ557Kdc4XARGDI98bcBIx1zu0BcM4F3eOfF3VuwtC0JJ6ZnsXizXu9jiMiclr8KfckYHO57ZyyfeW1A9qZ2Swzm2tmgwMVsCr96bJOJMbV4J43F3OosNjrOCIipyxQF1QjgbbAQGAU8LyZ1fv+IDMbbWYZZpaRl5cXoEMHTt2aUTx6ZVeydx7kkY+0uJiIBC9/yj0XSCm3nVy2r7wcYKpz7qhzbj2whtKy/w7n3HjnXLpzLj0hIeFUM1eqfm3i+Vm/Frw0ZyMz11a//wCJiPjDn3KfD7Q1s5ZmFg2MBKZ+b8y7lJ61Y2bxlE7TZAcwZ5X69eAOtEmszS/fWsq+gqNexxEROWkVlrtzrgi4HZgGZAKTnHMrzOxBM7usbNg0YJeZrQSmA790zu2qrNCVLSbKxxMjurHzwBF+/95yr+OIiJw082pVxPT0dJeRkeHJsf311GdrefyTNTw9Ko1Luzb1Oo6ICGa2wDmXXtG4sH9C9URuG9iabin1eODd5WzPP+x1HBERv6ncT6B0cbGuHCkq5peTl2rtdxEJGir3CrRKqM1vL+rIl2vyeHXeJq/jiIj4ReXuh2t6N+estvE8/EEm63ce9DqOiEiFVO5+MDP+Mbwr0ZER3P3mYoq0uJiIVHMqdz81rhvDXy4/g8Wb9/LsjHVexxEROSGV+0m4rGtTLu3alDGfrWVZzj6v44iI/CCV+0n6y5BONKwdzd2TFnP4qBYXE5HqSeV+kurFRvOP4V3J2nGAhz7I1O2RIlItqdxPwYB2CdzQvyWvzN3IH6euoFgv1xaRaibS6wDB6ncXdcQXYYz/MptdBwp5/Kqu1Ij0eR1LRARQuZ+yiAjjtxd1JL52NA9/uIo9BYWMu6YHcTFRXkcTEdG0zOkaPaA1j4/oytfrdzNy/Fzy9h/xOpKIiMo9EK7onszz16aTnXeQ4c/NZuMuPcUqIt5SuQfIoPaJvHZTL/YdOsqwZ+ewPFf3wYuId1TuAdS9WX0m39KHaJ8xcvxcZq/b6XUkEQlTKvcAa5MYx9u39aVJ3Riue2E+Hyzd6nUkEQlDKvdK0KRuTd66pQ+dk+ty+xsLeWXOBq8jiUiYUblXknqx0bx6Qy/OaZ/I799bweMfr9bTrCJSZVTulahmtI9x1/Tgyh7JPPV5Fr99Z7meZhWRKqGHmCpZpC+Cvw/vQkJcDf45Yx27Dx5hzMg0YqL0NKuIVB6duVcBM+NXgzvwh0tSmbZiOz994Wv2HTrqdSwRCWF+lbuZDTaz1WaWZWb3H+fz68wsz8wWl/26MfBRg9/1/VsyZmQ3Fm3aw1Xj5rAj/7DXkUQkRFVY7mbmA8YCFwKpwCgzSz3O0Dedc93Kfk0IcM6QMaRbEv+69kw27S7gimdn652sIlIp/Dlz7wlkOeeynXOFwERgSOXGCm0D2iXwxk29KSgsZvizs1mas9frSCISYvwp9yRgc7ntnLJ93zfMzJaa2WQzSwlIuhDWNaUek2/pQ0yUj1Hj5zJzbZ7XkUQkhATqgur7QAvnXBfgE+Cl4w0ys9FmlmFmGXl5KrNWCbWZcltfUhrEcv2L85m6ZIvXkUQkRPhT7rlA+TPx5LJ9xzjndjnnvlnrdgLQ43g/yDk33jmX7pxLT0hIOJW8IadRnRjevLkPaSn1ueONRfx71nqvI4lICPCn3OcDbc2spZlFAyOBqeUHmFmTcpuXAZmBixj66taM4uUbenJBaiP+/P5K/jFtlZ5mFZHTUmG5O+eKgNuBaZSW9iTn3Aoze9DMLisbdoeZrTCzJcAdwHWVFThUxUT5ePYnPRjVsxljp6/j128vpai4xOtYIhKkzKszxPT0dJeRkeHJsasz5xxPfLKGpz7P4ryOiTw9qjs1o/U0q4iUMrMFzrn0isbpCdVqxsy454L2PDikE5+t2sE1/5rHvgI9zSoiJ0flXk39tE8LnhnVnaU5+7hy3Gy27dPTrCLiP5V7NXZxlya8+LMz2bL3MMOenU3WjgNeRxKRIKFyr+b6toln4ujeHCkq5srnZrNo0x6vI4lIEFC5B4Ezkuoy+Za+xMVEcfXz85ixeofXkUSkmlO5B4kW8bWYfGsfWsbX4saXMnhnUY7XkUSkGlO5B5HEuBjevLk3Z7ZowN1vLmHCzGyvI4lINaVyDzJxMVG8eP2ZXNS5Mf/3QSZ//TBTT7OKyP/Qa/aCUI1IH0+P6k7DWisY92U2uw4W8rdhXfBFmNfRRKSaULkHKV+E8eCQTjSoFc2Yz9Zy6GgxT17VjSif/jImIir3oGZm3H1+O2rV8PHwh6s4crSYZ67urpdvi4jm3EPB6AGt+cuQTnyauYObXs6goLDI60gi4jGVe4i4pk8LHr2yK7OydnLtC1+z/7DWoxEJZyr3EDK8RzJPjUpj0aa9/HjCPPYWFHodSUQ8onIPMZd0acpzP+nBqq37GTl+Lnn7j1T8m0Qk5KjcQ9B5qY144boz2birgKvGz2HrvkNeRxKRKqZyD1H928bz8g092ZF/hBHj5rB5d4HXkUSkCqncQ9iZLRrw2o29yD9UxJXPzWFdnpYMFgkXKvcQ1zWlHhNH96aopISrxs0hc2u+15FEpAqo3MNAxyZ1mDi6D5EREYwcP5elOXu9jiQilUzlHibaJNbmrVv6EBcTydXPz2P+ht1eRxKRSqRyDyMpDWJ565Y+JMbV4Kf/+pqv1u70OpKIVBK/yt3MBpvZajPLMrP7TzBumJk5M0sPXEQJpCZ1a/LmzX1o3jCW61+az2eZ272OJCKVoMJyNzMfMBa4EEgFRplZ6nHGxQF3AvMCHVICKyGuBhNH96ZD4zhufmUBHyzd6nUkEQkwf87cewJZzrls51whMBEYcpxxfwH+BhwOYD6pJPVio3n1xl6kNavHL95YyOQFem2fSCjxp9yTgM3ltnPK9h1jZt2BFOfcBwHMJpWsTkwUL13fk76t47nvrSW8Mnej15FEJEBO+4KqmUUAjwP3+jF2tJllmFlGXl7e6R5aAiA2OpIJ16ZzbodEfv/ucp7/Uu9lFQkF/pR7LpBSbju5bN834oAzgBlmtgHoDUw93kVV59x451y6cy49ISHh1FNLQMVE+Xjumh5c3LkJD32YyZhP1+q9rCJBzp83Mc0H2ppZS0pLfSRw9TcfOuf2AfHfbJvZDOA+51xGYKNKZYryRTBmZDdionw88ekaCo4Wcf/gDpjpvawiwajCcnfOFZnZ7cA0wAe84JxbYWYPAhnOuamVHVKqRqQvgn8M70LN6AjGfZHN4cJi/nhpJyL04m2RoOPXO1Sdcx8CH35v3x9+YOzA048lXomIMP4y5AxqRvl4fuZ6CgqLeWRYF3wqeJGgohdky/8wM357UUdioyMZ89laDheV8PiIrkT59ECzSLBQuctxmRl3n9+O2Ggff/1oFYePFvPM1WnUiPR5HU1E/KBTMTmhm89uzYNDOvHJyu3c+FIGhwqLvY4kIn5QuUuFftqnBX8f3oVZWTu59t9fc+BIkdeRRKQCKnfxy4j0FMaMTGPhxj38eMI89hUc9TqSiJyAyl38dmnXpjz7kx5kbsln5PNz2XngiNeRROQHqNzlpJyf2ogJ16azfucBrho3h237tE6cSHWkcpeTNqBdAi9f34vt+UcYMW4Om3cXeB1JRL5H5S6npGfLBrx6Yy/2FhQyYtwcsvMOeB1JRMpRucsp65ZSj4mj+1BYVMKIcXNZslkv3hapLlTuclpSm9bhzZv7EO0zrnh2Nn/9KJPDR3UvvIjXVO5y2tok1uajuwZwZY9kxn2RzYVjZjIve5fXsUTCmspdAqJuzSgeGdaF127sRXGJ46rxc3ng3WXsP6z74UW8oHKXgOrXJp7/3nUWN/ZvyevzNnHBE18yfdUOr2OJhB2VuwRcbHQkD1ySytu39iUuJpKfvTifuyYuYvfBQq+jiYQNlbtUmrRm9fnPL87iznPb8sGyrZz/+Be8v2SLXuEnUgVU7lKpoiMjuPv8drz/i/4k16/JL95YxE0vL9CTrSKVTOUuVaJD4zpMua0fD1zcka+y8jj/8S944+tNOosXqSQqd6kyvgjjxrNaMe2uAZyRVJffTFnG1c/PY+Oug15HEwk5Knepcs0b1uL1m3rx1ys6szx3Hz968kue/zKb4hKdxYsEispdPGFmjOrZjE/uOZv+beJ56MNMrvjnLFZv2+91NJGQoHIXTzWuG8PzP03n6VFp5Ow5xCVPz+SJT9ZwpEhLGIicDr/K3cwGm9lqM8sys/uP8/ktZrbMzBab2Vdmlhr4qBKqzIxLuzblk3vO5pIuTRnz2VoufforFm3a43U0kaBVYbmbmQ8YC1wIpAKjjlPerzvnOjvnugF/Bx4PeFIJeQ1qRfPEVd144bp09h8u4opnZ/OX/6ykoFDvbBU5Wf6cufcEspxz2c65QmAiMKT8AOdcfrnNWoCujMkpO6dDIz6+ewA/7tWMf321nsFPzmR21k6vY4kEFX/KPQnYXG47p2zfd5jZz81sHaVn7ncEJp6Eq7iYKP7v8s68Obo3vgjj6gnzuP/tpew7pIXIRPwRsAuqzrmxzrnWwK+BB443xsxGm1mGmWXk5eUF6tASwnq1ashHd57FLWe35q0FOZz/+Bd8vGKb17FEqj1/yj0XSCm3nVy274dMBC4/3gfOufHOuXTnXHpCQoL/KSWsxUT5uP/CDrx7Wz8a1q7B6FcW8PPXF5K3/4jX0USqLX/KfT7Q1sxamlk0MBKYWn6AmbUtt3kxsDZwEUVKdU6uy9Tb+3HfBe34ZMV2zn/iC6YszNESBiLHUWG5O+eKgNuBaUAmMMk5t8LMHjSzy8qG3W5mK8xsMXAPcG2lJZawFuWL4PZz2vLhnf1pFV+LeyYt4Wcvzid37yGvo4lUK+bVWU96errLyMjw5NgSGopLHK/M2cDfp63GgPsv7MCPezUnIsK8jiZSacxsgXMuvaJxekJVgpYvwriuX0um3TWA7s3r8/v3VjBy/FzWbNcSBiIqdwl6KQ1iefn6nvxjeBdWb9/PRWNm8tAHKzlwRA8/SfhSuUtIMDOuTE9h+n0DGd4jmQlfreecR2fw3uJcXXCVsKRyl5DSoFY0jwzrwju39aNx3RjunLiYq8bPZdW2/Ip/s0gIUblLSOqWUo93buvHw0M7s2b7fi5+6iv+/P4K8g/rCVcJDyp3CVm+COPqXs2Yfu9ARp6ZwouzN3DOo1/w9gLdGy+hT+UuIa9+rWgeGtqZ937ej+T6Nbn3rSVc+dwcVm7RVI2ELpW7hI0uyfWYcmtf/j6sC9k7D3LJ0zP543vLtRiZhCSVu4SViAhjxJkpTL93ID/p3ZxX5m7knEdnMCljMyV6h6uEEJW7hKW6sVE8OOQMpt7enxbxtfjV5KUMe242y3P3eR1NJCBU7hLWzkiqy1s39+HRK7uyeXcBlz7zFQ+8u4y9BYVeRxM5LSp3CXsREcbwHsl8du9Aru3TgtfnbWLQozOY+PUmTdVI0FK5i5SpWzOKP13WiQ/uOIs2ibW5f8oyhj47m6U5e72OJnLSVO4i39OxSR0m3dyHJ67qypa9hxgydha/mbKMPQc1VSPBQ+UuchxmxtC0ZD6/92yu79eSSRmbGfTYDF6du5FiTdVIEFC5i5xAXEwUv78klQ/vOIv2jeJ44N3lXD52Fos27fE6msgJqdxF/NC+cRwTR/fmqVFp7Nh/mKH/nM2vJi9h1wG9x1WqJ5W7iJ/MjMu6NuWzewdy84BWTFmYy6BHZ/DynA2aqpFqR+UucpJq14jkNxd15L93nUXn5Lr84b0VXPr0VyzYuNvraCLHqNxFTlGbxDhevaEXY6/uzp6CQoY9O4d7Jy0hb7+masR7KneR02BmXNylCZ/ecza3DmzN1CW5nPPYDF74aj2HjxZ7HU/CmHm1rnV6errLyMjw5NgilWVd3gH+NHUFM9fuJDGuBjed1YqrezWjVo1Ir6NJiDCzBc659IrG+XXmbmaDzWy1mWWZ2f3H+fweM1tpZkvN7DMza34qoUWCXeuE2rx8fU9eu7EXbRvV5qEPM+n3t8958tM1Wq9GqlSFZ+5m5gPWAOcDOcB8YJRzbmW5MYOAec65AjO7FRjonLvqRD9XZ+4SDhZt2sM/Z6zjk5XbqRXt48e9m3Nj/5Yk1onxOpoEqUCeufcEspxz2c65QmAiMKT8AOfcdOdcQdnmXCD5ZAOLhKK0ZvV5/qfpTLtrAOenNmLCzGz6/206v3tnGZt2FVT8A0ROkT/lngRsLredU7bvh9wAfHS8D8xstJllmFlGXl6e/ylFglz7xnE8OTKNGfcNYnh6Mm9l5DDosRncNXERq7ft9zqehKCA3i1jZj8B0oF/HO9z59x451y6cy49ISEhkIcWCQrNGsby8NDOzPz1IG7o35KPV27nR09+yU0vZ7B4s1aflMDx5xJ+LpBSbju5bN93mNl5wO+As51zutFX5AQa1Ynhtxd15NazW/PSnA38e9YGPlk5i35tGvLzgW3o07ohZuZ1TAli/lxQjaT0guq5lJb6fOBq59yKcmPSgMnAYOfcWn8OrAuqIt86cKSIN+Zt4vmZ2ezYf4RuKfX4+aA2nNshkYgIlbx8y98Lqn7d525mFwFPAj7gBefcQ2b2IJDhnJtqZp8CnYGtZb9lk3PushP9TJW7yP86fLSYtxfm8NwX69i8+xDtG8Vx68DWXNKlCZE+PXMoAS73yqByF/lhRcUl/GfpVv45I4s12w/QrEEsN5/dimHdk4mJ8nkdTzykchcJASUljk8ztzN2xjqWbN6rp15F5S4SSpxzzF63i7HTs5i9bhd1a0bxs34tuK5vC+rFRnsdT6qQyl0kRJV/6jU22sePezXjxrNa0UhPvYYFlbtIiFu9bT/Pzshi6pItREZEMDw9mVsGtKZZw1ivo0klUrmLhImNuw4y7stsJmfkUFRSwmVdm3LrwDa0bxzndTSpBCp3kTCzPf8wE2Zm89q8TRQUFnN+aiNuObsV3ZvV1wNRIUTlLhKm9hws5MXZG3hx9gb2HTpKq/haXJ6WxNC0JFIaaMom2KncRcLcgSNFfLB0C1MW5jJvfen7Xc9sUZ+haclc3LkJdWOjPE4op0LlLiLH5Owp4L3FW5iyMId1eQeJ9kVwbsdEhqYlMbB9ItGRevo1WKjcReR/OOdYlruPKQtzeX/JFnYdLKR+bBSXdGnK0O5JpKXU0/x8NadyF5ETOlpcwsy1eUxZmMsnK7dzpKiElvG1uLxb6fy8bqmsnlTuIuK3/MNH+e+ybUxZlMPc7NL5+fTm9RnaPYlLOjfV/Hw1onIXkVOSu/cQ7y7K5Z1FuWTtOEC0L4JzOiQytHsSgzQ/7zmVu4icFuccy3PzmbIoh/eXbGHngULqxUZxSZcmDE1Lpnszzc97QeUuIgFTVFzCzLU7mbIol49XbONIUQktGsYeu3++ecNaXkcMGyp3EakU+w8f5aPl23hnYS5z1+/COejRvD5D05K4pEsTrVJZyVTuIlLptuw9xHuLt/DOohzWbC+dnx/UIYGhackM6pBAjUi9WCTQVO4iUmWcc6zYks+UhblMXbKFnQeOULdm6fz8Fd2TtL5NAKncRcQTRcUlzMzayTsLc/l45TYOHy2hcZ0Y+rWJp1+bhvRrE6+150+Dv+Wu93SJSEBF+iIY1D6RQe0TOXCkiP8u38b0VTv4bNV23l6YA0CbxNr0bxNP39YN6d26IXVidB99oOnMXUSqREmJY+XWfGZl7WTWul18vX4Xh4+WEGHQJbleadm3aUiP5vU1V38CmpYRkWrtSFExizbtLS37rJ0sydlHcYkjJiqCM1s0KJ3GaR1PatM6+CI0X/+NgJa7mQ0GxgA+YIJz7pHvfT4AeBLoAox0zk2u6Geq3EWkvP2HjzIvezez1pWW/ZrtBwCoFxtFn1YNy+bs42nRMDasL84GbM7dzHzAWOB8IAeYb2ZTnXMryw3bBFwH3HdqcUUk3MXFRHFeaiPOS20EwI78w8xet+vYmf1Hy7cBkFSvJn1bN6R/23j6tG5IYpwuzh6PPxdUewJZzrlsADObCAwBjpW7c25D2WcllZBRRMJQYp0YLk9L4vK0JJxzbNhVwFdZO5mdtZOPV27nrQWlF2fbN4qjb5uG9G8TT8+WDYjTxVnAv3JPAjaX284Bep3KwcxsNDAaoFmzZqfyI0QkDJkZLeNr0TK+Ftf0bk5xiWPllvzSsl+3k9fnbeLfszbgizC6pdSjX+vSaZy0ZvXDdqGzKr0V0jk3HhgPpXPuVXlsEQkdvgijc3JdOifX5daBrTl8tJiFm/aUTeHs4pnpWTz1eRY1o3yc2bIB/ds0pG/reDo2CZ+Ls/6Uey6QUm47uWyfiEi1EBPlo2/rePq2jueXP4J9h44yL3vXsdsuH/5wFQA1o3y0bxxHxyZ1SG0SR2rTOrRvXIfaNULvkR9//o3mA23NrCWlpT4SuLpSU4mInIa6NaO4oFNjLujUGIBt+w4zJ3snS3P2kbk1nw+XbeWNrzcdG9+8YSwdG9chtWkdOjapQ8cmcSTVqxnUd+X4eyvkRZTe6ugDXnDOPWRmDwIZzrmpZnYm8A5QHzgMbHPOdTrRz9StkCLiFeccW/YdJnNLPplb88nclk/m1v1s2HWQbyqxTkxkWdHXIbXsf9s2qk1MlLcPWOkhJhGRk3TwSBGrtu0vLfyt+azcms/qbfspKCwGSuf6WyfUOlb63xR/QlyNKsuotWVERE5SrRqR9Ghenx7N6x/bV1Li2Li74NvC35LP/PW7eW/xlmNj4mvXoGOTOFKbfDu10yq+FpE+7+7UUbmLiJxARMS3t2Fe1LnJsf17CwrJ3Lr/2Bl+5tZ8/j1rA4XFpY/7REdG0K5R7WNTOt/8qluzau7D17SMiEiAHC0uITvvICu37vu2+Lfks+tg4bExSfVq8qvB7RnSLemUjqFpGRGRKhbli6B94zjaN45jaFrpPuccefuPlJ3dlxZ+VczRq9xFRCqRmZFYJ4bEOjEMbJ9YZccNz+dyRURCnMpdRCQEqdxFREKQyl1EJASp3EVEQpDKXUQkBKncRURCkMpdRCQEebb8gJnlARtP8bfHAzsDGCfY6fv4Ln0f39J38V2h8H00d84lVDTIs3I/HWaW4c/aCuFC38d36fv4lr6L7wqn70PTMiIiIUhSoh6AAAACrElEQVTlLiISgoK13Md7HaCa0ffxXfo+vqXv4rvC5vsIyjl3ERE5sWA9cxcRkRMIunI3s8FmttrMsszsfq/zeMXMUsxsupmtNLMVZnan15mqAzPzmdkiM/uP11m8Zmb1zGyyma0ys0wz6+N1Jq+Y2d1lf06Wm9kbZhbjdabKFlTlbmY+YCxwIZAKjDKzVG9TeaYIuNc5lwr0Bn4ext9FeXcCmV6HqCbGAP91znUAuhKm34uZJQF3AOnOuTMAHzDS21SVL6jKHegJZDnnsp1zhcBEYIjHmTzhnNvqnFtY9s/7Kf2De2ovZQwRZpYMXAxM8DqL18ysLjAA+BeAc67QObfX21SeigRqmlkkEAts8ThPpQu2ck8CNpfbziHMCw3AzFoAacA8b5N47kngV0CJ10GqgZZAHvDvsmmqCWZWy+tQXnDO5QKPApuArcA+59zH3qaqfMFW7vI9ZlYbeBu4yzmX73Uer5jZJcAO59wCr7NUE5FAd+BZ51wacBAIy2tUZlaf0r/htwSaArXM7Cfepqp8wVbuuUBKue3ksn1hycyiKC3215xzU7zO47F+wGVmtoHS6bpzzOxVbyN5KgfIcc5987e5yZSWfTg6D1jvnMtzzh0FpgB9Pc5U6YKt3OcDbc2spZlFU3pRZKrHmTxhZkbpfGqmc+5xr/N4zTn3G+dcsnOuBaX/v/jcORfyZ2c/xDm3DdhsZu3Ldp0LrPQwkpc2Ab3NLLbsz825hMHF5UivA5wM51yRmd0OTKP0ivcLzrkVHsfySj/gGmCZmS0u2/db59yHHmaS6uUXwGtlJ0LZwM88zuMJ59w8M5sMLKT0LrNFhMGTqnpCVUQkBAXbtIyIiPhB5S4iEoJU7iIiIUjlLiISglTuIiIhSOUuIhKCVO4iIiFI5S4iEoL+H8XSClRWZC7iAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_bi_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }