{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Long Short-Term Memory. \n", "\n", "### Many to One Classification by LSTM\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharLSTM class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharLSTM:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # LSTM cell\n", " with tf.variable_scope('lstm_cell'):\n", " lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " outputs, states = tf.nn.dynamic_rnn(cell = lstm_cell, inputs = self._X_batch,\n", " sequence_length = self._X_length, dtype = tf.float32)\n", " \n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = states.h, num_outputs = n_of_classes, activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharLSTM" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_lstm = CharLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb, n_of_classes = 2,\n", " hidden_dim = 16, dic = char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_lstm.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.686\n", "epoch : 2, tr_loss : 0.663\n", "epoch : 3, tr_loss : 0.633\n", "epoch : 4, tr_loss : 0.613\n", "epoch : 5, tr_loss : 0.586\n", "epoch : 6, tr_loss : 0.547\n", "epoch : 7, tr_loss : 0.514\n", "epoch : 8, tr_loss : 0.473\n", "epoch : 9, tr_loss : 0.429\n", "epoch : 10, tr_loss : 0.379\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_lstm.ce_loss])\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x11a517a58>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VPW9x/H3N5OEhB1JQEgCYZVdlpE1gNYNN9AiCipFLYIKorW11d7e9l7trbWLIooK4r4hoFUUFFGRHSQBRHZCWJLIEvY1JITf/SOjjRTMAElOMvN5Pc88MmfOyXwyj3zm8PudxZxziIhIeIjwOoCIiJQdlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgYUemLiISRoErfzPqY2TozSzezh0/x+lNmtjzwWG9m+4q8NsTMNgQeQ0oyvIiInBkr7uQsM/MB64HLgSxgCTDIObf6NOvfB3Rwzt1pZucBqYAfcEAa0Mk5t7fkfgUREQlWZBDrdAbSnXMZAGY2EegHnLL0gUHAnwJ/vhKY6ZzbE9h2JtAHeOd0bxYXF+eSk5ODCi8iIoXS0tJ2Oefii1svmNJPADKLPM8CupxqRTNrCDQCvvyJbRNOsd0wYBhAgwYNSE1NDSKWiIh8z8y2BLNeSU/kDgSmOOcKzmQj59x455zfOeePjy/2i0pERM5SMKWfDSQVeZ4YWHYqA/nx0M2ZbCsiIqUsmNJfAjQzs0ZmFk1hsU89eSUzawHUAhYWWTwDuMLMaplZLeCKwDIREfFAsWP6zrnjZjaSwrL2AS8751aZ2aNAqnPu+y+AgcBEV+RwIOfcHjN7jMIvDoBHv5/UFRGRslfsIZtlze/3O03kioicGTNLc875i1tPZ+SKiIQRlb6ISBgJmdJ3zvGX6Wv4JnNf8SuLiISpYE7OqhC27D7CO4u3Mn5OBhcl12Joz8Zc1rIuvgjzOpqISLkRUhO5h44d590lmbw8bxPZ+46SXLsyd6Y04sZOiVSODpnvNxGR/xDsRG5Ilf73jhecYMaqHbw4N4PlmfuoERvFrV0aMKR7MnWrx5RQUhGR8iOsS7+otC17eHHOJmas3k5khHHdhfUZmtKYVvWrl9h7iIh4LdjSD/kxj04Nz6PT4PPYsvswr8zfzKTUTN5fmk2PprUZmtKY3s3jidC4v4iEiZDf0z/Z/iP5vLNkK6/O38z2A7k0rVOVX6Y04oYOCcRE+UrtfUVESpOGd4qRd/wE0779jhfnbGL1tgPUrhLNbV0bMrhbQ+KqVir19xcRKUkq/SA551iUsYcJczP4Yu1OoiMj+HmHBIb2bETTOtXKLIeIyLnQmH6QzIxuTWrTrUltNuYc4qV5m3gvLYuJSzK5+IJ47urZmO5NamOmcX8RqfjCfk//VHYfOsZbi7fy+sLN7DqUR8t61Rma0ojrLqxPdGTInMQsIiFEwzslIDe/gKnLv2PCvAzW7zhEnWqVGNI9mVu7NKBm5Wiv44mI/EClX4Kcc8zZsIsJczOYu2EXsVE+BvgTubNHI5LjqngdT0REpV9a1m4/wIS5m/hweTbHTzgub1mXoT0bc1FyLY37i4hnVPqlbOeBXN5YtIU3Fm1h35F82iXWYGjPxlzd5nwifRr3F5GypdIvI0fzCpiyNIuX521i067DJNSM5fbuydzWtSGx0TrZS0TKhkq/jJ044fhi7U4mzM1g8aY91K8Rw8NXt+S6dvU07CMipU63SyxjERHG5a3q8u7wbkwc1pWalaMZ9c4ybnxhoW7sIiLlhkq/FHRtXJuP7kvhif5t2bL7MP3GzufBScvZcSDX62giEuZU+qXEF2HcfFEDZv3mYu7u3YSPv9nGxX//ime+2EBufoHX8UQkTKn0S1m1mCgevqoFMx/sRe/m8fxz5nou/edsPvrmO8rbfIqIhL6gSt/M+pjZOjNLN7OHT7POTWa22sxWmdnbRZYXmNnywGNqSQWvaBrWrsILgzvxzl1dqR4bxX3vLGPACwtZkaXxfhEpO8UevWNmPmA9cDmQBSwBBjnnVhdZpxkwCfiZc26vmdVxzu0MvHbIOVc12EAV9eidM1FwwjE5NZN/fLaOXYfyuLFTIg9deYFu5SgiZ60kj97pDKQ75zKcc3nARKDfSevcBYx1zu0F+L7w5dR8EcbAzoXj/cN7N2bq8u+45B9fMXZWusb7RaRUBVP6CUBmkedZgWVFNQeam9l8M1tkZn2KvBZjZqmB5def6g3MbFhgndScnJwz+gUqsmoxUTxyVUtmPtiLns3i+PuMdVz6z9l8vELj/SJSOkpqIjcSaAZcDAwCXjSzmoHXGgb+yXELMNrMmpy8sXNuvHPO75zzx8fHl1CkiqNh7SqMG+zn7bu6UC0mkpFvL+PmcYtYmb3f62giEmKCKf1sIKnI88TAsqKygKnOuXzn3CYK5wCaATjnsgP/zQC+AjqcY+aQ1b1JHNNG9eQvN7RlY84hrnt2Hg9N/oadOr5fREpIMKW/BGhmZo3MLBoYCJx8FM4HFO7lY2ZxFA73ZJhZLTOrVGR5D2A1clq+COOWLg2Y9dDF3NWzMR8sz9Z4v4iUmGJL3zl3HBgJzADWAJOcc6vM7FEz6xtYbQaw28xWA7OAh5xzu4GWQKqZfRNY/teiR/3I6VWPieL3V7dk5q96071p4Xj/ZU/OZvq32zTeLyJnTRdcqyDmp+/isY9Xs3b7QTo3Oo8/XtuKNgk1vI4lIuWELrgWYno0LRzv/78b2pC+s3C8/7dTvmHnQY33i0jwVPoViC/CuLVLQ2b95mKGpjTiX8uy+dk/ZvP8Vxs13i8iQVHpV0A1YqP4r2ta8dmvetO1cW2e+HQtlz81m0803i8ixVDpV2CN4qowYYifN3/ZhcpRkdzz1lIGjtfx/SJyeir9EJDSLI5po1L48/Vt2BAY73/4vRXkHDzmdTQRKWd09E6I2X80n2e+2MCrCzYT6TOublOPAf4kujQ6j4gI3bZRJFTpHrlhLiPnEBPmbeKj5d9x8Nhxks6L5caOSfTvlEBircpexxOREqbSFwBy8wuYsWo7k1IzmZ++GzPo0SSOAf5Ermx9PjFRPq8jikgJUOnLf8jae4T30rKZnJZJ1t6jVIuJ5LoL63OTP4kLE2tgpuEfkYpKpS+ndeKEY9Gm3UxJzWL6ym3k5p+gWZ2qDPAnckOHROKrVfI6ooicIZW+BOVgbj4fr9jG5NRMlm7dhy/CuOSCOgzwJ/KzFnWI8ukAL5GKQKUvZyx95yGmpGXx3tIscg4eo3aVaK7vkMAAfyItzq/udTwR+QkqfTlrxwtOMGdDDpNTs/h8zQ7yCxztEmswoFMifS9MoEblKK8jishJVPpSIvYczuODZdlMTstizbYDREdGcEWrugzwJ5HSNA6fjv0XKRdU+lLiVmbvZ0paFh8sz2bfkXzq1Yihf8dEbuyUSHJcFa/jiYQ1lb6UmmPHC/h89U4mp2UyZ30OJxx0Tj6PG/2JXNO2HlUqRXodUSTsqPSlTGzfn8t7S7OYkpbFpl2HqRzt45q2hZd+uCi5lo79FykjKn0pU845UrfsZXJqJtNWbONwXgHJtSszwJ9E/46JnF8jxuuIIiFNpS+eOXzsOJ+sLLz0w9eb9hDlMwZe1IB7L2lCvRqxXscTCUkqfSkXNu86zLg5GUxOzSTCjEGdk7j3kqbUra49f5GSpNKXciVzzxHGzkpnclpW4LaPDbjn4ibUqabyFykJKn0pl7buPsIzX27g/WXZREYYg7s2ZHjvJrrej8g5UulLubZ512HGfLmBD5ZlEx0ZwS+6JTO8V2NqV1X5i5yNYEs/qKtpmVkfM1tnZulm9vBp1rnJzFab2Soze7vI8iFmtiHwGBL8ryChLDmuCk/e1J7PH+zNVW3qMWFuBj3/Nou/frKWPYfzvI4nErKK3dM3Mx+wHrgcyAKWAIOcc6uLrNMMmAT8zDm318zqOOd2mtl5QCrgBxyQBnRyzu093ftpTz88pe88xJgvNvDRiu+oHOXj9h7J3NWzMTUrR3sdTaRCKMk9/c5AunMuwzmXB0wE+p20zl3A2O/L3Dm3M7D8SmCmc25P4LWZQJ9gfwkJH03rVGXMoA589kAvLm5Rh+e+2kjKE7P452fr2H8k3+t4IiEjmNJPADKLPM8KLCuqOdDczOab2SIz63MG24r8oFndaoy9pSOf3t+LXs3jeObLdFKe+JKnZq5n/1GVv8i5KqmLpEQCzYCLgURgjpm1DXZjMxsGDANo0KBBCUWSiuyC86vx3K2dWLPtAKM/X8/TX2zglfmb+GVKY+5ISaZ6jC7vLHI2gtnTzwaSijxPDCwrKguY6pzLd85tonAOoFmQ2+KcG++c8zvn/PHx8WeSX0Jcy3rVGTfYz7RRKXRpXJunPl9Pzydm8eyXGzh07LjX8UQqnGAmciMpLPFLKSzsJcAtzrlVRdbpQ+Hk7hAziwOWAe359+Rtx8CqSymcyN1zuvfTRK78lG+z9jP68/V8sXYnNStHMaxXY4Z0S9aVPSXsldhErnPuODASmAGsASY551aZ2aNm1jew2gxgt5mtBmYBDznndgfK/TEKvyiWAI/+VOGLFKdtYg1euv0iPhzRgw5JNfnbp+vo+bdZvDB7I0fytOcvUhydnCUV2tKtexn9+QbmrM8hrmo0w3s14bauDYmN9nkdTaRM6YxcCStpW/bw1MwNzEvfRVzVStxzcRNu7dKAmCiVv4QHlb6EpSWb9/DUzPUs2LibOtUKy39QZ5W/hD6VvoS1RRm7eWrmehZv2kPd6pUYcUlTbr4oiUqRKn8JTSV67R2RiqZr49q8O7wbb9/VhYbnVeGPH67i2jHz2LDjoNfRRDyl0peQ1r1JHO8O78rLt/vZeySPvs/O5/2lWV7HEvGMSl9CnpnxsxZ1mT6qJ+0Sa/DgpG/47ZRvOJpX4HU0kTKn0pewUad6DG8N7cLIS5oyOS2L68fOJ33nIa9jiZQplb6ElUhfBL+58gJevaMzOYeO0ffZeXy4/D+uDCISslT6EpZ6N49n+qietKlfg/snLueR91eQm6/hHgl9Kn0JW+fXiOHtu7pwz8VNeOfrTG54bgEZORrukdCm0pewFumL4Hd9WvDKHRexff9RrntmHh99853XsURKjUpfBLjkgjpMG9WTFvWqc987y/ivf32r4R4JSSp9kYD6NWOZOKwrw3s15q3FW+n//AI27zrsdSyREqXSFykiyhfBI1e35KUhfrL2HuXaZ+YxbcU2r2OJlBiVvsgpXNqyLtPv70mzulUZ8fZS/vjhSo4d13CPVHwqfZHTSKgZy7vDujE0pRGvL9zCjc8vZOvuI17HEjknKn2RnxAdGcEfrm3F+MGd2LL7MNc8M5dPV2q4Ryoulb5IEK5ofT7TRvWkcXxV7n5zKf8zdRV5x094HUvkjKn0RYKUdF5lJg/vxh09knl1wWYGvLCAzD0a7pGKRaUvcgaiIyP403WteeG2TmTsOsw1Y+by2artXscSCZpKX+Qs9GlzPtPu60nD2lUY9kYaf/54NfkFGu6R8k+lL3KWGtSuzJR7ujGkW0MmzNvETeMWkr3vqNexRH6SSl/kHFSK9PG//drw3K0dSd9xiKufnssXa3Z4HUvktIIqfTPrY2brzCzdzB4+xeu3m1mOmS0PPIYWea2gyPKpJRlepLy4um09ProvhYSasfzytVQen75Gwz1SLkUWt4KZ+YCxwOVAFrDEzKY651aftOq7zrmRp/gRR51z7c89qkj5lhxXhffv7c6fp61m3JwMUrfs5ZlBHahfM9braCI/CGZPvzOQ7pzLcM7lAROBfqUbS6Riiony8efr2zJmUAfWbjvANWPmMmvdTq9jifwgmNJPADKLPM8KLDtZfzNbYWZTzCypyPIYM0s1s0Vmdv25hBWpKPpeWJ+P7kuhbvUY7nhlCU98upbjGu6RcqCkJnI/ApKdc+2AmcBrRV5r6JzzA7cAo82syckbm9mwwBdDak5OTglFEvFW4/iqfDCiB4M6N+D5rzZyy4uL2b4/1+tYEuaCKf1soOiee2Jg2Q+cc7udc8cCTycAnYq8lh34bwbwFdDh5Ddwzo13zvmdc/74+Pgz+gVEyrOYKB+P/7wto29uz8rv9nO1TuYSjwVT+kuAZmbWyMyigYHAj47CMbN6RZ72BdYEltcys0qBP8cBPYCTJ4BFQt71HRKYOrJwuGfYG2kMfyNVe/3iiWJL3zl3HBgJzKCwzCc551aZ2aNm1jew2igzW2Vm3wCjgNsDy1sCqYHls4C/nuKoH5Gw0LROVaaO7MHv+rTgq3U5XPbkbF5bsJmCE87raBJGzLny9T+c3+93qampXscQKVVbdh/mDx+sZO6GXbRPqsnjP29Ly3rVvY4lFZiZpQXmT3+SzsgV8UDD2lV4/c7OjL65PZl7jnDtM/N4/JM1HM3T3bmkdKn0RTxiZlzfIYEvft2bGzsmMm52BleMns3s9TqCTUqPSl/EYzUrR/PEje14d1hXonwRDHn5a0a9s4ycg8eK31jkDKn0RcqJLo1r88n9PXngsmZ8unI7lz05m4lfb+WEJnqlBKn0RcqRSpE+HrisOdPv70mL86vx8PvfMnD8ItJ3HvQ6moQIlb5IOdS0TlUmDuvK3/q3Y92Og1z19FyenLme3HxN9Mq5UemLlFNmxk0XJfHFr3tzTdt6jPliA1c/PZeFG3d7HU0qMJW+SDkXV7USowd24PU7O3P8hGPQi4t4aPI37D2c53U0qYBU+iIVRK/m8cx4oBf3XNyEfy3L5tInZ/OvZVmUtxMspXxT6YtUILHRPn7XpwUfj0qhYe3K/Ordb/jFy1+zZfdhr6NJBaHSF6mAWpxfnSl3d+exfq1ZvnUfVzw1h7Gz0nWLRimWSl+kgvJFGIO7JfP5r3vzsxZ1+PuMdVw7Zh5pW/Z6HU3KMZW+SAVXt3oMz9/WiQm/8HMwN58bX1jAHz74lgO5+V5Hk3JIpS8SIi5rVZeZD/bmju6NeHvxVi7752ymf7tNE73yIyp9kRBSpVIkf7yuFR+OSCG+WiXufWspQ19LJXvfUa+jSTmh0hcJQW0Ta/DhiB784ZqWLNi4m8ufnM2EuRm6Obuo9EVCVaQvgqE9GzPzwV50bVybP09bw/XPzefbrP1eRxMPqfRFQlxircq8NMTP2Fs6suPAMfqNncejH63WDVvClEpfJAyYGde0q8fnD/ZmUOcGvDx/Ezc8N59Nu3RSV7hR6YuEkRqxUfzfDW157c7ObD+QS99n5vHpyu1ex5IypNIXCUO9m8czbVRPGtepyt1vpvGX6Ws0yRsmVPoiYSqhZiyThndlcNeGjJ+TwS0TFrPzQK7XsaSUqfRFwlilSB+PXd+Gpwe259us/Vw9Zh6LMnS9/lAWVOmbWR8zW2dm6Wb28Clev93McsxseeAxtMhrQ8xsQ+AxpCTDi0jJ6Nc+gQ9H9qB6bCS3TljMuNkbdSZviCq29M3MB4wFrgJaAYPMrNUpVn3XOdc+8JgQ2PY84E9AF6Az8Cczq1Vi6UWkxDSvW42pI1Po0/p8Hv9kLcPfSNP1e0JQMHv6nYF051yGcy4PmAj0C/LnXwnMdM7tcc7tBWYCfc4uqoiUtqqVInn2lg7897Wt+HLtTvo+M4/V3x3wOpaUoGBKPwHILPI8K7DsZP3NbIWZTTGzpDPcVkTKCTPjlymNmDisK0fzC7jhuflMTs0sfkOpEEpqIvcjINk5147CvfnXzmRjMxtmZqlmlpqTk1NCkUTkXPiTz2PaqJ50aliLh6as4JH3V5Cbr7N4K7pgSj8bSCryPDGw7AfOud3OuWOBpxOATsFuG9h+vHPO75zzx8fHB5tdREpZXNVKvPHLLoy4pAnvfJ3JjS8sIHPPEa9jyTkIpvSXAM3MrJGZRQMDgalFVzCzekWe9gXWBP48A7jCzGoFJnCvCCwTkQrCF2E8dGULJvzCz9bdR7hmzFy+XLvD61hylootfefccWAkhWW9BpjknFtlZo+aWd/AaqPMbJWZfQOMAm4PbLsHeIzCL44lwKOBZSJSwVzWqi4f39eTpPMqc+erqfx9xloKTuiwzorGytuxuH6/36WmpnodQ0ROIze/gP+ZuoqJSzLp3qQ2YwZ1IK5qJa9jhT0zS3PO+YtbT2fkisgZiYny8df+7fjbje1I27KXa8bMJXWz/gFfUaj0ReSs3ORP4v17uxMT5WPg+EW8NG+TzuKtAFT6InLWWtevwdSRKVzSog6PfbyakW8v49Cx417Hkp+g0heRc1IjNorxgzvx8FUt+GTlNvo+O4/1Ow56HUtOQ6UvIufMzLi7dxPeGtqVA0eP0+/Z+Xyw7D9OyZFyQKUvIiWmW5PaTB+VQtuEGjzw7nL++4OVHDuus3jLE5W+iJSoOtVjeOuuLgzr1Zg3Fm3hpnGLyN531OtYEqDSF5ESF+WL4PdXt+SF2zqycechrh0zl9nrdV2t8kClLyKlpk+bekwd2YO61WO4/ZWvGf35ek7oLF5PqfRFpFQ1jq/Kv+7twQ0dEhj9+QZuf3UJew7neR0rbKn0RaTUxUb7+OeAC/nLDW1ZtHE3146Zy/LMfV7HCksqfREpE2bGLV0aMOWebpgZA15YwBsLN+ss3jKm0heRMtUusSbTRqWQ0jSO//5wFY+8/y15x094HStsqPRFpMzVrBzNS0MuYuQlTZm4JJPbXlqscf4yotIXEU9ERBi/ufICnh7YnuWZ++g3dh7rtuvyDaVNpS8inurXPoFJw7uRm3+Cnz83ny/W6K5cpUmlLyKea59Uk6kje9AovgpDX09l3OyNmuAtJSp9ESkX6tWIZfLw7lzdph6Pf7KW30xeoev2lAKVvoiUG7HRPp69pQMPXNaM95ZmMWj8InIOHvM6VkhR6YtIuWJmPHBZc8be0pHV2w5w/dj5rP7ugNexQoZKX0TKpWva1WPy8O4UnHD0f34Bn67c7nWkkKDSF5Fyq21iDaaO7EHz86tx95tpjJ2Vrgnec6TSF5FyrU71GN4d1pV+7evz9xnruH/icnLzNcF7tiK9DiAiUpyYKB+jb25P87rV+PuMdWzZc4QXB3eiTvUYr6NVOEHt6ZtZHzNbZ2bpZvbwT6zX38ycmfkDz5PN7KiZLQ88Xiip4CISXsyMEZc0ZdzgTmzYcZC+z87n26z9XseqcIotfTPzAWOBq4BWwCAza3WK9aoB9wOLT3ppo3OufeBxdwlkFpEwdmXr85lyd3d8EcaAcQuYtmKb15EqlGD29DsD6c65DOdcHjAR6HeK9R4DngBySzCfiMh/aFW/Oh+M6EHr+jUY8fZSnpqpO3IFK5jSTwAyizzPCiz7gZl1BJKcc9NOsX0jM1tmZrPNrOep3sDMhplZqpml5uToPpoiUrz4apV4+64u9O+YyNNfbOC+d5ZxNE8TvMU556N3zCwCeBL49Sle3gY0cM51AB4E3jaz6iev5Jwb75zzO+f88fHx5xpJRMJEpUgf/xjQjt9f3YLpK7cxYNwCtu0/6nWsci2Y0s8Gkoo8Twws+141oA3wlZltBroCU83M75w75pzbDeCcSwM2As1LIriICBRO8A7r1YQJv/CzedcR+j47n2Vb93odq9wKpvSXAM3MrJGZRQMDganfv+ic2++ci3POJTvnkoFFQF/nXKqZxQcmgjGzxkAzIKPEfwsRCXuXtqzL+/d2JyYqgpvHL+LD5dnFbxSGii1959xxYCQwA1gDTHLOrTKzR82sbzGb9wJWmNlyYApwt3Nuz7mGFhE5leZ1q/HhiBTaJ9Xk/onL+fuMtZrgPYmVt1Oa/X6/S01N9TqGiFRgecdP8McPVzJxSSZXtKrLUze3p0ql0D4X1czSnHP+4tbTZRhEJORER0bw+M/b8sdrW/H5mh30f34BWXuPeB2rXFDpi0hIMjPuTGnEK3d0JnvfUfo9O5/UzRpdVumLSEjr3Tyef93bg2oxkdzy4mImp2YWv1EIU+mLSMhrWqcqH4zogT+5Fg9NWcFfpq+hIEwneFX6IhIWalaO5rU7OzO4a0PGz8ngrtdTOZib73WsMqfSF5GwEeWL4LHr2/BYv9bMXp9D/+cXsHV3eE3wqvRFJOwM7pbM63d2ZseBY/QbO49FGbu9jlRmVPoiEpZ6NI3jgxE9qFUlml+89DWfrQqPe/Cq9EUkbDWKq8K/7ulBq/rVueetpUz95juvI5U6lb6IhLUalaN4c2gXOjWsxf0TlzFpSWgf0qnSF5GwV7VSJK/d0ZmUpnH89r0VvLZgs9eRSo1KX0QEiI32MWGIn8ta1uVPU1fxwuyNXkcqFSp9EZGASpE+nr+tI9e2q8dfP1nLUzPXU94uSnmuQvuycyIiZyjKF8HTAzsQG+Xj6S82cDS/gEeuaoGZeR2tRKj0RURO4oswnujfjthoH+PnZHA0r4D/7duaiIiKX/wqfRGRU4iIMP63b2tio3yMm5PB0fwCnujfDl8FL36VvojIaZgZD1/VgthoH6M/30BufgFP3dyeKF/FnQ5V6YuI/AQz44HLmlM52sdfpq8lN/8Ez97SgZgon9fRzkrF/boSESlDw3o14bF+rfl8zQ7uej2Vo3kFXkc6Kyp9EZEgDe6WzN9ubMf89F0MefnrCnlpZpW+iMgZuMmfxOiBHUjbupfbXvqafUfyvI50RlT6IiJnqO+F9Xn+1o6s+e4Ag15czK5Dx7yOFDSVvojIWbii9flMGOJn065D3DxuITsO5HodKShBlb6Z9TGzdWaWbmYP/8R6/c3MmZm/yLJHAtutM7MrSyK0iEh50Kt5PK/d0Znt+3O5adxCsvaW/7twFVv6ZuYDxgJXAa2AQWbW6hTrVQPuBxYXWdYKGAi0BvoAzwV+nohISOjSuDZvDu3C3sN53PTCQjbtOux1pJ8UzJ5+ZyDdOZfhnMsDJgL9TrHeY8ATQNF/4/QDJjrnjjnnNgHpgZ8nIhIyOjSoxTvDupJ7/AQ3jVvI+h0HvY50WsGUfgJQ9K4CWYFlPzCzjkCSc27amW4rIhIKWtevwaThXTHg5nELWZm93+tIp3TOE7lmFgE8Cfz6HH7GMDNLNbPUnJycc40kIuKJpnWqMWl4NypHRzLoxUWkbdnrdaT/EEzpZwNJRZ4nBpbScuPTAAAFHUlEQVR9rxrQBvjKzDYDXYGpgcnc4rYFwDk33jnnd8754+Pjz+w3EBEpR5LjqjDp7m7UrhLN4JcWs3Djbq8j/Ugwpb8EaGZmjcwsmsKJ2anfv+ic2++ci3POJTvnkoFFQF/nXGpgvYFmVsnMGgHNgK9L/LcQESlHEmrGMml4NxJqxnL7K1/z1bqdXkf6QbGl75w7DowEZgBrgEnOuVVm9qiZ9S1m21XAJGA18CkwwjlXMS9YISJyBupUj2HisK40rVOVu15P5dOV272OBICVt1uB+f1+l5qa6nUMEZESsf9oPre/8jUrsvbz5E0X0q996RzLYmZpzjl/cevpjFwRkVJUIzaKN37ZhYuSa/HAu8t5d8lWT/Oo9EVESlnVSpG8cntnejaL53fvfcur8zd5lkWlLyJSBmKjfbz4i05c0aou//PRap7/aqMnOVT6IiJlpFKkj7G3dqTvhfV54tO1PPnZOsp6XlW3SxQRKUNRvgieurk9sVE+xnyZzpG8Av7rmpaYlc0N11X6IiJlzBdhPP7ztsRG+5gwbxNH8wt4rF8bIiJKv/hV+iIiHoiIMP50XStiony8MHsjufkneKJ/WyJ9pTvqrtIXEfGImfG7PhdQOdrHkzPXk5tfwJhBHfCV4h6/Sl9ExENmxqhLm1E52sf+o/mlWvig0hcRKReG9mxcJu+jQzZFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIyUu9slmlkOsOUcfkQcsKuE4lR0+ix+TJ/Hj+nz+LdQ+CwaOufii1up3JX+uTKz1GDuExkO9Fn8mD6PH9Pn8W/h9FloeEdEJIyo9EVEwkgolv54rwOUI/osfkyfx4/p8/i3sPksQm5MX0RETi8U9/RFROQ0Qqb0zayPma0zs3Qze9jrPF4ysyQzm2Vmq81slZnd73Umr5mZz8yWmdnHXmfxmpnVNLMpZrbWzNaYWTevM3nJzH4V+Huy0szeMbMYrzOVppAofTPzAWOBq4BWwCAza+VtKk8dB37tnGsFdAVGhPnnAXA/sMbrEOXE08CnzrkWwIWE8ediZgnAKMDvnGsD+ICB3qYqXSFR+kBnIN05l+GcywMmAv08zuQZ59w259zSwJ8PUviXOsHbVN4xs0TgGmCC11m8ZmY1gF7ASwDOuTzn3D5vU3kuEog1s0igMvCdx3lKVaiUfgKQWeR5FmFcckWZWTLQAVjsbRJPjQZ+C5zwOkg50AjIAV4JDHdNMLMqXofyinMuG/gHsBXYBux3zn3mbarSFSqlL6dgZlWB94AHnHMHvM7jBTO7FtjpnEvzOks5EQl0BJ53znUADgNhOwdmZrUoHBVoBNQHqpjZbd6mKl2hUvrZQFKR54mBZWHLzKIoLPy3nHPve53HQz2Avma2mcJhv5+Z2ZveRvJUFpDlnPv+X35TKPwSCFeXAZuccznOuXzgfaC7x5lKVaiU/hKgmZk1MrNoCidipnqcyTNmZhSO2a5xzj3pdR4vOececc4lOueSKfz/4kvnXEjvyf0U59x2INPMLggsuhRY7WEkr20FuppZ5cDfm0sJ8YntSK8DlATn3HEzGwnMoHD2/WXn3CqPY3mpBzAY+NbMlgeW/d45N93DTFJ+3Ae8FdhBygDu8DiPZ5xzi81sCrCUwqPelhHiZ+fqjFwRkTASKsM7IiISBJW+iEgYUemLiIQRlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgY+X93TYd2utQMuAAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 83.33%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }