{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Stacked GRU with Drop out.\n", "\n", "### Many to One Classification by Stacked GRU with Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class** \n", "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n", "- Applying **Stacking** to model by `tf.contrib.rnn.MultiRNNCell`\n", "- Replacing **RNN Cell** with **GRU Cell**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharStackedGRU class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharStackedGRU:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, dic, hidden_dims = [32, 16]):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " self._keep_prob = tf.placeholder(dtype = tf.float32)\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # Stacked-GRU\n", " with tf.variable_scope('stacked_gru'):\n", " \n", " cells = []\n", " for hidden_dim in hidden_dims:\n", " cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim,\n", " kernel_initializer = tf.contrib.layers.xavier_initializer(),\n", " activation = tf.nn.tanh)\n", " cell = tf.contrib.rnn.DropoutWrapper(cell = cell, output_keep_prob = self._keep_prob)\n", " cells.append(cell)\n", " else:\n", " cells = tf.contrib.rnn.MultiRNNCell(cells = cells)\n", " \n", " _, state = tf.nn.dynamic_rnn(cell = cells, inputs = self._X_batch,\n", " sequence_length = self._X_length, dtype = tf.float32)\n", " \n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = state[-1], num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharStackedGRU" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_stacked_gru = CharStackedGRU(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n", " n_of_classes = 2, dic = char_dic, hidden_dims = [32,16])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_stacked_gru.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.691\n", "epoch : 2, tr_loss : 0.646\n", "epoch : 3, tr_loss : 0.613\n", "epoch : 4, tr_loss : 0.594\n", "epoch : 5, tr_loss : 0.548\n", "epoch : 6, tr_loss : 0.484\n", "epoch : 7, tr_loss : 0.443\n", "epoch : 8, tr_loss : 0.425\n", "epoch : 9, tr_loss : 0.351\n", "epoch : 10, tr_loss : 0.277\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_stacked_gru.ce_loss],\n", " feed_dict = {char_stacked_gru._keep_prob : .5})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x10e9d1cc0>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd0VWW+xvHv75w0AiRACDWUEEGMFMFAaIJddFAsWGgqgthQB9v1TrteZ8aZay+DMIAwo4KMZXQYy4COCNICoUsTEpQuEDohpL33j0Ru9CJEPMk+5fmslbXY+7zJftZe5GGz22vOOUREJLz4vA4gIiKBp3IXEQlDKncRkTCkchcRCUMqdxGRMKRyFxEJQyp3EZEwpHIXEQlDKncRkTAU5dWG69ev71q2bOnV5kVEQtKSJUv2OOeSTzXOs3Jv2bIl2dnZXm1eRCQkmdnXlRlXqdMyZtbXzNab2UYze/QEnz9nZsvLv740s/0/NrCIiATOKY/czcwPjAEuAbYCi81sunNuzbdjnHOjK4y/F+hUBVlFRKSSKnPk3hXY6JzLdc4VAtOA/icZPxB4IxDhRETk9FSm3JsCWyosby1f9/+YWQsgFfj0p0cTEZHTFehbIW8C3nbOlZzoQzMbaWbZZpa9e/fuAG9aRES+VZly3wY0q7CcUr7uRG7iJKdknHPjnXMZzrmM5ORT3skjIiKnqTLlvhhobWapZhZDWYFP//4gM2sL1AUWBDaiiIj8WKcsd+dcMTAKmAGsBd50zq02s8fN7KoKQ28Cprkqnrdv467DPDNzPQVFJzzzIyIiVPIhJufch8CH31v3m+8tPxa4WD/s32u/4aVPN/LByh384dr2ZLZKqo7NioiElJB7t8wdfdJ4bXhXCktKuXH8Qn757ioOFhR5HUtEJKiEXLkDnNc6mZmjezOiVypvLNrMJc/OZubqnV7HEhEJGiFZ7gDxMVH8ql86797dk7rxMYx8bQl3T1nCrkMFXkcTEfFcyJb7tzo2q8M/7+3Fw5edySdrd3HxM7N5c/EWqvi6rohIUAv5cgeI9vu454Iz+Oj+82jbOIFH3lnJ4IlZfLXniNfRREQ8ERbl/q205FpMu70bT1zTnlVbD3DZ83P48+wciktKvY4mIlKtwqrcAXw+Y1Bmcz5+oA+92yTzh4/WcfXL8/hi2wGvo4mIVJuwK/dvNUqMY/zQcxk7uDPfHDxG/zHz+ONH6/Twk4hEhLAtdwAz4/L2jflkdB8GdE5h3Owc+j4/h/k5e7yOJiJSpcK63L+VGB/N/wzowNQRmThg0IQsHn1nJQfy9fCTiISniCj3b/U4oz4zft6bO/uk8daSrVz83Gw+WrXD61giIgEXUeUOEBft59HL2/KPe3rSoHYsd01Zyh2vZfPNQT38JCLhI+LK/Vvtmibyj3t68p+Xt+Wz9bu5+JnZTM3aTGmpHn4SkdAXseUOEOX3cUefNGb8vDftUxL5xburGDhhIbm7D3sdTUTkJ4nocv9Wy/o1mTIikyev68DaHQfp+8LnjJm1kSI9/CQiIUrlXs7MuKFLMz55sA+XnNWQp2as58qX5rJy636vo4mI/Ggq9+9pUDuOMYM7M37ouezLL+TqMfP4/QdryC8s9jqaiEilqdx/wKVnN+LjB/owsGtzJny+icuen8PnG3Z7HUtEpFJU7ieREBfN769pz5t3dCfa52PoK4t48M0V7DtS6HU0EZGTUrlXQtfUenx4/3mMuuAM/rF8G5c8N5t/rtiud8aLSNBSuVdSXLSfhy47k+mjetGkTg3ufWMZN09axKJNe1XyIhJ0VO4/UnqTBN69uye/7pfOmu0HueHPC7hu7Hw+XvONHoASkaBhXh11ZmRkuOzsbE+2HShHC0t4a8kWxs/JZeu+o7RuUIuRvVvR/5ymxETp300RCTwzW+KcyzjlOJX7T1dcUsoHq3Yw9rMc1u08ROPEOIb3SmVg1+bUjI3yOp6IhBGVuwecc8z+cjfjZuewMHcviTWiuaV7C27p0ZKkWrFexxORMKBy99iyzfsYNzuHmWu+ITbKxw0Zzbj9vFY0qxfvdTQRCWEq9yCxcddhxs/J4d1l2yh10K9DY+7onUZ6kwSvo4lICFK5B5mdBwp4ZW4uU7M2c6SwhD5tkrnr/DQyU+thZl7HE5EQoXIPUgfyi3g962smz9vEnsOFnNOsDnf2SePS9Ib4fCp5ETk5lXuQKygq4a0lW5kwJ5fNe/NplVyTO3un0b9TE2Kj/F7HE5EgpXIPEcUlpXz0xU7GfpbDmh0HaZgQy/BeqQzKbEEt3UYpIt+jcg8xzjk+37CHsZ/lsCA3j4S4KIZ2b8GtPVJJrq3bKEWkjMo9hK3Ysp9xs3P41+qdxPh9XJ+Rwsjz0miepNsoRSKdyj0M5Ow+zIQ5ufx96TaKS0u5on1j7uyTRrumiV5HExGPqNzDyDcHC5g0bxNTFm7m8LFizmtdn7v6pNE9LUm3UYpEGJV7GDpwtIgpWV8zae5X7Dl8jI4piWW3UZ7dCL9uoxSJCCr3MFZQVMI7S7cyfk4uX+flUzc+mm6tkuielkT3Vkmc0aCWjuhFwpTKPQKUlDpmrt7JJ2t3sTA3j237jwJQv1Ys3VrVO172qfVrquxFwkRly103Uocwv8+4vH1jLm/fGOccW/YeZUHuHhbk5DE/J4/3V+4AoFFC3PGi756WpJeXiUSASh25m1lf4AXAD0x0zv3xBGNuAB4DHLDCOTfoZD9TR+5VyzlH7p4jLMjJY0FuHlm5eew5XDaxd9M6Nb5T9k3q1PA4rYhUVsBOy5iZH/gSuATYCiwGBjrn1lQY0xp4E7jQObfPzBo453ad7Oeq3KuXc44Nuw6XlX1OHgs35bE/vwiAFknxx4u+e6skGiTEeZxWRH5IIE/LdAU2Oudyy3/wNKA/sKbCmNuBMc65fQCnKnapfmZGm4a1adOwNrf0aElpqWPdzkMsyC0r+w9W7WDa4i0ApCXXLC/6+nRrVU8TjYiEoMqUe1NgS4XlrUDm98a0ATCzeZSdunnMOfevgCSUKuHzGelNEkhvksDwXqmUlDpWbz9w/DTOu0u38frCzQCc2bA23dOS6NYqiW6t6lEnPsbj9CJyKoG6oBoFtAbOB1KAOWbW3jm3v+IgMxsJjARo3rx5gDYtgeD3GR1S6tAhpQ539EmjqKSUVdvKyn5hbh7TFm/mL/O/wgzOapRAj7Sy0zhdUuuREBftdXwR+Z7KlPs2oFmF5ZTydRVtBbKcc0XAJjP7krKyX1xxkHNuPDAeys65n25oqXrRfh+dm9elc/O63HPBGRQWl7Ji6/7j5+xfXfg1E+duwmfQvmki3dKSGNS1OS2SanodXUSo3AXVKMouqF5EWakvBgY551ZXGNOXsoust5hZfWAZcI5zLu+Hfq4uqIa2gqISlm7ex8Ly0zjLt+zH7zMeuvRMhvVM1ROzIlUkYBdUnXPFZjYKmEHZ+fRJzrnVZvY4kO2cm17+2aVmtgYoAR4+WbFL6IuL9tMjrT490uoDZdMI/vLdVfzug7V8uGoHTw7oyBkNanmcUiRy6QlVCRjnHO8t38Zj09dwtKiEBy5pw4heqUT5fV5HEwkblT1y12+dBIyZcU2nFD5+oDcXnJnMHz9ax3Vj57N+5yGvo4lEHJW7BFyD2nGMG3IuLw3sxJZ9R+n30ue89O8NFJWUeh1NJGKo3KVKmBlXdmzCx6N7c+nZjXjm4y+5esw81mw/6HU0kYigcpcqlVQrljGDOjNuSGe+OVjAVX+ay7Mff0lhsY7iRaqSyl2qRd92jfl4dB/6dWjMi//ewFV/msuqrQe8jiUStlTuUm3q1ozh+Zs6MfHmDPYeKeTql+fx5L/Wcay4xOtoImFH5S7V7uL0hnw8ug/XdGrKy5/l8LMX57Js8z6vY4mEFZW7eCIxPpqnr+/I5GFdOHKsmOvGzueJD9dSUKSjeJFAULmLpy44swEzRvfmxi7NGD8nlyte+Jzsr/Z6HUsk5KncxXMJcdH84doOvD48k2PFpVz/5wX89z9Xk19Y7HU0kZClcpeg0at1fWaM7s3Qbi2YPO8r+j7/OQty9IoikdOhcpegUis2isf7t2PayG6YwcAJC/n1e19w+JiO4kV+DJW7BKVurZL46P7zuK1nKq9nfc1lz81h7oY9XscSCRkqdwla8TFR/ObKdN66ozuxUT6GvJLFo++s5GBBkdfRRIKeyl2CXkbLenx4/3mM7N2KN7O3cNlzc5i1XnOwi5yMyl1CQly0n19ccRbv3NWDWrFRDJu8mIfeWsGBfB3Fi5yIyl1CSqfmdXn/vl7cc0Ea7y7bxiXPzeaTNd94HUsk6KjcJeTERvl5+LK2vHd3T+rVjGHEq9ncP20Z+44Ueh1NJGio3CVktU9JZPqoXtx/UWs+WLmDS56bzUerdngdSyQoqNwlpMVE+Rh9SRumj+pFw4Q47pqylN9/sIbSUm/mBhYJFip3CQvpTRJ4756e3Ny9BRM+38QDby7XhCAS0aK8DiASKNF+H/991dk0TIjjqRnryTtSyLgh51IzVn/NJfLoyF3CiplxzwVn8OR1HZifk8fACQvZc/iY17FEqp3KXcLSDV2a8ech57J+5yEGjJ3P5rx8ryOJVCuVu4Sti9MbMvX2TPblF3Ht2Pms3q45WyVyqNwlrJ3boh7v3NWdGL9x458XMn+jXj4mkUHlLmHvjAa1eefuHjSpE8etkxfz/srtXkcSqXIqd4kIjRNr8NYdPejYLJF731jGX+Zt8jqSSJVSuUvESIyP5rXhmVx8VkMe++canpqxDuf0sJOEJ5W7RJS4aD9jB3dmYNfmjJmVwyNvr6S4RA87SfjR0x0ScaL8Pp64ph3JtWN58d8b2HukkD8N6kyNGL/X0UQCRkfuEpHMjAcuacPvrm7Hp+t3MXjiQr1VUsKKyl0i2pBuLRg7uDNfbD/IgHHz2bb/qNeRRAJC5S4Rr2+7xrx6W1d2HTrGtS/PY/3OQ15HEvnJVO4iQLdWSbx1Z3cArh83n0Wb9nqcSOSnUbmLlGvbKIF37upB/dqxDHklixmrd3odSeS0qdxFKkipG8/bd/YgvXECd72+hKlZm72OJHJaVO4i31OvZgxTb8+kT5tkfvHuKp7/5Es97CQhR+UucgLxMVGMvzmDAeem8PwnG/jle19Qoqn7JIRUqtzNrK+ZrTezjWb26Ak+v9XMdpvZ8vKvEYGPKlK9ov0+nhrQgbvPT2Nq1mbunrKEgqISr2OJVMopy93M/MAY4HIgHRhoZuknGPo359w55V8TA5xTxBNmxiN92/JfV6Yzc8033PzKIg4cLfI6lsgpVebIvSuw0TmX65wrBKYB/as2lkhwGdYzlRdv6sSyLfu4YdwCdh4o8DqSyElVptybAlsqLG8tX/d915nZSjN728yanegHmdlIM8s2s+zdu3efRlwR71zZsQl/GdaVbfuPct3Y+WzcddjrSCI/KFAXVP8JtHTOdQA+Bv56okHOufHOuQznXEZycnKANi1SfXqeUZ9pI7txrLiUAePms3TzPq8jiZxQZcp9G1DxSDylfN1xzrk859y3U8xPBM4NTDyR4NOuaSJ/v6sHdWpEM2jCQj5d943XkUT+n8qU+2KgtZmlmlkMcBMwveIAM2tcYfEqYG3gIooEn+ZJ8bx9Vw9aN6jN7a8u4c3sLaf+JpFqdMpyd84VA6OAGZSV9pvOudVm9riZXVU+7D4zW21mK4D7gFurKrBIsKhfK5Y3RnajR1oSj7y9kpc/26iHnSRomFd/GTMyMlx2drYn2xYJpMLiUh5+ewX/WL6dW3u05Df90vH5zOtYEqbMbIlzLuNU4zQTk8hPFBPl47kbziG5ViwT525i9+FjPHtDR2KjNLOTeEflLhIAPp/xq37pNEiI5YkP17H3cCFP39CRpnVqeB1NIpTeLSMSQCN7p/HcjR1ZtmUfFz79Gc/OXE9+YbHXsSQCqdxFAuyaTil8+uD5XHZ2I178dCMXPj2b95Zt08VWqVYqd5Eq0KRODV4c2Im37+xOcu1Yfv635Vw7dj7Lt+z3OppECJW7SBXKaFmPf9zTk6cGdGDrvqNcPWYeD/xtud5NI1VO5S5SxXw+4/qMZsx66HzuPj+N91fu4IKnP+NPn27QK4SlyqjcRapJrdgoHunblk8e6EOfNsk8PfNLLnpmNh+s3KHz8RJwKneRatY8KZ5xQ8/ljdu7UTsuinumLuXG8Qv5YtsBr6NJGFG5i3ike1oSH9x3Hk9c056Nuw5z5Z/m8ug7K9l96Nipv1nkFFTuIh7y+4xBmc2Z9dD5DO+ZyttLtnLB05/x59k5HCvW+Xg5fSp3kSCQWCOaX/VLZ+bo3mSm1uMPH63j0ufmMHP1Tp2Pl9OichcJIq2Sa/HKrV34621difb7GPnaEoa+soj1Ow95HU1CjMpdJAj1aZPMR/efx2NXprNq2wEuf2EOv37vC/YeKfQ6moQIlbtIkIr2+7i1ZyqfPXQ+Q7u1YOqizZz/1Cwmzd1EUUmp1/EkyKncRYJc3Zox/Hf/dnx0/3l0bFaHx99fQ9/n5zBr/S6vo0kQU7mLhIg2DWvz6m1dmXhzBiWljmGTF3Pr5EVs3HXY62gShFTuIiHEzLg4vSEzR/fhl1ecxZKv9tH3+Tk8/s81HMgv8jqeBBGVu0gIionycXvvVsx6+Hyuz2jG5PmbOP/pWby+8GuKdT5eULmLhLT6tWL5w7Xtef/eXrRpWJtfvfcF/V6ay/yNe7yOJh5TuYuEgbObJDJtZDfGDu7M4WPFDJqYxchXs8nZrfPxkcq8evotIyPDZWdne7JtkXBWUFTCK3M3MWbWRvILS2iRFE/3Vkl0T0uie6skGiTEeR1RfgIzW+KcyzjlOJW7SHjadbCA6Su2szA3j6xNezlUUDaXa1pyzfKir0+3VvVIqhXrcVL5MVTuInJcSalj9fYDLMjJY0FuHos27SW/sOzFZGc2rF1W9mlJdEtNIjE+2uO0cjIqdxH5QUUlpazceoCFuXksyMkj++u9FBSVYgbpjROOn8bpmlqP2nEq+2CicheRSjtWXMKKLWVH9vNz9rBs834KS0rxGbRPqXO87Lu0rEt8TJTXcSOayl1ETltBUQlLv97HgvIj++Vb9lNc6ojyGR2blZV9j7QkOreoS1y03+u4EUXlLiIBk19YTPZX+5hffs5+1db9lDqI8fvo1LzO8Ttxzmleh9golX1VUrmLSJU5VFDE4q/2Hr9Au3r7QZyDuGgfGS3qlV2cbZVEh5REov16nCaQVO4iUm325xeStams7Bfm5rGufHKRmjF+MlrWY2DX5vRt18jjlOFB5S4insk7fOx42c/+cjdb9uXz7A0duaZTitfRQl5ly12XvUUk4JJqxXJF+8Zc0b4xBUUl3PaXxTz45griovxc3r6x1/Eigk6GiUiViov2M+HmDDo1r8t905Yxa50mGakOKncRqXI1Y6OYPKwLbRslcMfrS5int1ZWOZW7iFSLhLhoXr2tK6lJNRnx12yyv9rrdaSwpnIXkWpTt2YMr4/IpHFiHMMmL2bl1v1eRwpbKncRqVbJtWOZcnsmifHR3DxpEet2HvQ6UlhSuYtItWucWIOpI7oRF+VnyMQsTSpSBSpV7mbW18zWm9lGM3v0JOOuMzNnZqe8B1NEIlvzpHim3J4JwOAJWWzZm+9xovByynI3Mz8wBrgcSAcGmln6CcbVBu4HsgIdUkTCU1pyLV4bnklBcQkDJyxkx4GjXkcKG5U5cu8KbHTO5TrnCoFpQP8TjPst8D9AQQDziUiYO6txAq/e1pUD+UUMnpDF7kPHvI4UFipT7k2BLRWWt5avO87MOgPNnHMfBDCbiESIDil1mDysCzsOFDD0lSz2HSn0OlLI+8kXVM3MBzwLPFiJsSPNLNvMsnfv3v1TNy0iYSSjZT0m3pJB7p4j3DxpEQcLiryOFNIqU+7bgGYVllPK132rNtAO+MzMvgK6AdNPdFHVOTfeOZfhnMtITk4+/dQiEpZ6nlGfcUM6s27nQYZNXsyRY8VeRwpZlSn3xUBrM0s1sxjgJmD6tx865w445+o751o651oCC4GrnHN65aOI/GgXtm3ICzd1Ytnmfdz+ajYFRSVeRwpJpyx351wxMAqYAawF3nTOrTazx83sqqoOKCKR54r2jXnmho4syM3jrteXUFhc6nWkkKP3uYtI0JqatZlfvLuKy9s14qWBnYjSrE6Vfp+79pSIBK1Bmc35db90PvpiJw+/vZLSUm8ORkORJusQkaA2vFcqBUUlPDVjPXHRPp64pj1m5nWsoKdyF5Ggd88FZ5BfWMyYWTnERfv5Tb90FfwpqNxFJCQ8dOmZHC0sZdK8TcTH+Hn4srZeRwpqKncRCQlmxq/7ncXRohLGzMqhRrSfURe29jpW0FK5i0jIMDN+f3U7jhWV8PTML6kRE8XwXqlexwpKKncRCSk+n/HkgA4UFJfw2/fXEBftY3BmC69jBR2Vu4iEnCi/j+dv7ERB0RJ+9d4X1Ij2c23nFK9jBRXd5y4iISkmysfLgzvTIy2Jh95awYerdngdKaio3EUkZMVF+5lwcwadm9flvjeW8em6b7yOFDRU7iIS0uJjopg0rAvpTRK48/WlzN2wx+tIQUHlLiIhLyEumr8O60qr+jW5/dVsFn+11+tInlO5i0hYqFszhteGZ9K4ThzDJi9mxZb9XkfylMpdRMJGcu1YpozIpG7NaG6etIi1Ow56HckzKncRCSuNE2swdUQ34mP8DH0li427DnsdyRMqdxEJO83qxfP6iEwABk9cyOa8fI8TVT+Vu4iEpbTkWrw+IpNjxaUMmriQ7fuPeh2pWqncRSRstW2UwGu3ZXIgv4jBE7PYdajA60jVRuUuImGtfUoif7mtC98cLGDoxEXsO1LodaRqoXIXkbB3bot6TLg5g015R7h50iIOFhR5HanKqdxFJCL0PKM+44Z0Zu2Og9w2eTH5hcVeR6pSKncRiRgXtm3ICzd1YunmfYx8dQkFRSVeR6oyKncRiSg/69CYJwd0ZO7GPYyaupSiklKvI1UJlbuIRJwB56bw26vb8cnaXYz+23JKSp3XkQJOk3WISEQa2q0FRwuLeeLDdcRF+3nyug74fOZ1rIBRuYtIxBrZO40jx0p44d8bqBnj57GrzsYsPApe5S4iEe3nF7fmaFEJ4+fkUiMmiv/oe2ZYFLzKXUQimpnxn5e35cixYsbNzqFmjJ97L2rtdayfTOUuIhHPzPht/3YcLSzhmY+/pEaMnxHntfI61k+ichcRAXw+48kBHSgoLuF3H6wlPiaKQZnNvY512lTuIiLlovw+nr+xE0cLs/nle6uoEePjmk4pXsc6LbrPXUSkgpgoH2OHnEu31CQeemsl//pih9eRTovKXUTke+Ki/Uy8JYOOKYnc+8YyZq3f5XWkH03lLiJyAjVjo5g8rCttGtbmzteWsCAnz+tIP4rKXUTkByTWiOa14Zk0rxfP8L8uZunmfV5HqjSVu4jISdSrGcOUEZkk147l1kmLWL39gNeRKkXlLiJyCg0S4pgyIpNasVEMfWURG3cd8jrSKancRUQqIaVuPFNu74bPjEETsvg674jXkU5K5S4iUkmp9WsyZUQmhSWlDJqQxfb9R72O9IMqVe5m1tfM1pvZRjN79ASf32lmq8xsuZnNNbP0wEcVEfHemY1q89ptmRw8WsSQiVnsPnTM60gndMpyNzM/MAa4HEgHBp6gvKc659o7584BngSeDXhSEZEg0T4lkcnDurDjQAFDJmax70ih15H+n8ocuXcFNjrncp1zhcA0oH/FAc65gxUWawLhN62JiEgFGS3rMfGWDDblHeGWyYs4VFDkdaTvqEy5NwW2VFjeWr7uO8zsHjPLoezI/b7AxBMRCV49z6jP2MGdWbP9ILf9ZTH5hcVeRzouYBdUnXNjnHNpwH8AvzrRGDMbaWbZZpa9e/fuQG1aRMQzF53VkOdvOoclX+/jjteWUFBU4nUkoHLlvg1oVmE5pXzdD5kGXH2iD5xz451zGc65jOTk5MqnFBEJYv06NOF/ruvA5xv2MGrqMopKSr2OVKlyXwy0NrNUM4sBbgKmVxxgZhWnLfkZsCFwEUVEgt/1Gc34bf+z+WTtNzzw5gpKSr299HjK97k754rNbBQwA/ADk5xzq83scSDbOTcdGGVmFwNFwD7glqoMLSISjIZ2b0l+YQl/+GgdNaJ9/PHaDvh83szHWqnJOpxzHwIffm/dbyr8+f4A5xIRCUl39EnjSGEJL/57A/ExUfzXlemeTLitmZhERAJs9MWtyT9WzMS5m4iP8fNI37bVnkHlLiISYGbGL392FvlFJbz8WQ7xMX5GXdj61N8YQCp3EZEqYGb8rn87CgpLeHrml9SIiWJ4r9Rq277KXUSkivh8xpMDOpBfWMJv319DfIyfgV2bV8+2q2UrIiIRKsrv48WBnTj/zGR+8e4q3lt2sseEAkflLiJSxWKifIwbci6ZqfV48K0VzFi9s8q3qXIXEakGcdF+Jt7ShT5tkmmUEFfl29M5dxGRalIrNopJt3aplm3pyF1EJAyp3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyp3EVEwpA5581UUGa2G/j6NL+9PrAngHFCnfbHd2l//B/ti+8Kh/3Rwjl3ykmoPSv3n8LMsp1zGV7nCBbaH9+l/fF/tC++K5L2h07LiIiEIZW7iEgYCtVyH+91gCCj/fFd2h//R/viuyJmf4TkOXcRETm5UD1yFxGRkwi5cjezvma23sw2mtmjXufxipk1M7NZZrbGzFab2f1eZwoGZuY3s2Vm9r7XWbxmZnXM7G0zW2dma82su9eZvGJmo8t/T74wszfMrOpny/BYSJW7mfmBMcDlQDow0MzSvU3lmWLgQedcOtANuCeC90VF9wNrvQ4RJF4A/uWcawt0JEL3i5k1Be4DMpxz7QA/cJO3qapeSJU70BXY6JzLdc4VAtOA/h5n8oRzbodzbmn5nw9R9ovb1NtU3jKzFOBnwESvs3jNzBKB3sArAM65Qufcfm9TeSoKqGFmUUA8sN3jPFUu1Mq9KbClwvJWIrzQAMwQvUBMAAABkElEQVSsJdAJyPI2ieeeBx4BSr0OEgRSgd3A5PLTVBPNrKbXobzgnNsGPA1sBnYAB5xzM71NVfVCrdzle8ysFvAO8HPn3EGv83jFzPoBu5xzS7zOEiSigM7AWOdcJ+AIEJHXqMysLmX/w08FmgA1zWyIt6mqXqiV+zagWYXllPJ1EcnMoikr9inOub97ncdjPYGrzOwryk7XXWhmr3sbyVNbga3OuW//N/c2ZWUfiS4GNjnndjvnioC/Az08zlTlQq3cFwOtzSzVzGIouygy3eNMnjAzo+x86lrn3LNe5/Gac+4/nXMpzrmWlP29+NQ5F/ZHZz/EObcT2GJmZ5avughY42EkL20GuplZfPnvzUVEwMXlKK8D/BjOuWIzGwXMoOyK9yTn3GqPY3mlJzAUWGVmy8vX/cI596GHmSS43AtMKT8QygWGeZzHE865LDN7G1hK2V1my4iAJ1X1hKqISBgKtdMyIiJSCSp3EZEwpHIXEQlDKncRkTCkchcRCUMqdxGRMKRyFxEJQyp3EZEw9L8rMJksf3lpFAAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_stacked_gru.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 83.33%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }