{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Stacked LSTM with Drop out.\n", "\n", "### Many to One Classification by Stacked LSTM with Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class** \n", "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n", "- Applying **Stacking** to model by `tf.contrib.rnn.MultiRNNCell`\n", "- Replacing **RNN Cell** with **LSTM Cell**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharStackedLSTM class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharStackedLSTM:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, dic, hidden_dims = [32, 16]):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " self._keep_prob = tf.placeholder(dtype = tf.float32)\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # Stacked-LSTM\n", " with tf.variable_scope('stacked_lstm'):\n", " \n", " cells = []\n", " for hidden_dim in hidden_dims:\n", " cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " cell = tf.contrib.rnn.DropoutWrapper(cell = cell, output_keep_prob = self._keep_prob)\n", " cells.append(cell)\n", " else:\n", " cells = tf.contrib.rnn.MultiRNNCell(cells = cells)\n", " \n", " _, states = tf.nn.dynamic_rnn(cell = cells, inputs = self._X_batch,\n", " sequence_length = self._X_length, dtype = tf.float32)\n", " \n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = states[-1].h, num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharStackedLSTM" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_stacked_lstm = CharStackedLSTM(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n", " n_of_classes = 2, dic = char_dic, hidden_dims = [32,16])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_stacked_lstm.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.692\n", "epoch : 2, tr_loss : 0.665\n", "epoch : 3, tr_loss : 0.638\n", "epoch : 4, tr_loss : 0.610\n", "epoch : 5, tr_loss : 0.572\n", "epoch : 6, tr_loss : 0.503\n", "epoch : 7, tr_loss : 0.439\n", "epoch : 8, tr_loss : 0.338\n", "epoch : 9, tr_loss : 0.282\n", "epoch : 10, tr_loss : 0.229\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_stacked_lstm.ce_loss],\n", " feed_dict = {char_stacked_lstm._keep_prob : .5})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x117ea0048>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH+tJREFUeJzt3Xl0VdX9/vH3JwkBwpAwhCEkEGSSKUEIM6WKEw4BxKrQYh3rBIp+7a/a1n5rtV+r1lYsg0qdBUHrSKiCVlEEBUmQMA9hDAlDGBJmMu3fHwkaLZoANzl3eF5rsRb35iTnWWfBs3bOuXtvc84hIiLBJczrACIi4nsqdxGRIKRyFxEJQip3EZEgpHIXEQlCKncRkSCkchcRCUIqdxGRIKRyFxEJQhFenbhp06YuMTHRq9OLiASkjIyMPc652MqO86zcExMTSU9P9+r0IiIBycy2VuU43ZYREQlCVSp3MxtqZuvMLMvM7j/J1580s2Xlf9abWb7vo4qISFVVelvGzMKBycCFwHZgiZnNcs6tPnGMc+6eCsffCZxTDVlFRKSKqjJy7wNkOec2OecKgZnA8B85fjQwwxfhRETk9FSl3FsB2RVeby9/77+YWRugLfDJmUcTEZHT5esHqqOAN51zJSf7opndYmbpZpael5fn41OLiMgJVSn3HCChwuv48vdOZhQ/ckvGOTfVOZfinEuJja30Y5oiInKaqlLuS4AOZtbWzCIpK/BZ3z/IzM4GGgFf+jbid2Vm5zPpkw1s2XO4Ok8jIhLQKv20jHOu2MzGAXOBcOAF59wqM3sISHfOnSj6UcBMV82bsi7atJcnPlzPEx+uJzk+mtTkOC5LaknL6LrVeVoRkYBiXm2QnZKS4k53hmpu/lFmL88lLXMHK3IKMIPeiY1JTY7j0m4taFK/to/Tioj4BzPLcM6lVHpcIJZ7RZvyDjF7+Q5mZeaStfsQ4WHGwPZNSU1qyUVdWxBdt5YP0oqI+IeQKfcTnHOs3XmQtMxc0pbnkr3vKJHhYZzbKZbU5Dgu6NycupHhPjufiIgXQq7cK3LOsSw7n7TMHcxensvug8eJigzngs7NSU2OY3DHptSOUNGLSOAJ6XKvqKTU8dXmfaQtz+WDFTvYf6SIhnUiGNqtBanJcfQ/qwkR4Vo/TUQCg8r9JIpKSlmQtYe0zFw+XLWLQ8eLaVo/kku7t2RYchw9WzciLMxqNJOIyKlQuVfiWFEJn67bTVrmDv6zZhfHi0uJi67D5clxDEuOo2tcQ8xU9CLiX1Tup+DQ8WL+s3oXaZm5fLY+j+JSR9um9UhNasmwHnG0b9bA64giIoDK/bTlHylkzsqdpC3P5cuNeyl1cHaLBqSWj+gTGkd5HVFEQpjK3Qd2HzzG++WfoV+6rWz/kR4JMQwrnxXbvGEdjxOKSKhRuftY9r4j/HvFDmYty2X1jgOYQd+2J2bFtqRRvUivI4pICFC5V6Os3Ye+mSy1Ke8wEWHGoA5NGZYcx4VdmtOgjmbFikj1ULnXAOccq3ccYFZmLrMzd5CTf5TaEWEMObsZqclxDDm7GXVqabKUiPiOyr2GOedYui2ftMxcZi/fwZ5Dx6kXGc6FXZozrEccg9rHEhmhyVIicmZU7h4qKXUs2rSXtMxcPli5k4KjRcRE1eKSbi1ITYqj71lNCNdkKRE5DSp3P1FYXMrnG/LKZsWu3sWRwhJiG9Tmsu4tSU2Oo2frGE2WEpEqU7n7oaOFJXyydjdpmbl8sm43hcWltIqpS2pyHKnJLenSUrNiReTHqdz93IFjRXy0ahdpy3P5fMMeSkod7WLrlRd9HO1i63sdUUT8kMo9gOw7XMgHK3eQlpnL4s37cA66tGzIsB5xXJ7UkvhGmhUrImVU7gFqZ8GxsslSmblkZpfNiu3VphGpSS25NKklzRpoVqxIKFO5B4Fte4+QtjyXtMxc1u48SJhB/3ZNSE2KY2i3FsREaVasSKhRuQeZ9bvKtxDMzGXL3iPUCjcGdyjfQrBLc+rXjvA6oojUAJV7kHLOsSKn4JvJUjsKjlE7IozzOzcjNSmO8zQrViSoqdxDQGmpI33rfmYvz+X9FTvYc6iQepHhXNS1BanJLTUrViQIqdxDTHFJKYs27SufFbuDA8eKia5bi6Fdy/aK7XdWY+0VKxIEVO4hrLC4lAVZeaRl7uDDVTs5XFjyzV6xlyfFkdJGe8WKBCqVuwBle8XOW7ub2cu/3Su2RcM6XJ5UtvxBUny0ZsWKBBCVu/yXQ8eL+XjNt3vFFpU4WjeO+qboz27RQEUv4udU7vKjCo4UMXf1TtIyc/li415KSh3tm9UnNSmOy5NbavkDET+lcpcq23PoOB+sLCv6JVvKlj/oGteQy5PKlj/QpuAi/kPlLqflxPIHaZm5LCtf/uCc1jGkJmlTcBF/oHKXM5a97wizl5cV/YlNwfsklm0Kfkm3FjSpX9vriCIhR+UuPpW1+xCzy9e52Zh3mPAwY2D7pqQmteSiri2IrqtNwUVqgspdqoVzjrU7y9e5WZ5L9r6jRIaHMbhjLGP6teanHWP1iRuRaqRyl2rnnCNzewGzM3OZlZnL7oPHSYqP5s4hHbigczOVvEg1qGq5V2k+upkNNbN1ZpZlZvf/wDFXm9lqM1tlZq+damAJPGZGj4QYHri8CwvuG8KjI7uTf6SIX72SziVPfc6/l++gtNSbwYNIqKt05G5m4cB64EJgO7AEGO2cW13hmA7AG8AQ59x+M2vmnNv9Yz9XI/fgVFxSyqzMXCbNy2JT3mHaN6vP2PPakZoUp7VtRHzAlyP3PkCWc26Tc64QmAkM/94xvwImO+f2A1RW7BK8IsLDGNkzno/u+SkTR59DuBn3vJ7JBX//jDeWZFNYXOp1RJGQUJVybwVkV3i9vfy9ijoCHc1soZktMrOhvgoogSk8zEhNjuOD8T/h2Wt7Ub9OBL95aznnPfEpry7ayvHiEq8jigQ1X/2eHAF0AM4FRgP/NLOY7x9kZreYWbqZpefl5fno1OLPwsKMi7u2IG3cIF68vjfNGtbmD++uZPDj83hhwWaOFqrkRapDVco9B0io8Dq+/L2KtgOznHNFzrnNlN2j7/D9H+Scm+qcS3HOpcTGxp5uZglAZsZ5Zzfj7dsHMP3mviQ2qcdDs1fzk8c/4ZnPNnLoeLHXEUWCSlXKfQnQwczamlkkMAqY9b1j3qVs1I6ZNaXsNs0mH+aUIGFWNvnp9Vv788at/encsiGPfrCWQY99wsSPN3DgWJHXEUWCQqXl7pwrBsYBc4E1wBvOuVVm9pCZDSs/bC6w18xWA/OA/+ec21tdoSU49GnbmFdv6ss7dwygV+tG/O2j9Qx89BP+9uE69h8u9DqeSEDTJCbxGytzCpg8L4sPVu6kXmQ4Y/q34eZBZxHbQGvYiJygGaoSsNbvOsikT7KYvTyXyIgwRvdpza2D29EiWitSiqjcJeBtyjvElE838s7XOYSbcVVKPLef2474RlpfXkKXyl2CRva+I0z5dCNvZmTjHIzs2Yo7zm1PYtN6XkcTqXEqdwk6OwqO8uxnm5jx1TaKSkoZ3qMVY89rR/tmDbyOJlJjVO4StHYfPMZzn2/m1S+3cqy4hEu7tWTckPZ0btnQ62gi1U7lLkFv3+FCnl+wiZe/2Mqh48Vc2KU5dw5pT1L8f02OFgkaKncJGQVHinjpiy28sHAzBUeLGHteO+69sBNhYVpPXoKPT9dzF/Fn0VG1GH9BBxbcdx6jeicwed5Gxr++jGNFWrdGQleE1wFEfKVBnVr8ZWR32jSpx2Nz1rKz4CjPXptC43qRXkcTqXEauUtQMTNuP7cdk35+DpnbCxg5ZSGb9xz2OpZIjVO5S1C6PCmOGb/qS8HRIkZOWUj6ln1eRxKpUSp3CVq92jTmnTsGEhMVyc+fW0xaZq7XkURqjMpdglpi03q8ffsAesTHcOeMr5k8LwuvPiEmUpNU7hL0GtWL5NWb+zC8Rxx/nbuO+99aQVGJ9nKV4KZPy0hIqB0RzoRretC6cRQTP8kit+Aok3/Rk4Z1ankdTaRaaOQuIcPMuPeiTjz+syS+3LiXq57+kpz8o17HEqkWKncJOVenJPDyjX3IzT/KiMkLWbG9wOtIIj6ncpeQNLB9U966YwCR4WFc/eyX/Gf1Lq8jifiUyl1CVsfmDXhn7AA6NK/PLa+m89LCzV5HEvEZlbuEtGYN6jDzln6c37k5D6at5k9pqygp1UclJfCp3CXkRUVG8MyYXtw4sC0vLtzCbdMyOFJY7HUskTOichcBwsOM/03twp+GdeXjNbsYNXURuw8e8zqWyGlTuYtUcN2ARKZem8KGXYe4YvIXrN910OtIIqdF5S7yPRd0ac4bt/ansKSUK6d8wYINe7yOJHLKVO4iJ9E9Ppp3xw4kLqYu17/4FW8syfY6ksgpUbmL/IBWMXX51+396d+uCb95azlPzF2nRcckYKjcRX5Ewzq1eOH63ozqncCkeVmMn7mM48Xavk/8nxYOE6lErfAw/jKyO62bRPH4nHXsKDjK1GtTaKTt+8SPaeQuUgVmxh3ntmfi6PLt+57+gi3avk/8mMpd5BSkJsfx2s19yT9SyBVTFpKxVdv3iX9SuYucopTEb7fvG/1Pbd8n/knlLnIaTmzflxwfzZ0zvmbKp9q+T/yLyl3kNDWqF8mrN/VlWHIcj89Zx2/f1vZ94j/0aRmRM1Cn1rfb902al0VOvrbvE/9QpZG7mQ01s3VmlmVm95/k69ebWZ6ZLSv/c7Pvo4r4p7Aw49cXd+LxK7V9n/iPSsvdzMKBycAlQBdgtJl1OcmhrzvnepT/ec7HOUX83tW9E3jphm+371u6bb/XkSSEVWXk3gfIcs5tcs4VAjOB4dUbSyQwDepQtn1fnVphjHp2Ea8v2eZ1JAlRVSn3VkDFVZO2l7/3fVea2XIze9PMEk72g8zsFjNLN7P0vLy804gr4v86Nm9A2rhB9D2rMfe9tYI/vLuSwmI9aJWa5atPy6QBic65JOAj4OWTHeScm+qcS3HOpcTGxvro1CL+JyYqkhev780tg8/i1UVbGfPcYvYcOu51LAkhVSn3HKDiSDy+/L1vOOf2OudO/Mt9Dujlm3gigSsiPIzfXdqZp0b1YHlOPqkTF7B8e77XsSREVKXclwAdzKytmUUCo4BZFQ8ws5YVXg4D1vguokhgG96jFW/eNoAwM372zJe8lbHd60gSAiotd+dcMTAOmEtZab/hnFtlZg+Z2bDyw+4ys1VmlgncBVxfXYFFAlG3VtHMGjeQXq0bce+/MvlT2ipNeJJqZV5NmU5JSXHp6emenFvEK0UlpTzy/hpeXLiF/mc1YfIvetJYSwfLKTCzDOdcSmXHafkBkRpUKzyMP6Z25YmrksnYtp/UiQtYmVPgdSwJQip3EQ/8rFc8/7q1P6XO8bNnvuC9ZTmVf5PIKVC5i3gkOSGGWeMG0b1VNONnLuOR99dQrPvw4iMqdxEPxTaozfSb+3FtvzZMnb+JG15aQv6RQq9jSRBQuYt4LDIijIdHdOPRkd1ZvGkfwyYtZO3OA17HkgCnchfxE6P6tGbmrf04VlTCyClf8P6KHV5HkgCmchfxIz1bN2L2nYM4u0UD7pi+lL/OXUtJqXZ4klOnchfxM80a1mHGLf0Y1TuByfM2ctPLSyg4WuR1LAkwKncRP1Q7Ipy/jOzOn0d0Y8GGPYyYvJANuw56HUsCiMpdxE+ZGWP6tWHGLf04eKyYEZMXMnfVTq9jSYBQuYv4ud6JjUm7cyDtm9Xn1lcz+PtH6ynVfXiphMpdJAC0jK7L67f258qe8fzj4w3c8moGB4/pPrz8MJW7SICoUyucJ65K4sHULsxbt5sRkxeyMe+Q17HET6ncRQKImXH9wLZMu6kv+48UMWLSQj5Zu8vrWOKHVO4iAah/uybMGjeQ1k2iuOnldCZ+vEH34eU7VO4iASq+URRv3jaA4clx/O2j9dwxfSmHjhd7HUv8hMpdJIDVjQznyWt68MBlnflw9U5GTlnIlj2HvY4lfkDlLhLgzIybf3IWr9zYl90HjzNs0gI+W5/ndSzxmMpdJEgM6tCUWWMHERdTlxte/IqnP92IV9toivdU7iJBpHWTKN6+YwCXdG/JY3PWcs/ry/SgNURFeB1ARHwrKjKCSaPPoUOz+kz4zwbaNq3P+As6eB1LapjKXSQImRnjz+/Atr1HmPDxepISojmvUzOvY0kN0m0ZkSBlZvzfFd3p1LwBd89cRva+I15HkhqkchcJYnUjw3n22l6UOsdt0zI4VlTidSSpISp3kSDXpkk9JlzTg1W5B3jg3ZX6BE2IULmLhIDzOzfnriHteTNjOzO+yvY6jtQAlbtIiBh/QUd+2jGWB2etYll2vtdxpJqp3EVCRHiY8dSoHjRrWJs7pmWw99BxryNJNVK5i4SQmKhInhnTiz2HC7lr5teUaIJT0FK5i4SYbq2i+fOIbizM2ssTH67zOo5UE5W7SAi6OiWB0X1a8/SnG5mzUptuByOVu0iIenBYF5Ljo/n1vzLZpO36go7KXSRE1Y4IZ8qYXtQKN26blsFhbfQRVKpU7mY21MzWmVmWmd3/I8ddaWbOzFJ8F1FEqkurmLpMHN2TrN2HuO+t5ZrgFEQqLXczCwcmA5cAXYDRZtblJMc1AMYDi30dUkSqz6AOTfn1xZ2YvXwHLy7c4nUc8ZGqjNz7AFnOuU3OuUJgJjD8JMc9DDwGHPNhPhGpAbf/tB0XdWnOI++v4avN+7yOIz5QlXJvBVScr7y9/L1vmFlPIME5928fZhORGmJmPHF1MgmNoxj72lJ2H9AYLdCd8QNVMwsD/g7cW4VjbzGzdDNLz8vTHo8i/qRhnVo8M6YXh44VM/a1pRSVlHodSc5AVco9B0io8Dq+/L0TGgDdgE/NbAvQD5h1soeqzrmpzrkU51xKbGzs6acWkWrRqUUDHr2yO0u27OeR99d4HUfOQFXKfQnQwczamlkkMAqYdeKLzrkC51xT51yicy4RWAQMc86lV0tiEalWw3u04oaBiby4cAvvLcup/BvEL1Va7s65YmAcMBdYA7zhnFtlZg+Z2bDqDigiNe93l3YmpU0j7n9rBet2HvQ6jpwG8+pzrSkpKS49XYN7EX+1+8AxLpu4gPq1I3hv3EAa1qnldSQBzCzDOVfpXCLNUBWRk2rWsA6Tf96T7H1HuPeNTEq1gmRAUbmLyA/q07Yxv7u0Mx+t3sUz8zd6HUdOgcpdRH7UDQMTSU2O44m561iwYY/XcaSKVO4i8qPMjEdHdqd9s/rcNfNrcvKPeh1JqkDlLiKVqlc7gmfG9KKwuJQ7pmVwvLjE60hSCZW7iFTJWbH1eeKqZDK3F/DgrNVex5FKqNxFpMqGdmvB7ee2Y8ZX23hjSXbl3yCeUbmLyCm598KODGzfhAfeW8nKnAKv48gPULmLyCmJCA/jH6POoUm9SG6blsH+w4VeR5KTULmLyClrUr82T4/pxe4Dxxn/+jJKNMHJ76jcReS09EiI4cFhXZm/Po+nPt7gdRz5HpW7iJy20X0SuKpXPP/4eAMfr9nldRypQOUuIqfNzHh4RDe6xjXknteXsXXvYa8jSTmVu4ickTq1wnlmTC/MjNumLeVooSY4+QOVu4icsYTGUUwY1YO1Ow/w+3dW4NVS4vItlbuI+MR5nZpx9/kdefvrHKYt2up1nJCnchcRn7lzSHvO6xTLQ7NXk7F1v9dxQprKXUR8JizMmHDNObSMrssd0zPIO3jc60ghS+UuIj4VHVWLp8f0JP9IEXfOWEpxSanXkUKSyl1EfK5rXDSPXNGdRZv28de567yOE5IivA4gIsHpyl7xLMvO59n5m+jaKpphyXFeRwopGrmLSLX5w+Vd6Nk6hvEzv+Yv76+hsFi3aGqKyl1Eqk1kRBjTb+7H6D6teXb+Jq6YspCs3Ye8jhUSVO4iUq3qRobzyBXdmXptL3Lzj3L5xM+ZvnirJjpVM5W7iNSIi7q2YO7dg+md2Jjfv7OSX72Swd5D+qhkdVG5i0iNadawDi/f0IcHLuvM/PV5DH3qc+avz/M6VlBSuYtIjQoLM27+yVm8O3YgMXVr8csXvuLh2as5VqQFx3xJ5S4inugS15C0OwdxXf82PL9gMyMmL2T9roNexwoaKncR8UydWuH8aXg3Xrg+hT2HjpM6cQEvf7FFD1t9QOUuIp4bcnZzPhg/mP7tmvDHWau48aUlWpfmDKncRcQvxDaozYvX9+ZPw7qycONeLnlqPvPW7vY6VsBSuYuI3zAzrhuQSNq4QTStX5sbXlrCH99bqYetp0HlLiJ+p1OLBrw7diA3DmzLy19uZdikBazZccDrWAFF5S4ifqlOrXD+N7ULL9/Yh/1Hihg+aSHPL9hMaaketlZFlcrdzIaa2TozyzKz+0/y9dvMbIWZLTOzBWbWxfdRRSQU/bRjLHPG/4TBHZvy8OzVXPfiV+w+cMzrWH6v0nI3s3BgMnAJ0AUYfZLyfs0519051wN4HPi7z5OKSMhqUr82//xlCn8e0Y0lW/Yx9KnP+Wj1Lq9j+bWqjNz7AFnOuU3OuUJgJjC84gHOuYo3w+oB+r1JRHzKzBjTrw2z7xxEi4Z1+NUr6fz+nRUcLdTD1pOpSrm3ArIrvN5e/t53mNlYM9tI2cj9Lt/EExH5rvbNGvDO2AHcMvgspi/exmUTP2dlToHXsfyOzx6oOucmO+faAfcBD5zsGDO7xczSzSw9L0+LBYnI6akdEc7vLu3MtJv6cvh4MVdMWcjU+Rv1sLWCqpR7DpBQ4XV8+Xs/ZCYw4mRfcM5Ndc6lOOdSYmNjq55SROQkBnVoypzxgxlydjMeeX8tY55fzM4CPWyFqpX7EqCDmbU1s0hgFDCr4gFm1qHCy8uADb6LKCLywxrVi+SZMb14dGR3vt6Wz8UT5jNn5Q6vY3mu0nJ3zhUD44C5wBrgDefcKjN7yMyGlR82zsxWmdky4H+A66otsYjI95gZo/q05t93DaJ14yhum7aU+95czuHjxV5H84x5tfpaSkqKS09P9+TcIhK8CotLefI/63nms40kNqnHhGt6kJwQ43UsnzGzDOdcSmXHaYaqiASVyIgw7ht6Nq/d3I9jRSVc+fQXTJ6XRUmIPWxVuYtIUOrfrglzxg/m4q4t+OvcdYz+5yJy8o96HavGqNxFJGhFR9Vi0s/P4a8/S2JVTgEXPzmfGV9tC4nNQFTuIhLUzIyrUhKYc/dgurVqyG/fXsEvX/gq6EfxKncRCQkJjaN47eZ+PDyiGxlb93Pxk/N5bXHwjuJV7iISMsLCjGv7tWHu3YPp3iqa372zgmuf/4rt+494Hc3nVO4iEnISGkcx/ea+PDyiG0u37WfohM+DbhSvcheRkFRxFJ8UH3yjeJW7iIS0hMZRTLupL38e0Y2vt5Xdi5++eGvAj+JV7iIS8sLCytaKn3P3YHq0juH376xkzPOLA3oUr3IXESl3YhT/f1d0Y9m2fC5+cj7TFgXmKF7lLiJSgZnxi77fjuIfeHclv3huMdn7AmsUr3IXETmJiqP4zOx8hk4oG8UHyoYgKncRkR9wYhQ/957BnNO6EQ+8W3YvPhBG8Sp3EZFKxDeK4tWb+vCXkd1Zvr2AiyfM51U/H8Wr3EVEqsDMGN2nNXPvGUyvNo34g5/fi1e5i4icglYxdXnlxrJR/Iqc8lH8l1v8bhSvchcROUX/NYp/b5XfjeJV7iIip+nEKP5RPxzFq9xFRM7Aic25K47if/7cIrbt9XYUr3IXEfGBE6P4x67szqqcAwx9aj6vfOndKF7lLiLiI2bGNb3LRvEpiY353/dWMfqf3oziVe4iIj4WF1OXl2/ozWNXdmd17gEunjCfl7+o2VG8yl1EpBpUHMX3aduYP84qG8Vv3Xu4Rs6vchcRqUZxMXV56YbePH5lEqtzDzB0wuekZeZW+3lV7iIi1czMuLp3Ah/+z2AGtm9C26b1qv2cEdV+BhERAaBldF2eu653jZxLI3cRkSCkchcRCUIqdxGRIKRyFxEJQip3EZEgpHIXEQlCKncRkSCkchcRCULmnDfLUZpZHrD1NL+9KbDHh3ECna7Hd+l6fEvX4ruC4Xq0cc7FVnaQZ+V+Jsws3TmX4nUOf6Hr8V26Ht/StfiuULoeui0jIhKEVO4iIkEoUMt9qtcB/Iyux3fpenxL1+K7QuZ6BOQ9dxER+XGBOnIXEZEfEXDlbmZDzWydmWWZ2f1e5/GKmSWY2TwzW21mq8xsvNeZ/IGZhZvZ12Y22+ssXjOzGDN708zWmtkaM+vvdSavmNk95f9PVprZDDOr43Wm6hZQ5W5m4cBk4BKgCzDazLp4m8ozxcC9zrkuQD9gbAhfi4rGA2u8DuEnngLmOOfOBpIJ0etiZq2Au4AU51w3IBwY5W2q6hdQ5Q70AbKcc5ucc4XATGC4x5k84Zzb4ZxbWv73g5T9x23lbSpvmVk8cBnwnNdZvGZm0cBg4HkA51yhcy7f21SeigDqmlkEEAVU/yamHgu0cm8FZFd4vZ0QLzQAM0sEzgEWe5vEcxOA3wClXgfxA22BPODF8ttUz5lZ9W/c6YeccznAE8A2YAdQ4Jz70NtU1S/Qyl2+x8zqA28BdzvnDnidxytmdjmw2zmX4XUWPxEB9ASeds6dAxwGQvIZlZk1ouw3/LZAHFDPzMZ4m6r6BVq55wAJFV7Hl78XksysFmXFPt0597bXeTw2EBhmZlsou103xMymeRvJU9uB7c65E7/NvUlZ2YeiC4DNzrk851wR8DYwwONM1S7Qyn0J0MHM2ppZJGUPRWZ5nMkTZmaU3U9d45z7u9d5vOac+61zLt45l0jZv4tPnHNBPzr7Ic65nUC2mXUqf+t8YLWHkby0DehnZlHl/2/OJwQeLkd4HeBUOOeKzWwcMJeyJ94vOOdWeRzLKwOBa4EVZras/L3fOefe9zCT+Jc7genlA6FNwA0e5/GEc26xmb0JLKXsU2ZfEwIzVTVDVUQkCAXabRkREakClbuISBBSuYuIBCGVu4hIEFK5i4gEIZW7iEgQUrmLiAQhlbuISBD6/yzdG3f0PeNDAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_stacked_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 83.33%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }