{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Stacked Bi-directional Long Short-Term Memory with Drop out.\n", "\n", "### Many to One Classification by Stacked Bi-directional LSTM with Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n", "- Applying **Stacking** and **dynamic rnn** to model by `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://pozalabs.github.io/blstm/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharBiLSTM class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharStackedBiLSTM:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dims, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " self._keep_prob = tf.placeholder(dtype = tf.float32)\n", " \n", " # Stacked Bi-directional LSTM with Drop out\n", " with tf.variable_scope('stacked_bi-directional_lstm'):\n", " \n", " # forward \n", " lstm_fw_cells = []\n", " for hidden_dim in hidden_dims:\n", " lstm_fw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(cell = lstm_fw_cell,\n", " output_keep_prob = self._keep_prob)\n", " lstm_fw_cells.append(lstm_fw_cell)\n", " \n", " # backword\n", " lstm_bw_cells = []\n", " for hidden_dim in hidden_dims:\n", " lstm_bw_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(cell = lstm_bw_cell,\n", " output_keep_prob = self._keep_prob)\n", " lstm_bw_cells.append(lstm_bw_cell)\n", " \n", " _, output_state_fw, output_state_bw = \\\n", " tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw = lstm_fw_cells, cells_bw = lstm_bw_cells,\n", " inputs = self._X_batch,\n", " sequence_length = self._X_length,\n", " dtype = tf.float32)\n", "\n", " final_state = tf.concat([output_state_fw[-1], output_state_bw[-1]], axis = 1)\n", "\n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharStackedBiLSTM" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_stacked_bi_lstm = CharStackedBiLSTM(X_length = X_length_mb, X_indices = X_indices_mb, \n", " y = y_mb, n_of_classes = 2, hidden_dims = [16,16], dic = char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_stacked_bi_lstm.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.823\n", "epoch : 2, tr_loss : 0.698\n", "epoch : 3, tr_loss : 0.611\n", "epoch : 4, tr_loss : 0.325\n", "epoch : 5, tr_loss : 0.260\n", "epoch : 6, tr_loss : 0.218\n", "epoch : 7, tr_loss : 0.254\n", "epoch : 8, tr_loss : 0.115\n", "epoch : 9, tr_loss : 0.129\n", "epoch : 10, tr_loss : 0.081\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_stacked_bi_lstm.ce_loss],\n", " feed_dict = {char_stacked_bi_lstm._keep_prob : .5})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x1146bac88>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xt4VfWd7/H3NzcCBMIlgQAJSZSbXBRlE5GbtioC7UDPsSI4VEUUe7FVa6djT2fmTJ12TjudqdqpY0WkVWuLitaiorS1iDduAZSrYASEhGDCLVxz/54/EjUgmg3ssLL3/ryeJ49Za//Y+9P9lM/zY631W8vcHRERiS0JQQcQEZHIU7mLiMQglbuISAxSuYuIxCCVu4hIDFK5i4jEIJW7iEgMUrmLiMQglbuISAxKCuqDMzIyPC8vL6iPFxGJSqtWrdrj7pnNjQus3PPy8igsLAzq40VEopKZfRDOOB2WERGJQSp3EZEYpHIXEYlBKncRkRikchcRiUEqdxGRGKRyFxGJQVFX7tv3HOFnL79Lfb0eDygi8lmirtz/vHE3D776Pv/4zFrqVPAiIicV2ArV03XLmHM4UlXH/a+8R507P//qBSQmWNCxRERalagrdzPjziv7kZhg/OIvW6ird/7rmgtISoy6f4SIiLSYsBrRzMab2WYzKzKzu0/yem8zW2xma8xsrZlNjHzU433n8r78w1X9+dPbu7j9ybepqatv6Y8UEYkazc7czSwReAC4EigGVprZAnff2GTYPwFPufuDZjYQWAjktUDe43zrC31ITjT+fWHDCdZfTruQZM3gRUTCmrkXAEXuvtXdq4F5wOQTxjjQsfH3dGBX5CJ+vlljz+WfvzyQl9bv5ltPrKa6VjN4EZFwyr0XsLPJdnHjvqb+FZhuZsU0zNq/fbI3MrNZZlZoZoXl5eWnEffkZo7O50eTBvHnjR/yzSdWUVVbF7H3FhGJRpE6hjEN+K27ZwMTgcfN7FPv7e6z3T3k7qHMzGbvNX9KbhiZx4+/Mpi/birj1sdXUVmjgheR+BVOuZcAOU22sxv3NTUTeArA3ZcCqUBGJAKeiukjcvnp/x7Cki3l3PJYoQpeROJWOOW+EuhrZvlmlgJMBRacMGYHcDmAmZ1HQ7lH7rjLKZha0JufXX0+bxTtYeajKzlWrYIXkfjTbLm7ey1wG7AI2ETDVTEbzOweM5vUOOwu4BYzewf4A3Cjuwe2fHRKKIf/uuYClr6/lxm/XcGRqtqgooiIBMKC6uBQKOQt/QzVP71dwp1Pvk0otwtzZwwnrU3UrdkSETmOma1y91Bz42L6ovDJQ3vxy2kXsmrHfm6Yu4JDlTVBRxIROStiutwBvnx+T3417ULe2XmArz2ygoMqeBGJAzFf7gAThvTgf/7+IjbsqmD6nOVUHFXBi0hsi4tyBxg3KItfTx/Gu6WHuG7OMvYfqQ46kohIi4mbcge4/LzuPHT9MN4rO8x1c5azTwUvIjEqrsod4Av9uzHn+hBbyw9z3cPL2HO4KuhIIiIRF3flDjC2XyZzbxzO9r1HmDZ7GeWHVPAiElvistwBRvXJ4Dc3FlBy4BhTZy+l7GBl0JFERCImbssd4JJzu/LbGQXsrqjk2tnL2F2hgheR2BDX5Q5QkN+Fx2YWUH6oimtnL2XXgWNBRxIROWNxX+4Aw3K78PjMAvYdruba2Usp3n806EgiImdE5d7owt6deeKWi6k4WsO1Dy1jx14VvIhEL5V7E+dnd+L3t4zgSHUtU2cvZfueI0FHEhE5LSr3Ewzulc7vbx7BsZo6ps5extbyw0FHEhE5ZSr3kxjYsyN/mDWCmrp6ps5eRlGZCl5EoovK/TMMyOrIvFkjqHeYOnsZWz48FHQkEZGwqdw/R9/uHZg3awQJBtNmL+Pd3QeDjiQiEpawyt3MxpvZZjMrMrO7T/L6vWb2duPPFjM7EPmowejTLY0nb72E5MQEps1exsZdKngRaf2aLXczSwQeACYAA4FpZjaw6Rh3v9Pdh7r7UOC/gWdbImxQ8jPa8+StI2ibnMh1c5axvqQi6EgiIp8rnJl7AVDk7lvdvRqYB0z+nPHTaHhIdkzJ7dqeJ2+9hPYpSVz38DLe2Rkz/zgRkRgUTrn3AnY22S5u3PcpZpYL5AN/O/NorU9Ol3Y8eesI0tslM33Oclbv2B90JBGRk4r0CdWpwHx3rzvZi2Y2y8wKzaywvLw8wh99dmR3bseTsy6hS1oKU2cv4+HXtlJX70HHEhE5TjjlXgLkNNnObtx3MlP5nEMy7j7b3UPuHsrMzAw/ZSvTs1NbnvnGSC7tl8lPFm5i2sPL2LlPtysQkdYjnHJfCfQ1s3wzS6GhwBecOMjMBgCdgaWRjdg6ZaS1YfbXhvGf11zApl0HGX/fazy5cgfumsWLSPCaLXd3rwVuAxYBm4Cn3H2Dmd1jZpOaDJ0KzPM4ajcz46vDsnn5zrFckNOJf3xmHTMfLdSDP0QkcBZUF4dCIS8sLAzks1tCfb3z6NLt/PSld2mbkshPvjKEL53fI+hYIhJjzGyVu4eaG6cVqhGSkGDMGJXPi98ZQ26Xdnzr96u5fd4aKo7WBB1NROKQyj3C+nRL45lvjOS7V/bjxbWljLtvCUu2ROeVQSISvVTuLSApMYHvXN6X5741io6pydwwdwU//OM6jlTVBh1NROKEyr0FDe6VzvPfHs0tY/L5/YodTPzl6xRu3xd0LBGJAyr3FpaanMgPvzSQebeMoK7emfLQUn760rtU1Z50nZeISESo3M+Si8/pyst3jGVKKIdfL3mfyb96U3eYFJEWo3I/i9LaJPHTq89n7o0h9h6pZvIDb/DA4iJq6+qDjiYiMUblHoAvDujOn+8Yy7iBWfx80WamPLSUbXoYt4hEkMo9IJ3bp/Cr6y7k/qlDKSo7zMT7X+fxpdt1+wIRiQiVe4DMjMlDe/HnOy9leH4X/vlPG7h+7gpKK44FHU1EopzKvRXISk/l0RnD+fFXBlO4fT/j7n2N59aUaBYvIqdN5d5KmBnTR+Ty0u1j6Ne9A3c8+TbffGI1ew9XBR1NRKKQyr2Vyctoz1O3XsLdEwbwyqYyrrrvNf6y8cOgY4lIlFG5t0KJCcbXLz2XP902ioy0NtzyWCHfn/8Ohyp1EzIRCY/KvRU7r0dHFtw2mm994Vzmrypm/H2vs/T9vUHHEpEooHJv5VKSEviHqwbw9NdHkpxoTHt4Gfc8v5HKGt2+QEQ+m8o9SgzL7czC28dw/SW5zH1zG1/65eusLT4QdCwRaaVU7lGkXUoS90wezOMzCzhSVcf/+p+3uPcvW6jR7QtE5ARhlbuZjTezzWZWZGZ3f8aYKWa20cw2mNnvIxtTmhrTN5NFd4xl0gU9uf+V97j7mXVBRxKRVqbZcjezROABYAIwEJhmZgNPGNMX+AEwyt0HAXe0QFZpIr1dMvdeO5QZo/J47u0SSg5oVauIfCKcmXsBUOTuW929GpgHTD5hzC3AA+6+H8DdyyIbUz7LLWPOAeDRt7YHG0REWpVwyr0XsLPJdnHjvqb6Af3M7E0zW2Zm4yMVUD5fz05tmTikB39YvoPDeoyfiDSK1AnVJKAvcBkwDXjYzDqdOMjMZplZoZkVlpfrodGRMnN0Poeqanm6cGfzg0UkLoRT7iVATpPt7MZ9TRUDC9y9xt23AVtoKPvjuPtsdw+5eygzM/N0M8sJhuZ0IpTbmblvbqOuXjcbE5Hwyn0l0NfM8s0sBZgKLDhhzHM0zNoxswwaDtNsjWBOacbNY/LZue8Yf9m4O+goItIKNFvu7l4L3AYsAjYBT7n7BjO7x8wmNQ5bBOw1s43AYuAf3F3r5M+iKwdmkdOlLY+8sS3oKCLSCiSFM8jdFwILT9j3L01+d+C7jT8SgMQEY8bIfO55YSNv7zzA0JxPnfIQkTiiFaoxZMrwHDq0SdLsXURU7rEkrU0SUwtyWLiulF1a1CQS11TuMeaGkXmAFjWJxDuVe4zJ7tyOCYOz+P2KHRzRoiaRuKVyj0EzR+dzqFKLmkTimco9Bl3YuzPDcjsz983tWtQkEqdU7jHq5tH57Nh3lL9u0sO1ReKRyj1GjRuURXbntjzyui6LFIlHKvcYlZhgzBiVz4rt+/Q4PpE4pHKPYVNC2VrUJBKnVO4xrENqMtcOz+HFtVrUJBJvVO4x7sZRedS78+jS7UFHEZGzSOUe4xoWNTU8qUmLmkTih8o9Dswck8/BylrmryoOOoqInCUq9zhwUe/OXNS7k57UJBJHVO5xYuboc/hg71Fe0aImkbigco8TVw3qTq9ObZmjyyJF4oLKPU4kJSYwY1QeK7btY11xRdBxRKSFhVXuZjbezDabWZGZ3X2S1280s3Ize7vx5+bIR5Uzde3wHNLaJPHIG3p2uUisa7bczSwReACYAAwEppnZwJMMfdLdhzb+zIlwTomAjxY1vbC2lN0VlUHHEZEWFM7MvQAocvet7l4NzAMmt2wsaSk3jtSiJpF4EE659wKaPvWhuHHfia42s7VmNt/Mck72RmY2y8wKzaywvLz8NOLKmcrp0o7xg7N4YtkHWtQkEsMidUL1eSDP3c8H/gI8erJB7j7b3UPuHsrMzIzQR8upmjn6HA5W1vLMai1qEolV4ZR7CdB0Jp7duO9j7r7X3asaN+cAwyITT1rCsNzOXNi7E3Pf2Ea9FjWJxKRwyn0l0NfM8s0sBZgKLGg6wMx6NNmcBGyKXERpCTNH57N971Feebcs6Cgi0gKaLXd3rwVuAxbRUNpPufsGM7vHzCY1DvuOmW0ws3eA7wA3tlRgiYzxg7IaFjW9rssiRWJRUjiD3H0hsPCEff/S5PcfAD+IbDRpSUmJCdw4Mo+fLNzE+pIKBvdKDzqSiESQVqjGsWsLcmifkqgnNYnEIJV7HOuYmsy1w3vz/Du7tKhJJMao3OPcjMYnNT22dHvQUUQkglTucS6nSzuuGpTFE8t3cLRai5pEYoXKXbh5TD4Vx2p4ZnVJ84NFJCqo3IWLenfmghwtahKJJSp3wcy4eXQ+2/Yc4W9a1CQSE1TuAsCEwQ2LmnRZpEhsULkL0LCo6YaRuSzdupf1JXpSk0i0U7nLx64d3pv2KYnM1exdJOqp3OVj6W2TuSaUw/Nrd/HhQS1qEolmKnc5zk2j8qmt16ImkWincpfj9O7ajqsGNixqOlZdF3QcETlNKnf5lJlj8jlwtEZPahKJYip3+ZRQbmcuyE7XoiaRKKZyl08xM2aOOYete47w6hYtahKJRip3OakJg7PokZ7KnNd1WaRINAqr3M1svJltNrMiM7v7c8ZdbWZuZqHIRZQgJDc+qemt9/eyYZcWNYlEm2bL3cwSgQeACcBAYJqZDTzJuA7A7cDySIeUYEwt6E27lETmvrE96CgicorCmbkXAEXuvtXdq4F5wOSTjPs34GeAVr/EiPS2yUwJ5bDgnRLKtKhJJKqEU+69gJ1Ntosb933MzC4Cctz9xQhmk1Zgxqi8xkVNHwQdRUROwRmfUDWzBOAXwF1hjJ1lZoVmVlheXn6mHy1nQW7X9owb2J0nln+gRU0iUSScci8BcppsZzfu+0gHYDDwqpltB0YAC052UtXdZ7t7yN1DmZmZp59azqqZo89h/9Eanl2jRU0i0SKccl8J9DWzfDNLAaYCCz560d0r3D3D3fPcPQ9YBkxy98IWSSxn3fC8zpyfnc4jWtQkEjWaLXd3rwVuAxYBm4Cn3H2Dmd1jZpNaOqAEz8yYOTqfreVHWLJFh9NEooG5BzMTC4VCXlioyX20qKmrZ8zPFnNut/Y8cfOIoOOIxC0zW+Xuza4l0gpVCUtyYgI3jsrjzaK9bCo9GHQcEWmGyl3CNm14b9omJ+o5qyJRQOUuYUtvl8yUUDYL3t5F2SEtahJpzVTuckpmjMqnpr6e32lRk0irpnKXU5KX0Z4rzuvO48s+oLJGi5pEWiuVu5yym0fnNyxqWl3S/GARCYTKXU5ZQX4XhvRKZ+6bWtQk0lqp3OWUfbSoqajsMEve06ImkdZI5S6nZeKQHmR1TOURPalJpFVSuctpSUlK4IaRebxRtId3d2tRk0hro3KX03ZdQeOiJs3eRVodlbuctvR2yVwTyuZPWtQk0uqo3OWMfLyoadmOoKOISBMqdzkj+RntuXxAd36nRU0irYrKXc7YzWPy2Xekmh/+cb0epC3SSqjc5YxdnN+Fm0bl89zbJYz9+WL+feEm9h6uCjqWSFzTwzokYj7Ye4T7X3mP59aU0DY5kRtH5TFrzLmkt0sOOppIzAj3YR0qd4m4orLD3PfXLbywtpQOqUncPPocbhqdR4dUlbzImYrok5jMbLyZbTazIjO7+ySvf93M1pnZ22b2hpkNPJ3QEhv6dEvjV9ddxMt3jGHkuV25969bGPMfi3nw1fc5Wl0bdDyRuNDszN3MEoEtwJVAMbASmObuG5uM6ejuBxt/nwR8093Hf977auYeP9YVV/CLv2xm8eZyMtJS+Pql5zJ9RC6pyYlBRxOJOpGcuRcARe6+1d2rgXnA5KYDPir2Ru0B3SpQPjYkO53fzCjgmW+MpH9WB3784iYu/fliHl+6napaXT4p0hLCKfdewM4m28WN+45jZt8ys/eB/wC+E5l4EkuG5XbmiZtHMG/WCHK7tOef/7SBL/7nEuat2EFNXX3Q8URiSsQuhXT3B9z9XOAfgX862Rgzm2VmhWZWWF6uW8XGqxHndOXJW0fw2E0FZHRow93PruOKXyzh2dXF1On+8CIREU65lwA5TbazG/d9lnnAV072grvPdveQu4cyMzPDTykxx8wY2y+T5745kkduCNE+JYnvPvUO4+5dwvPv7NJDQETOUDjlvhLoa2b5ZpYCTAUWNB1gZn2bbH4JeC9yESWWmRmXn9edF749mgf//iISE4xv/2ENE3/5Oos27CaoS3VFol1ScwPcvdbMbgMWAYnAXHffYGb3AIXuvgC4zcyuAGqA/cANLRlaYk9CgjFhSA/GDcrihbW7uO+v73Hr46sY0iud747rx2X9MjGzoGOKRA0tYpJWqbaunj+uKeH+V96jeP8xLurdibvG9WfkuV1V8hLXtEJVYkJ1bT1Pr9rJr/5WRGlFJRfnd+Gucf0pyO8SdDSRQKjcJaZU1tTxhxU7eGDx++w5XMWYvhncNa4/Q3M6BR1N5KxSuUtMOlZdx+PLtvPgq++z/2gNV5zXjTuv7MegnulBRxM5K1TuEtMOV9Xy2ze3Mfu1rRysrGXC4CzuvLIf/bp3CDqaSItSuUtcqDhWwyOvb2Xum9s5Ul3L353fk2kFvSnI70Jigk68SuxRuUtc2X+kmode28pjS7dztLqOjLQUxg3K4ktDenBxfheSEuPvuTTujnvDZaYSO1TuEpeOVtfy6uZyFq4r5W/vlnG0uo7O7ZK5alAWE4b0YOS5XUmO4aI/UlXLm0V7eHVLOUs2l1NX7zx16yX07tou6GgSISp3iXvHqutYsqWcl9aX8sqmMg5X1ZLeNplxA7szcUgPRvXJICUpuove3SkqO8yrm8t5dUsZK7bto6bOaZ+SyMg+GazYto/O7ZKZ/42RZKS1CTquRIDKXaSJypo63nhvDwvXlfKXTR9yqLKWDqlJXHledyYM6cGYvhlRc3/5E2fnJQeOAdCvexqX9e/GZf0yCeV1ISUpgVUf7Ofv5yyjT7c0/nDLCD0NKwao3EU+Q1VtHW8V7WXhulL+vPFDKo7VkNYmicvP68aEwT24rH9mqyp6d+e9ssO8urmMVzeXs3L7J7PzUX0yuKx/Ny7tn0mvTm1P+ucXv1vGzY8VcnF+F34zYzhtklrP/zY5dSp3kTDU1NXz1vt7eWldKYs27Gb/0RrapSTyxQHdmDikoejbpTR7C6aIO/zR7HxzOa9t+WR23r97By7rn8ml/TMJ5XYJ+7DSs6uL+e5T7zBxSBb/Pe0iXUkUxVTuIqeotq6e5dv28eK6Uhat383eI9WkJifwhf4NRf/FAd1o36Zliv6j2fnidxtm54UfNMzO09okMapP14bZeb9Men7G7DwcD7+2lZ8s3MT0Eb35t8mDdY+eKKVyFzkDdfXOim37eGl9KS+t3035oSraJCVwab9MJg7pweXndTvj49dNZ+dLNpexq6IS+GR2fln/bgzL7RzRk77/b+EmHnptK3dc0Zc7rugXsfeVs0flLhIhdfXOqg/2s3BdKS+v383ug5WkJCYwtl8GEwb34IqB3Ulv23zRuztbPvzk2HlLzM7DyfC9p9fyzOpifvyVwUwfkdtinyUtQ+Uu0gLq6501O/ezcN1uXlpXyq6KSpITjVF9Mpg4uAfjBnWnU7uUj8cfrqrljff2sGRLGUs2l388Ox+Q1YFL+2dyWb/Iz86bU1NXz62Pr2Lx5jIeuO4iJg7pcdY+W86cyl2khbk77xRX8NK6Ul5cV0rx/mMkJRiXnNuVC3M6sXL7flZu30dtfcPsfHSfjI9PhvZIb7nZeTiOVdcx/ZHlrCuu4LczhjOyT0ageSR8KneRs8jdWV9ykIXrS1m4rpQP9h5lQFaHhuvO+2cyLLdzq1sZW3G0hmseeouS/cd48tZLGNxLd9aMBhEtdzMbD9xPw2P25rj7T094/bvAzUAtUA7c5O4ffN57qtwlVrk7R6rrSGuhK2siaXdFJVc/+BZVtXXM//pI8jLaBx1JmhFuuTc7lTCzROABYAIwEJhmZgNPGLYGCLn7+cB84D9OPbJIbDCzqCh2gKz0VB69qYC6eudrc5dTdrAy6EgSIeH8O7EAKHL3re5eDcwDJjcd4O6L3f1o4+YyIDuyMUWkpfTplsZvZhSw93A1N/xmJQcra4KOJBEQTrn3AnY22S5u3PdZZgIvnUkoETm7huZ04tfTh/Heh4e45dFCKmvqgo4kZyiiZ3jMbDoQAn7+Ga/PMrNCMyssLy+P5EeLyBka2y+T/5pyAcu37eP2eWuoqw/mYguJjHDKvQTIabKd3bjvOGZ2BfBDYJK7V53sjdx9truH3D2UmZl5OnlFpAVNHtqL//t3A1m04UP+6bl1BHU1nZy5cM76rAT6mlk+DaU+Fbiu6QAzuxB4CBjv7mURTykiZ82MUfnsOVzFA4vfp2v7Nnzvqv5BR5LT0Gy5u3utmd0GLKLhUsi57r7BzO4BCt19AQ2HYdKApxtvRrTD3Se1YG4RaUHfG9efvYer+dXiIjLSUrhxVH7QkeQUhXW9lrsvBBaesO9fmvx+RYRziUiAzIwff2Uw+45U86MXNtIlrQ2TLugZdCw5Ba1ryZyItBpJiQn8ctqFDM/twl1Pvc3r7+kiiGiicheRz5SanMjDN4Q4NzONWx9fxTs7DwQdScKkcheRz5XeNpnHbiqga1oKM367kvfLDwcdScKgcheRZnXrmMpjN12MAdc/soLdFbpNQWuncheRsORntOe3Mwo4cLSaG+auoOKoblPQmqncRSRsQ7LTmX19iG17jjDz0ZUcq9ZtClorlbuInJJRfTK499qhrNqxn9t+v5rauvqgI8lJqNxF5JR96fwe3DN5MK+8W8YPntVtClqj6LjptIi0Ol8bkcueQ1Xc/8p7dE1rw90TBgQdSZpQuYvIabvjir7sOVzFr5e8T0ZaCjePOSfoSNJI5S4ip83MuGfyYPYfrebHL26iS/sU/vdFelZPa6Bj7iJyRhITjHuvHcol53Tl+/PXsnizbgzbGqjcReSMtUlKZPb1w+if1YFv/m41q3fsDzpS3LOgznKHQiEvLCwM5LNFpGWUH6riq79+i4pjNcz/+iX06dbhrGeorKlj14FjlFZUfvzf0opj7K6opE+3NK4J5dCv+9nPFSlmtsrdQ82OU7mLSCTt2HuUq3/9FkkJxjPfGEnPTm0j9t7VtfV8ePCT0t5VcYzSAw3lvavxv/tPsnK2a/sUMju0oajsMLX1ztCcTlwTyubvLuhJx9TkiOU7G1TuIhKYjbsOcu1DS+mensrTt15C5/Ypzf6Zunqn7FDlxyVdeuCE8q6oZM/hKk6srPS2yfRIT2346dSWnump9EhvS49OqfRMb0tWeiqpyYkA7DlcxXNrSniqcCdbPjxManICEwb34Jph2Yw4pysJCdYSX0dEqdxFJFDLtu7l+rkrGNSzI7+beTFHq+uOm2Efd9jkwDE+PFT1qYdyt09JpEentvRIbyjqjwq7R6fGAk9PpX2bU7/oz91ZW1zBU4U7WfDOLg5V1pLduS3XDMvh6mG9yO7cLlJfQ8Sp3EUkcC+v3803n1iFw6dm3G2SEhpn3MeXdtPy7piaROOjO1tMZU0dizbs5qnCnbxZtBczGN0ng68Oy+aqQVkfz/pbi4iWu5mNB+6n4Rmqc9z9pye8Pha4DzgfmOru85t7T5W7SHxY/G4Zy7ft+/jQSc/GmXiX9iktXtynaue+ozyzupinC4spOXCMjqlJTB7aiymhHAb36tgq8kas3M0sEdgCXAkUAyuBae6+scmYPKAj8D1ggcpdRKJZfb2zdOteni7cyUvrd1NVW8+ArA5cE8rhK0N70jWtTWDZwi33cA5WFQBF7r618Y3nAZOBj8vd3bc3vqbbw4lI1EtIMEb1yWBUnwx+dKyG59/ZxdOrivm3Fzby05c2cfmA7kwZns3YvpkkJbbO5ULhlHsvYGeT7WLg4paJIyLSuqS3TWb6iFymj8hl8+5DPF24kz+uKeHlDbvp1qENVw/L5pph2ZyTmRZ01OOc1XvLmNksYBZA7969z+ZHi4icsf5ZHfinLw/k++MHsHhzGU8X7mT2a1t58NX3CeV2Zkooh4nn9yDtNK7gibRwEpQAOU22sxv3nTJ3nw3MhoZj7qfzHiIiQUtJSuCqQVlcNSiLsoOVPLumhKcLd/L9Z9byr89vYOKQHkwJ5TA8r3NgJ2HDKfeVQF8zy6eh1KcC17VoKhGRKNGtYypfv/Rcbh17Dqt3HODpwp28sLaU+auKyevajmtCOVx9UTZZ6alnNVe4l0JOpOFSx0Rgrrv/xMzuAQrdfYGZDQf+CHQGKoHd7j7o895TV8uISKw6Wl3LS+sarp1fvm0fCQZj+2UyJZTD5ed1o03S6V87r0VMIiKtwAd7jzB/VTHzVxVTWlFJp3bJ/GjI8110AAADVklEQVTSICYP7XVa7xfJSyFFROQ05XZtz13j+nPHFf14o2gPTxfuJLtz5G6m9llU7iIiZ0FignFpv0wu7Zd5Vj6vdV59LyIiZ0TlLiISg1TuIiIxSOUuIhKDVO4iIjFI5S4iEoNU7iIiMUjlLiISgwK7/YCZlQMfnOYfzwD2RDBOtNP3cTx9H5/Qd3G8WPg+ct292ZVQgZX7mTCzwnDurRAv9H0cT9/HJ/RdHC+evg8dlhERiUEqdxGRGBSt5T476ACtjL6P4+n7+IS+i+PFzfcRlcfcRUTk80XrzF1ERD5H1JW7mY03s81mVmRmdwedJyhmlmNmi81so5ltMLPbg87UGphZopmtMbMXgs4SNDPrZGbzzexdM9tkZpcEnSkoZnZn49+T9Wb2BzM7uw80DUBUlbuZJQIPABOAgcA0MxsYbKrA1AJ3uftAYATwrTj+Lpq6HdgUdIhW4n7gZXcfAFxAnH4vZtYL+A4QcvfBNDwLemqwqVpeVJU7UAAUuftWd68G5gGTA84UCHcvdffVjb8fouEv7uk9lDFGmFk28CVgTtBZgmZm6cBY4BEAd6929wPBpgpUEtDWzJKAdsCugPO0uGgr917AzibbxcR5oQGYWR5wIbA82CSBuw/4PlAfdJBWIB8oB37TeJhqjpm1DzpUENy9BPhPYAdQClS4+5+DTdXyoq3c5QRmlgY8A9zh7geDzhMUM/syUObuq4LO0kokARcBD7r7hcARIC7PUZlZZxr+hZ8P9ATam9n0YFO1vGgr9xIgp8l2duO+uGRmyTQU+xPu/mzQeQI2CphkZttpOFz3RTP7XbCRAlUMFLv7R/+am09D2cejK4Bt7l7u7jXAs8DIgDO1uGgr95VAXzPLN7MUGk6KLAg4UyDMzGg4nrrJ3X8RdJ6gufsP3D3b3fNo+P/F39w95mdnn8XddwM7zax/467LgY0BRgrSDmCEmbVr/HtzOXFwcjkp6ACnwt1rzew2YBENZ7znuvuGgGMFZRTwNWCdmb3duO//uPvCADNJ6/Jt4InGidBWYEbAeQLh7svNbD6wmoarzNYQBytVtUJVRCQGRdthGRERCYPKXUQkBqncRURikMpdRCQGqdxFRGKQyl1EJAap3EVEYpDKXUQkBv1/rkX0sBc1lZkAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_stacked_bi_lstm.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }