{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Stacked Bi-directional Recurrent Neural Networks with Drop out.\n", "\n", "### Many to One Classification by Stacked Bi-directional RNN with Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n", "- Applying **Stacking** and **dynamic rnn** to model by `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://pozalabs.github.io/blstm/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharStackedBiRNN class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharStackedBiRNN:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dims, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " self._keep_prob = tf.placeholder(dtype = tf.float32)\n", " \n", " # Stacked Bi-directional RNN with Drop out\n", " with tf.variable_scope('stacked_bi-directional_rnn'):\n", " \n", " # forward \n", " rnn_fw_cells = []\n", " for hidden_dim in hidden_dims:\n", " rnn_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " rnn_fw_cell = tf.contrib.rnn.DropoutWrapper(cell = rnn_fw_cell, output_keep_prob = self._keep_prob)\n", " rnn_fw_cells.append(rnn_fw_cell)\n", " \n", " # backword\n", " rnn_bw_cells = []\n", " for hidden_dim in hidden_dims:\n", " rnn_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " rnn_bw_cell = tf.contrib.rnn.DropoutWrapper(cell = rnn_bw_cell, output_keep_prob = self._keep_prob)\n", " rnn_bw_cells.append(rnn_bw_cell)\n", " \n", " _, output_state_fw, output_state_bw = \\\n", " tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw = rnn_fw_cells, cells_bw = rnn_bw_cells,\n", " inputs = self._X_batch,\n", " sequence_length = self._X_length,\n", " dtype = tf.float32)\n", " \n", " final_state = tf.concat([output_state_fw[-1], output_state_bw[-1]], axis = 1)\n", "\n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharStackedBiRNN" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_stacked_bi_rnn = CharStackedBiRNN(X_length = X_length_mb, X_indices = X_indices_mb, \n", " y = y_mb, n_of_classes = 2, hidden_dims = [16,16], dic = char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_stacked_bi_rnn.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.869\n", "epoch : 2, tr_loss : 0.486\n", "epoch : 3, tr_loss : 0.390\n", "epoch : 4, tr_loss : 0.260\n", "epoch : 5, tr_loss : 0.311\n", "epoch : 6, tr_loss : 0.187\n", "epoch : 7, tr_loss : 0.112\n", "epoch : 8, tr_loss : 0.060\n", "epoch : 9, tr_loss : 0.043\n", "epoch : 10, tr_loss : 0.036\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_stacked_bi_rnn.ce_loss],\n", " feed_dict = {char_stacked_bi_rnn._keep_prob : .5})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x1171a2be0>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl0VeW9//H395zkZCKQEIIhOSAzEhlDQAFnraJVUIOttlq1zq1WW+9t6/31evuz63evnfV2Wal1aB0qKiBgCw51qAOIBAggyCQI5AAhYQiQeXh+fyRqQJADJtln+LzWcpGzz5Ocj0f5nJ29n/1sc84hIiKxxed1ABERaX8qdxGRGKRyFxGJQSp3EZEYpHIXEYlBKncRkRikchcRiUEqdxGRGKRyFxGJQQlevXCPHj1c3759vXp5EZGotGTJkgrnXPbRxnlW7n379qW4uNirlxcRiUpmtjmccTosIyISg1TuIiIxSOUuIhKDVO4iIjFI5S4iEoNU7iIiMUjlLiISg6Ku3Jdt2cP989d4HUNEJKJFXbl/GKpk2r8+ZtW2Sq+jiIhErKgr90tG5hLw+5i5JOR1FBGRiBV15Z6RGuDcoT2ZUxKioanZ6zgiIhEp6sodoKggyK6qet5aW+51FBGRiBSV5X7mkGyy0gLMXFLqdRQRkYgUleWe6Pdx6eg8Xl9Txp6qeq/jiIhEnKgsd2g5NNPQ5HhpxTavo4iIRJyoLff83K4M7dWVGTo0IyLyBVFb7gBFBXmsKK1kfdl+r6OIiESUqC73KaPy8PuMGUu19y4i0lZUl3t2ehJnD8lm9rIQTc3O6zgiIhEjqssdWk6slu2r490NFV5HERGJGFFf7ucM7Um3lETNeRcRaSOscjezSWa21sw2mNlPD/N8HzN708yWmdkKM7uo/aMeXlKCn8kjc3ll1Q721TZ01suKiES0o5a7mfmBh4ALgXzgKjPLP2TYz4DnnXOjgSuBP7Z30C9TNCZIXWMz/1ixvTNfVkQkYoWz5z4O2OCc2+icqwemA1MOGeOArq1fdwM69cqikcFuDMhO06EZEZFW4ZR7HrC1zePS1m1t/Ry42sxKgXnAHe2SLkxmxtQxvSnevIdPKqo686VFRCJSe51QvQr4i3MuCFwEPGVmX/jZZnazmRWbWXF5efuu6HjZ6Dx8BrM0511EJKxyDwG92zwOtm5r6wbgeQDn3EIgGehx6A9yzj3inCt0zhVmZ2cfX+IjyOmWzMSBPZi5NESz5ryLSJwLp9wXA4PMrJ+ZBWg5YTr3kDFbgHMBzGwoLeXe6YutTx0TJLS3hvc37erslxYRiShHLXfnXCNwO/AK8BEts2JWmdl9Zja5ddjdwE1mthx4FrjOOdfpu8/n5+fQJSlBt+ATkbiXEM4g59w8Wk6Utt12b5uvVwMT2zfasUsJ+Ll4RC/mLt/GfVNOJi0prH89EZGYE/VXqB6qaEyQ6vomXv5wh9dRREQ8E3PlXnhiJidmpTJTs2ZEJI7FXLmbGZePDrLg412U7qn2Oo6IiCdirtwBLi9oucbqxaU6sSoi8Skmy71391RO7d+dWctCeDBpR0TEczFZ7tCyzvumiiqWbtnjdRQRkU4Xs+V+4fBepCT6maE57yISh2K23LskJXDhsBz+vmIbtQ1NXscREelUMVvu0DLnfX9tI6+uLvM6iohIp4rpch/fP4vcbsla511E4k5Ml7vPZ1xeEOSd9eWU7av1Oo6ISKeJ6XKHljnvzQ5mL9OJVRGJHzFf7v2zu1DQJ4OZS0s1511E4kbMlzu0nFhdV3aAlaFKr6OIiHSKuCj3i0fkEkjw6cSqiMSNuCj3bimJnJ9/AnOXb6O+sdnrOCIiHS4uyh1aDs3sqW7gjTU7vY4iItLh4qbcTx/Yg+z0JK3zLiJxIW7KPcHv47LReby5Zie7DtR5HUdEpEPFTblDy0qRjc2OOSXbvI4iItKh4qrch+SkMzyvmw7NiEjMi6tyBygqyGPVtn2s2bHP6ygiIh0m7sp98qg8Ev2mOe8iEtPirty7pwU4e0hPXly2jcYmzXkXkdgUd+UOLXPeKw7U8fb6cq+jiIh0iLgs97OH9KR7WoCZugWfiMSouCz3QIKPySNzeW11GZXVDV7HERFpd3FZ7gBTxwSpb2rmpRWa8y4isSduy/3k3K4MOSFdc95FJCbFbbmbGUVj8li2ZS8flx/wOo6ISLuK23IHuHRUHj5Dc95FJObEdbn37JrMmYOzeXFZiKZm3YJPRGJHXJc7tMx5315Zy8KPd3kdRUSk3cR9uZ839AS6JifoxKqIxJS4L/fkRD8Xj8xl/ofb2V+rOe8iEhvivtyhZZ332oZm5q/c4XUUEZF2oXIHCvpk0L9HGjN0aEZEYoTKnU/nvAf5YNNutu6u9jqOiMhXFla5m9kkM1trZhvM7KdHGPMNM1ttZqvM7G/tG7PjXTY6DzN0YlVEYsJRy93M/MBDwIVAPnCVmeUfMmYQcA8w0Tl3MnBXB2TtULkZKUwYkMWspSGc05x3EYlu4ey5jwM2OOc2OufqgenAlEPG3AQ85JzbA+Cc29m+MTtHUUGQLburWfzJHq+jiIh8JeGUex6wtc3j0tZtbQ0GBpvZe2b2vplNaq+AnWnSsBzSAn5mLNl69MEiIhGsvU6oJgCDgLOAq4A/m1nGoYPM7GYzKzaz4vLyyLsLUmoggYuG92Leyh3U1Dd5HUdE5LiFU+4hoHebx8HWbW2VAnOdcw3OuU3AOlrK/iDOuUecc4XOucLs7OzjzdyhisYEOVDXyCurNOddRKJXOOW+GBhkZv3MLABcCcw9ZMxsWvbaMbMetBym2diOOTvNuL7dCWamaNaMiES1o5a7c64RuB14BfgIeN45t8rM7jOzya3DXgF2mdlq4E3g351zUbkSl89nXF4Q5N0NFWyvrPE6jojIcQnrmLtzbp5zbrBzboBz7v+1brvXOTe39WvnnPuRcy7fOTfcOTe9I0N3tKKCPJyDWUt1A20RiU66QvUwTsxKY1zf7sxcWqo57yISlVTuR1A0Jo+N5VWUbN3rdRQRkWOmcj+Ci4b3IjnRpxOrIhKVVO5HkJ6cyAUn5/DS8u3UNWrOu4hEF5X7lygqCFJZ08DrH0XlagoiEsdU7l9i4sAe5HRNZsYSHZoRkeiicv8Sfp9xWUEe/1pXTvn+Oq/jiIiETeV+FEUFQZqaHXNKNOddRKKHyv0oBvbswsjeGTo0IyJRReUehqkFeazZsZ9V2yq9jiIiEhaVexguGZlLwO9j5hIdmhGR6KByD0NGaoDz8nsypyREQ1Oz13FERI5K5R6mooIgu6rqeWtt5N1kRETkUCr3MJ0xOJseXQLM1IlVEYkCKvcwJfp9TBmVx+trythTVe91HBGRL6VyPwZFBUEamhwvrdjmdRQRkS+lcj8G+bldye/VVXPeRSTiqdyPUdGYICtKK1lftt/rKCIiR6RyP0ZTRuWS4DNmaJ13EYlgKvdj1KNLEmcNyWb2shBNzboFn4hEJpX7cSgqCFK2r453N1R4HUVE5LBU7sfhnKE9yUhN5On3N2vvXUQiksr9OCQl+PnOqSfy2uoyrnxkIZt3VXkdSUTkICr34/TDrw3mN1eMZM2O/Ux64B2eXPgJzdqLF5EIoXI/TmbG1DFBXv3hGYzt151756zi6scWsXV3tdfRRERU7l9Vr24p/PX6sfzP5cNZvnUvkx54m2c/2IJz2osXEe+o3NuBmXHVuD68fNcZjAhmcM+slVz7xGK2V9Z4HU1E4pTKvR317p7KMzeewn1TTmbxpt2c//u3mbGkVHvxItLpVO7tzOczvjO+L/PvPJ2TctL5txeWc9OTxezcV+t1NBGJIyr3DtK3RxrTbx7Pz74+lHfWV3D+A28zpySkvXgR6RQq9w7k9xk3nt6feXeeTt+sNO6cXsL3nllKxYE6r6OJSIxTuXeCAdldmHHreH4y6SRe/2gnF/z+beav3O51LBGJYSr3TpLg93HbWQN46Y7T6JWRzG3PLOUHzy7TXZ1EpEOo3DvZkJx0XvzeRH70tcHMW7md8x94m3+uLvM6lojEGJW7BxL9Pn5w7iDm3D6RrLQANz5ZzN3PL6eypsHraCISI1TuHjo5txtzbz+NO84ZyOySEBf8/m3eWrvT61giEgNU7h4LJPi4+/whzLptAunJCVz3xGLumbWC/bXaixeR46dyjxAje2fw0h2nccuZ/Xlu8VYmPfAOC3QzEBE5Tir3CJKc6OeeC4fywq0TCCT4+Naji7h3zodU1TV6HU1EokxY5W5mk8xsrZltMLOffsm4IjNzZlbYfhHjz5gTM5n3g9P57sR+PPX+Zi588B0+2LTb61giEkWOWu5m5gceAi4E8oGrzCz/MOPSgTuBRe0dMh6lBPzce0k+0286FYBvPrKQX/x9NbUNTR4nE5FoEM6e+zhgg3Nuo3OuHpgOTDnMuF8AvwS0QlY7OqV/FvPvPJ2rTzmRx97dxEUPvsPSLXu8jiUiES6ccs8DtrZ5XNq67TNmVgD0ds79ox2zSau0pAR+cekwnr7hFOoam5n68ALun79Ge/EickRf+YSqmfmA3wF3hzH2ZjMrNrPi8vLyr/rScee0QT14+a7T+UZhb6b962Mu+cO7rCyt9DqWiESgcMo9BPRu8zjYuu1T6cAw4C0z+wQ4FZh7uJOqzrlHnHOFzrnC7Ozs408dx9KTE7m/aARPXD+WfbUNXPrH93j4rY+1lLCIHCSccl8MDDKzfmYWAK4E5n76pHOu0jnXwznX1znXF3gfmOycK+6QxALA2UN68updZzJpWA6/fHkNP5+7iuZmFbyItEg42gDnXKOZ3Q68AviBx51zq8zsPqDYOTf3y3+CdJRuqYn84crR9OqazKPvbmJ3dQO/vWIkgQRdviAS745a7gDOuXnAvEO23XuEsWd99VgSLp/P+NnF+WSnJ/E/89ewp6qeadeMoUtSWP9pRSRGaRcvRtxy5gB+PXUECzfu4lt/fp9dutuTSFxTuceQKwp786erx7B2x36mTlvI1t3VXkcSEY+o3GPMefkn8MyNp7DrQB1FDy9gzY59XkcSEQ+o3GNQYd/uvHDrBMzgG9MWsvgTrUsjEm9U7jFqSE46M2+bQI8uSVz96CJe0638ROKKyj2GBTNTeeHW8ZyUk86tTy/h+cVbj/5NIhITVO4xLqtLEn+76VQmDMjixzNX8Me3NuhqVpE4oHKPA2lJCTx27Vgmj8zlVy+v5Rd//0hXs4rEOF3pEicCCT4e+OYouqcFePy9TeyuquNXU3U1q0isUrnHEZ/P+K9LWq5m/fUra9ld3cC0qwtIDeh/A5FYo922OGNmfP/sgdx/+XDeXV/Ot/68iD1V9V7HEpF2pnKPU1eO68PDV49h9fZ9TJ22gNDeGq8jiUg7UrnHsQtOzuHJ745j5746pj68gHVl+72OJCLtROUe507tn8Vzt4ynsdlxxbSFLNmsq1lFYoHKXcjP7cqs2yaQmZrItx9dxBtrdDWrSLRTuQsAvbunMuO2CQzs2YWbnlzCzCWlXkcKW1Oz470NFfz7C8s557dv8XyxrsQV0Rw4+UyPLkk8e9Op3PLUEu5+YTm7quq4+YwBXsc6LOccH4b2MackxEsrtlG2r44uSQnkZiTz4xkrKN1dzQ+/Nhgz8zqqiCdU7nKQ9OREnrh+LD96bjn/PW8Nuw7U89MLT4qYkty8q4o5JduYXRJiY3kViX7jrCE9mTIql/OGnoDfZ9wzayX/+8YGSvfUcH/RCF2oJXFJ5S5fkJTg53+vGk33tAB/ensjFQfqub9oOIl+b0qy4kAd/1ixndklIZZt2QvAuH7dufG0/lw0PIeM1MBB4389dQR9uqfyu9fWsb2ylmnXjKFbSqIX0UU8o3KXw/L7jPumnExWlwAP/HM9e6rreehbBaQE/J3y+lV1jby6egdzSrbxzvoKmpodJ+Wk85NJJzF5VC55GSlH/F4z4wfnDiKYmcJPZq5g6sMLeOL6sQQzUzslu0gkMK9WCCwsLHTFxcWevLYcm6ff38x/zvmQ0b0zePy6sV/YU24vDU3NvLO+nNnLtvHa6jJqGprIy0hh8qhcpozK5aScrsf8MxdsqOCWp5eQnOjn8WvHMjzYrQOSi3QeM1vinCs86jiVu4Rj3srt3DW9hBOzUnnyhnH06nbkPedj4Zxj6ZY9zF62jX+s3M7uqnoyUhO5aHgvLh2VR+GJmfh8X+14/7qy/Vz/xGJ2V9Xz0LdHc85JJ7RLdhEvqNyl3S34uIKbn1xC1+QEnrzhFAb27HLcP2t92X5ml4SYU7KN0j01JCX4OC//BC4dlceZg7Pb/STozn21fPevi1m9bR//d8owrjn1xHb9+SKdReUuHeLDUCXXPfEBTc2Ox68by+g+mWF/747KWuYuDzF72TZWb9+Hz2DiwB5cOiqPC4bl0CWpY08BVdU1csezy3hjzU5uOaM/P5l00lf+rUCks6ncpcN8UlHFdx7/gPL9dUy7ZgxnDs4+4tjKmgbmr2yZ6bJo026cg5HBbkwZlcfFI3vRMz25E5NDY1MzP39pFU+/v4Wvj+jFb68YSXJi55wkFmkPKnfpUDv313Lt44tZX7af31wxkktH5332XG1DE2+u2cnskhBvrimnvqmZfj3SmDIqlymj8ujXI83D5C3H+f/09kbun7+GwhMz+fN3CslM65iTxCLtTeUuHW5fbQM3/bWYRZt287OvD2Vor67MKQkx/8Md7K9tpEeXJC4Z2XJidESwW8RcCPWpv6/Yxo+eX05eRgp/uX4sJ2Z5+6EjEg6Vu3SK2oYm7py+jFdWtSw2lhbwc8GwHC4dlceEAVkkeHThU7gWf7Kbm54sxmfGo9cWUnAM5xBEvKByl07T1Ox4vngr6ckJnDf0hKg7hr2x/ADXPbGYsn21PHjlKCYN6+V1JJEjCrfcI3u3SqKC32dcNa4PF4/IjbpiB+if3YUXvzeB/Nyu3PbMUh57d5PXkUS+MpW7CJDVuiLmBfk5/OLvq/n53FU0NXvzW61Ie1C5i7RKTvTz0LcLuOG0fvxlwSfc+vQSauqbvI4lclxU7iJt+H3Gf16cz39dks8/PyrjykcWUr6/zutYIsdM5S5yGNdP7Me0q8ewtmw/lz/8Hht2HvA6ksgxUbmLHMEFJ+cw/ebx1NQ3UfTwAhZt3OV1JJGwqdxFvsSo3hnMum0iWV0CXPPYB8wpCXkdSSQsKneRo+iTlcqs2yYwqk8Gd04v4Y9vbcCr60NEwqVyFwlDRmqAp24Yx+SRufzq5bX8x4sf0tjU7HUskSMKq9zNbJKZrTWzDWb208M8/yMzW21mK8zsdTPTYtkSc5IS/DzwzVF876wBPPvBFm74azEH6hq9jiVyWEctdzPzAw8BFwL5wFVmln/IsGVAoXNuBDAD+FV7BxWJBD6f8eNJJ/Hflw3n3Q0VfGPaQsr21XodS+QLwtlzHwdscM5tdM7VA9OBKW0HOOfedM5Vtz58Hwi2b0yRyPKtU/rw6LWFbN5VxaUPvceaHfu8jiRykHDKPQ/Y2uZxaeu2I7kBmP9VQolEg7OH9OS5W8bT1Oy44uGFvLu+wutIIp9p1xOqZnY1UAj8+gjP32xmxWZWXF5e3p4vLeKJYXndePH7E8nNSOG6Jz7gheKtR/8mkU4QTrmHgN5tHgdbtx3EzM4D/g8w2Tl32Ou1nXOPOOcKnXOF2dlHvjWbSDTJy0jhhdvGc2r/LP59xgp+/9o6TZUUz4VT7ouBQWbWz8wCwJXA3LYDzGw08Cdain1n+8cUiWxdkxN5/LqxTB0T5MHX13P388uprG7wOpbEsaPebt4512hmtwOvAH7gcefcKjO7Dyh2zs2l5TBMF+CF1lupbXHOTe7A3CIRJ5Dg49dTR9A7M5UHXl/Hax+VccsZ/bl+Yj/Sko76V02kXelOTCId4KPt+/jtq+v450dlZKUF+N7ZA/n2KX2i8mYmEll0mz2RCLB0yx5+++pa3tuwi17dkvnBuYOYOiZIYoTfW1Yil26zJxIBCvpk8syNp/K3m06hV7dk7pm1kvN+9y/mlIRo1p2epAOp3EU6wYQBPZh52wQeu7aQ1EACd04v4cIH3+GVVTs0s0Y6hMpdpJOYGecOPYF/3HEaf7hqNA1Nzdzy1BIufeg93llfrpKXdqVyF+lkPp9xychcXv3hGfyqaAQVB+q55rEPuPKR9yn+ZLfX8SRG6ISqiMfqGpuY/sFW/vDGBioO1HH2kGzuPn8Iw/K6eR1NIpBmy4hEmer6Rv66YDPT/vUxlTUNfH14L374tcEM7NnF62gSQVTuIlGqsqaBx97ZyGPvbqKmoYnLC4Lcee4gendP9TqaRACVu0iU23Wgjoff+pgn39+Mc44rx/bhjnMG0rNrstfRxEMqd5EYsaOylj+8sZ7nFm8lwW9cO74vt545gMy0gNfRxAMqd5EYs3lXFQ/+cz0vloRICyRw4+n9uOG0fqQnJ3odTTqRyl0kRq0r28/vXl3Hy6t2kJmayG1nDeA74/tq3Zo4oXIXiXErSvfym1fX8fa6cnqmJ3HHOQP55tg+BBJ0+UosU7mLxIlFG3fxm1fXsviTPQQzU7jrvMFcNjoPv8+8jiYdQAuHicSJU/pn8fwt4/nL9WPJSE3k315YzgUPvM28ldu1OFkc0x0ERGKAmXHWkJ6cOTiblz/cwW9fW8f3nlnKSTnpXFHYm8kjc8lOT/I6pnQiHZYRiUFNzY7Zy0L8ZcEnrAxV4vcZZw3O5vKCIOcO7amTr1Es3MMy2nMXiUF+n1E0JkjRmCDryvYza2mIF5eV8vqanXRNTuDikbkUFeRR0CeT1ltjSozRnrtInGhqdiz4uIJZS0O8/OEOahqa6JuVyuUFQS4bnaflDaKEZsuIyBEdqGtk/srtzFoaYuHGXQCM69edqQVBLhyeowujIpjKXUTCUrqnmtnLQsxaGmJjRRVJCT4uODmHojFBThvYQ1MqI4zKXUSOiXOOkq17mbm0lJeWb6eypoGe6UlcOjqPooIgQ3LSvY4oqNxF5Cuoa2zizTU7mbEkxFtrd9LY7Dg5tyuXFwSZMiqXHl00rdIrKncRaRe7DtTx0vJtzFwa0rTKCKByF5F213ZaZdm+Ok2r9IDKXUQ6jKZVekflLiKdQtMqO5fKXUQ63afTKmcuDbGpoorkxJZplWcNySaYmUowM4We6cmaXvkVqNxFxDPOOZZt3cusNtMqP5XgM3plJJOXkUJeRip5mSkEM1LIy0whLyOFXhnJJCXoJO2RqNxFJCLUNzazeVcVpXtrCO2pIXTIn2X7a2lbQ2aQ3SXps7LPy0xp2etv8wGQlhS/y2Jp4TARiQiBBB+DTkhn0AmHvwiqvrGZHZW1lO6t/kL5rwxV8sqqHTQ0HbwTmpGa2Lrn/3nhBzM//00gMzUx7mfuqNxFxFOBBB99slLpk3X4GTbNzY6d++sI7a2m9JDy31RRxbsbKqiubzroe1IDfnI/K/zPPwAyUwOkBvykBPykBhI+/zrRT4I/tu5dpHIXkYjm8xk53ZLJ6ZbMmBO/+Lxzjr3VDYT21hxS/tWE9tawfOte9lQ3fPEbDxHw+0gJ+ElrU/4tf7b8k5KY8PnXn21vu631wyLx8+c+HZfowQeHyl1EopqZkZkWIDMtwLC8bocdU1XXyLa9NVTWNFBd30R1fRM1DY0tf9Y3UVXXRHVDIzWfPlffRHV9y/O7q+op3dM6rnVbfWPzMWVM9BspiX7SkloK/67zBjN5ZG57/OsfkcpdRGJeWlLCEY/5H4/GpmZqGpo++zBo+2FR3eaD4aDnW7dVNzSRmdrxc/9V7iIixyjB7yPd74voC7Ri6wyCiIgAKncRkZikchcRiUFhlbuZTTKztWa2wcx+epjnk8zsudbnF5lZ3/YOKiIi4TtquZuZH3gIuBDIB64ys/xDht0A7HHODQR+D/yyvYOKiEj4wtlzHwdscM5tdM7VA9OBKYeMmQL8tfXrGcC5Fu/X/oqIeCiccs8DtrZ5XNq67bBjnHONQCWQ1R4BRUTk2HXqCVUzu9nMis2suLy8vDNfWkQkroRzEVMI6N3mcbB12+HGlJpZAtAN2HXoD3LOPQI8AmBm5Wa2+XhCAz2AiuP83lik9+Ngej8+p/fiYLHwfhxmhZ0vCqfcFwODzKwfLSV+JfCtQ8bMBa4FFgJTgTfcURaKd85lhxPwcMysOJz1jOOF3o+D6f34nN6Lg8XT+3HUcnfONZrZ7cArgB943Dm3yszuA4qdc3OBx4CnzGwDsJuWDwAREfFIWGvLOOfmAfMO2XZvm69rgSvaN5qIiByvaL1C9RGvA0QYvR8H0/vxOb0XB4ub98Oze6iKiEjHidY9dxER+RJRV+5HW+cmXphZbzN708xWm9kqM7vT60yRwMz8ZrbMzP7udRavmVmGmc0wszVm9pGZjfc6k1fM7Ietf08+NLNnzSzZ60wdLarKPcx1buJFI3C3cy4fOBX4fhy/F23dCXzkdYgI8SDwsnPuJGAkcfq+mFke8AOg0Dk3jJZZfzE/oy+qyp3w1rmJC8657c65pa1f76flL+6hy0LEFTMLAl8HHvU6i9fMrBtwBi3TlHHO1Tvn9nqbylMJQErrRZapwDaP83S4aCv3cNa5iTutSyyPBhZ5m8RzDwA/Bo7t7sWxqR9QDjzRepjqUTNL8zqUF5xzIeA3wBZgO1DpnHvV21QdL9rKXQ5hZl2AmcBdzrl9XufxipldDOx0zi3xOkuESAAKgIedc6OBKiAuz1GZWSYtv+H3A3KBNDO72ttUHS/ayj2cdW7ihpkl0lLszzjnZnmdx2MTgclm9gkth+vOMbOnvY3kqVKg1Dn36W9zM2gp+3h0HrDJOVfunGsAZgETPM7U4aKt3D9b58bMArScFJnrcSZPtK6X/xjwkXPud17n8Zpz7h7nXNA515eW/y/ecM7F/N7ZkTjndgBbzWxI66ZzgdUeRvLSFuBUM0tt/XtzLnFwcjms5QcixZHWufE4llcmAtcAK82spHXbf7QuFSECcAfwTOuO0Ebgeo/zeMI5t8jMZgBLaZlltow4uFJVV6iKiMSgaDssIyIiYVC5i4jEIJW7iEgMUrliuv6yAAAAIElEQVSLiMQglbuISAxSuYuIxCCVu4hIDFK5i4jEoP8PbI2wXWDjLIgAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_stacked_bi_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }