{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Stacked RNN with Drop out.\n", "\n", "### Many to One Classification by Stacked RNN with Drop out\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class** \n", "- Applying **Drop out** to model by `tf.contrib.rnn.DropoutWrapper`\n", "- Applying **Stacking** to model by `tf.contrib.rnn.MultiRNNCell`\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharStackedRNN class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharStackedRNN:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, dic, hidden_dims = [32, 16]):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " self._keep_prob = tf.placeholder(dtype = tf.float32)\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # Stacked-RNN\n", " with tf.variable_scope('stacked_rnn'):\n", " \n", " cells = []\n", " for hidden_dim in hidden_dims:\n", " cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " cell = tf.contrib.rnn.DropoutWrapper(cell = cell, output_keep_prob = self._keep_prob)\n", " cells.append(cell)\n", " else:\n", " cells = tf.contrib.rnn.MultiRNNCell(cells = cells)\n", " \n", " _, state = tf.nn.dynamic_rnn(cell = cells, inputs = self._X_batch,\n", " sequence_length = self._X_length, dtype = tf.float32)\n", " \n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = state[-1], num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices, keep_prob = 1.):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices, self._keep_prob : keep_prob}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharStackedRNN" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_stacked_rnn = CharStackedRNN(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n", " n_of_classes = 2, dic = char_dic, hidden_dims = [32,16])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_stacked_rnn.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.611\n", "epoch : 2, tr_loss : 0.373\n", "epoch : 3, tr_loss : 0.330\n", "epoch : 4, tr_loss : 0.157\n", "epoch : 5, tr_loss : 0.093\n", "epoch : 6, tr_loss : 0.094\n", "epoch : 7, tr_loss : 0.036\n", "epoch : 8, tr_loss : 0.060\n", "epoch : 9, tr_loss : 0.013\n", "epoch : 10, tr_loss : 0.025\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_stacked_rnn.ce_loss],\n", " feed_dict = {char_stacked_rnn._keep_prob : .5})\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x115c6e438>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xt8VOW97/HPL3cSAiEQbknIJHINeAEGhKBIi1UQhVZti1aF7qqnu7XW1vZsW/e2LT3taXettd1le0rpFrxSa2mLiop4Kco9iHJHIwSScIuEBAi55zl/JGrAQAaYZGVmvu/Xi1ey1jzM+jLKl5VZz3rGnHOIiEh4ifI6gIiIBJ/KXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyp3EVEwpDKXUQkDKncRUTCUIxXB+7Vq5fz+XxeHV5EJCRt2LDhQ+dcWlvjPCt3n89Hfn6+V4cXEQlJZrYnkHF6W0ZEJAyp3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyFXLlvKi7nly/tQB8PKCJyegGVu5lNMbOdZlZgZvedZsyXzGybmW01s6eCG/MT7xaV88gbH7CxqLy9DiEiEvLaLHcziwbmAlOBXOAmM8s9Zcwg4AfABOfccOCedsgKwPWjMkiOj+HRlYXtdQgRkZAXyJn7WKDAObfLOVcLLAJmnDLmDmCuc+4IgHPuUHBjfiIpPoYvjcnkxc37OVBR3V6HEREJaYGUezpQ1GK7uHlfS4OBwWa20szWmNmUYAVszazxPhqc48m1AS2xICIScYJ1QTUGGARMAm4C/mhmKacOMrM7zSzfzPJLS0vP+WADeiYyeWgfnlq7l+q6hnN+HhGRcBVIuZcAmS22M5r3tVQMLHHO1TnndgPv0VT2J3HOzXPO+Z1z/rS0NlesPKOvTvBxuLKW5zftP6/nEREJR4GU+3pgkJllm1kcMBNYcsqYv9N01o6Z9aLpbZpdQcz5KXkX9GRQ7648unK3pkWKiJyizXJ3ztUDdwEvA9uBZ5xzW81sjplNbx72MnDYzLYBrwPfd84dbq/QAGbG7Ak+tu47yoY9R9rzUCIiIce8Ouv1+/3ufD+s40RtPeN+/iqXD0pj7ldGBSmZiEjnZWYbnHP+tsaF3B2qLSXGxTBz7ABe2nqAfeVVXscREek0QrrcAW4dl4VzjifWaFqkiMhHQr7cM1MT+VxuH55ep2mRIiIfCflyB5idl82RE3UseWef11FERDqFsCj3cTmpDO2bzKOrCjUtUkSEMCl3M2N2no/t+4+ydneZ13FERDwXFuUOMOOSdFISY1mg1SJFRMKn3LvERTNzzACWbTtA8ZETXscREfFU2JQ7wK3jszAzHte0SBGJcGFV7ukpXbh6eB8WrSuiqlbTIkUkcoVVuUPTtMiKqjr+/s6pC1eKiESOsCv3Mb4e5PbrxoKVmhYpIpEr7Mr9o9Uidx48xuoP2nVhShGRTivsyh1g+sX9SU2K49FVhV5HERHxRFiWe0JsNDePHcDy7QcpKtO0SBGJPGFZ7gC3jMsiyozHVhd6HUVEpMOFbbn37Z7A1BF9WbS+iMqaeq/jiIh0qLAtd2j6EO1j1fX8baOmRYpIZAnrch81oAcXpndngVaLFJEIE9bl/tFqkQWHjvNWwYdexxER6TBhXe4A117cj15d47RapIhElLAv9/iYaG6+NIvXdh6i8MNKr+OIiHSIsC93gFsuHUC0GY+t1mqRIhIZIqLce3dLYNpF/fhLfhHHNS1SRCJAQOVuZlPMbKeZFZjZfa08PtvMSs3sneZftwc/6vmZnefjWE09i98u9jqKiEi7a7PczSwamAtMBXKBm8wst5Whf3bOXdL8a36Qc563kQN6cHFmCgtWFtLYqGmRIhLeAjlzHwsUOOd2OedqgUXAjPaN1T7+ZYKPXR9WsuL9Uq+jiIi0q0DKPR0oarFd3LzvVDeY2SYze9bMMoOSLsimjuhHWnI8C7RapIiEuWBdUH0O8DnnLgJeARa2NsjM7jSzfDPLLy3t+LPnuJgobrk0izd2lrKr9HiHH19EpKMEUu4lQMsz8YzmfR9zzh12ztU0b84HRrf2RM65ec45v3POn5aWdi55z9vNlw4gNlrTIkUkvAVS7uuBQWaWbWZxwExgScsBZtavxeZ0YHvwIgZXWnI8113Un7/kF3Gsus7rOCIi7aLNcnfO1QN3AS/TVNrPOOe2mtkcM5vePOxuM9tqZu8CdwOz2ytwMMye4KOytoFnN2hapIiEJ/NqtUS/3+/y8/M9OTbADY+s4vDxGl67dxJRUeZZDhGRs2FmG5xz/rbGRcQdqq2Zneej8PAJ3njvkNdRRESCLmLLfcqIvvTpFs+jWi1SRMJQxJZ7bHQUt47L4s33P6Tg0DGv44iIBFXEljvATWMHEBcTxcJVmhYpIuElosu9Z9d4pl/cn7++XUxFlaZFikj4iOhyh6YLqydqG/hLflHbg0VEQkTEl/uI9O6M9aWycHUhDVotUkTCRMSXOzTd1FRUVsVrOzQtUkTCg8oduCq3D/26J7Bg1W6vo4iIBIXKHYiJjuLW8VmsLDjMewc1LVJEQp/KvdnMMQOIj4nSWu8iEhZU7s1Sk+L4/CXpLH67mIoTmhYpIqFN5d7C7Ak+qusaWbR+r9dRRETOi8q9hWH9ujEuJ5XHVu+hvqHR6zgiIudM5X6K2XnZlJRXsXy7pkWKSOhSuZ/iymG9SU/pommRIhLSVO6niImO4rbxWazZVcb2/Ue9jiMick5U7q348phMEmKjWKhpkSISolTurUhJjOP6URn8bWMJZZW1XscRETlrKvfTmJ3no6Ze0yJFJDSp3E9jcJ9kJgzsyeOaFikiIUjlfgaz87LZX1HNsm0HvY4iInJWVO5n8NmhvclM7cICfYi2iIQYlfsZREcZs8b7WFdYxpaSCq/jiIgELKByN7MpZrbTzArM7L4zjLvBzJyZ+YMX0Vtf9GeSGBetaZEiElLaLHcziwbmAlOBXOAmM8ttZVwy8G1gbbBDeql7l1huGJXBP97dx+HjNV7HEREJSCBn7mOBAufcLudcLbAImNHKuJ8CvwSqg5ivU5iVl0VtfSNPr9O0SBEJDYGUezpQ1GK7uHnfx8xsFJDpnHshiNk6jYG9k7l8UC8eX7OHOk2LFJEQcN4XVM0sCngIuDeAsXeaWb6Z5ZeWlp7voTvUVyf4OHi0hpe2HPA6iohImwIp9xIgs8V2RvO+jyQDI4A3zKwQGAcsae2iqnNunnPO75zzp6WlnXtqD0wa3Btfz0R9DJ+IhIRAyn09MMjMss0sDpgJLPnoQedchXOul3PO55zzAWuA6c65/HZJ7JGoKGNWno8Ne46wqbjc6zgiImfUZrk75+qBu4CXge3AM865rWY2x8ymt3fAzuTG0RkkxUXr7F1EOr2YQAY555YCS0/Z98Bpxk46/1idU3JCLF/0Z/LU2r38YOow0pLjvY4kItIq3aF6lm4bn0VtQyNPrdW0SBHpvFTuZyknrSuThqTxxNo91NZrWqSIdE4q93MwO89H6bEaXtyy3+soIiKtUrmfg4mD0shJS+KXL+7gqbV7qapt8DqSiMhJVO7nICrK+OUNF9E9MY4f/m0z4/7vq/x86XaKyk54HU1EBABzznlyYL/f7/LzQ3sqvHOO9YVHWLiqkJe2HqDROSYP7cPsPB8TBvbEzLyOKCJhxsw2OOfaXHk3oKmQ0jozY2x2KmOzU9lfUcWTa/by9Lq9LN9+kIG9uzJrfBZfGJVB13i9zCLSsXTmHmTVdQ28sGk/C1cXsqm4guT4GG4YncGsPB/ZvZK8jiciIS7QM3eVeztxzrGxqJyFqwpZunk/dQ2OSUPSmJXn44pBaURF6S0bETl7KvdO5NCxap5au5cn1+6l9FgN2b2SuHVcFjf6M+iWEOt1PBEJISr3Tqi2vpEXt+xn4apC3t5bTmJcNDeMymBWXhYDeyd7HU9EQoDKvZPbXFzBglWFPLdpH7X1jUwY2JNZ431MHtaHaL1lIyKnoXIPEYeP17BofRFPrNnD/opqMnp04bbxWXzJn0lKYpzX8USkk1G5h5j6hkaWbTvIglWFrNtdRkJsFF8Ymc5t430M69fN63gi0kmo3EPYtn1HeWx1IX9/p4TqukbGZqcyO8/HVbl9iInWTcUikUzlHgbKT9Ty5/VFPL5mD8VHqujXPYFbxmUxc0wmPbtqLXmRSKRyDyMNjY5Xtx9k4epCVhYcJi4miukX92d2no8R6d29jiciHUjLD4SR6CjjquF9uWp4X94/eIyFqwtZ/HYJz24oZnRWD34wdSh+X6rXMUWkE9GZe4iqqKrj2Q3FzH9zF3UNjte+d4VuiBKJAIGeuevqXIjq3iWWr12WzR9uHc3hyhp+t/x9ryOJSCeicg9xF2WkMHPMAB5dVch7B495HUdEOgmVexj4/tVD6Bofw4+XbMWrt9lEpHNRuYeB1KQ4vnfVYFZ9cJilmw94HUdEOgGVe5i4+dIscvt142cvbONEbb3XcUTEYwGVu5lNMbOdZlZgZve18vjXzWyzmb1jZm+ZWW7wo8qZREcZc2YMZ19FNf/9+gdexxERj7VZ7mYWDcwFpgK5wE2tlPdTzrkLnXOXAP8JPBT0pNImvy+VL4xMZ96KXRR+WOl1HBHxUCBn7mOBAufcLudcLbAImNFygHPuaIvNJEBX9Tzyg6lDiY02fvr8Nq+jiIiHAin3dKCoxXZx876TmNk3zewDms7c7w5OPDlbvbslcM+Vg3l1xyFe3X7Q6zgi4pGgXVB1zs11zl0A/Bvw762NMbM7zSzfzPJLS0uDdWg5xaw8HxekJTHn+W1U1zV4HUdEPBBIuZcAmS22M5r3nc4i4POtPeCcm+ec8zvn/GlpaYGnlLMSFxPFj6cPZ8/hE8x/c5fXcUTEA4GU+3pgkJllm1kcMBNY0nKAmQ1qsTkN0L3wHrt8UBpTR/Tl968XUFJe5XUcEelgbZa7c64euAt4GdgOPOOc22pmc8xsevOwu8xsq5m9A3wXmNVuiSVg908bBsDPX9jucRIR6WgBLfnrnFsKLD1l3wMtvv92kHNJEGT0SOQbkwby0CvvcXPBh0wY2MvrSCLSQXSHapi7c2IOA1IT+dGSrdQ1NHodR0Q6iMo9zCXERvPAtbkUHDrOwlWFXscRkQ6ico8Ak4f15jND0nh4+fscOlbtdRwR6QAq9whgZjxw3XBq6xv5xYs7vI4jIh1A5R4hsnslcfvl2Sx+u4QNe8q8jiMi7UzlHkHu+uxA+nVP4IF/bKWhUcv/iIQzlXsESYyL4f5pw9i67yhPr9vrdRwRaUcq9wgz7cJ+jM/pyYPLdnKkstbrOCLSTlTuEcbM+PH04RyrrufBZTu9jiMi7UTlHoGG9E1m1ngfT63by5aSCq/jiEg7ULlHqHs+N4ieSXE88I8tNOriqkjYUblHqG4JsfzblKG8vbecxRvPtIKziIQilXsEu2FUBiMHpPCLF3dwtLrO6zgiEkQq9wgWFWXMmT6Cw5U1/Ha5luAXCScq9wh3YUZ3Zo4ZwIJVhbx38JjXcUQkSFTuwvevHkLX+Bh+9I+tOKeLqyLhQOUupCbF8b2rh7B612GWbj7gdRwRCQKVuwBw89gB5Pbrxv95YRsnauu9jiMi50nlLgBERxlzZgxnf0U1c18v8DqOiJwnlbt8zO9L5fqR6fxxxW4KP6z0Oo6InAeVu5zkvqlDiYuJYs7z27yOIiLnQeUuJ+ndLYF7rhzEazsO8er2g17HEZFzpHKXT5mV52Ng76785LltVNc1eB1HRM6Byl0+JTY6ih9fN5y9ZSeY/+Yur+OIyDkIqNzNbIqZ7TSzAjO7r5XHv2tm28xsk5m9amZZwY8qHemyQb245sK+/P71AkrKq7yOIyJnqc1yN7NoYC4wFcgFbjKz3FOGbQT8zrmLgGeB/wx2UOl4909r+s/8sxd0cVUk1ARy5j4WKHDO7XLO1QKLgBktBzjnXnfOnWjeXANkBDemeCE9pQvfnDSQpZsPsLLgQ6/jiMhZCKTc04GiFtvFzftO52vAi+cTSjqPOybmMCA1kR8t2UpdQ6PXcUQkQEG9oGpmtwB+4FenefxOM8s3s/zS0tJgHlraSUJsND+6LpeCQ8dZuKrQ6zgiEqBAyr0EyGyxndG87yRmdiVwPzDdOVfT2hM55+Y55/zOOX9aWtq55BUPTB7Wh88MSePh5e9z6Gi113FEJACBlPt6YJCZZZtZHDATWNJygJmNBP5AU7EfCn5M8doD1w2ntr6RX7y4w+soIhKANsvdOVcP3AW8DGwHnnHObTWzOWY2vXnYr4CuwF/M7B0zW3Kap5MQld0riTsmZrN4Ywn5hWVexxGRNphXH87g9/tdfn6+J8eWc3Oitp7Jv/4nPRLjeO5blxEdZV5HEok4ZrbBOedva5zuUJWAJcbFcP+0YWzbf5Sn1u31Oo6InIHKXc7KtAv7MT6nJw++vJOyylqv44jIaajc5ayYGT+ZMZzjNfU8uGyn13FE5DRU7nLWBvdJZtZ4H0+v28vm4gqv44hIK1Tuck7u+dwgeibF8cCSLTQ2enNRXkROT+Uu56RbQiz3TR3Gxr3lLN74qXvaRMRjKnc5Z9ePTGfUgBR+8eJ2jlbXeR1HRFpQucs5i4oy5swYweHKWh5+5X2v44hICyp3OS8j0rtz09gBLFxdyKbicq/jiEgzlbuct+9fNYQeiXHc+MhqHl7+HjX1+txVEa+p3OW89UiK48VvX86UEX15ePn7XPPbN1m3W+vPiHhJ5S5BkZYcz+9uGsmCr46hpr6RL/1hNff9dRMVJ3ShVcQLKncJqklDerPsOxO5c2IOf9lQzOSH/smSd/fh1QJ1IpFK5S5BlxgXww+vGcaSuybQPyWBu5/eyFcXrKeo7ETbv1lEgkLlLu1meP/u/O0bE3jg2lzW7S7jqt+sYN6KD6jXZ7GKtDuVu7Sr6CjjXy7L5pXvXsGEgT35+dIdzJi7UtMmRdqZyl06RHpKF/54m59HvjKK0mM1fH7uSuY8t43Kmnqvo4mEJZW7dBgzY+qF/Vh+7xV85dIsHl21m8899E+WbzvodTSRsKNylw7XLSGWn35+BM9+PY/khFhufyyff31iAwePVnsdTSRsqNzFM6OzevDcty7j+1cP4dUdh7jy1//k8TV7tISwSBCo3MVTcTFRfPMzA1l2z0QuyuzOf/x9C1/8w2p2HjjmdTSRkKZyl07B1yuJJ752Kb/+4sXsKj3OtN+9ya9e3kF1ndapETkXKnfpNMyMG0Zn8Oq9k5h+SX/mvv4BUx5ewaqCD72OJhJyVO7S6aQmxfHQly7hydsvBeDm+Wu595l3Kaus9TiZSOhQuUunNWFgL166ZyLf/MwF/OOdEib/+g0Wv12sdWpEAhBQuZvZFDPbaWYFZnZfK49PNLO3zazezG4MfkyJVAmx0Xz/6qG8cPflZPdK4rvPvMstf1pL4YeVXkcT6dTaLHcziwbmAlOBXOAmM8s9ZdheYDbwVLADigAM6ZvMs1/P46efH8GmogqufngFc18voE7r1Ii0KpAz97FAgXNul3OuFlgEzGg5wDlX6JzbBOhvmrSbqCjj1nFZLL/3Cj47tDe/enkn1/7uLTbsOeJ1NJFOJ5ByTweKWmwXN+87a2Z2p5nlm1l+aWnpuTyFCH26JfDILaP5421+jlbXceP/W8V//H0LR6v1wSAiH4npyIM55+YB8wD8fr+uisl5+VxuH8Zf0JNfL9vJglWFLNt2gH+flsuwft2ob2ykvsHR0Oiob/zoa+Mn2w0n7z95bGOLx5q/Nnx6f31DK+Oaf39iXAx3TsxhcJ9kr18miVCBlHsJkNliO6N5n4jnusbH8KPrhvP5S9K5b/FmvvX0xnY9Xmy0ER1lxERFNX+1j7/GREd9vH2gopq/bSzhK5cO4DtXDqZHUly75hI5VSDlvh4YZGbZNJX6TODmdk0lcpYuzkzhubsm8NqOQ1TXN562dD/52lzO0a3vb63Eo6Is4DxHKmv5zfL3eGLNHv7xzj6+c+UgvjIui9hozT6WjmGBzBk2s2uAh4Fo4H+ccz8zszlAvnNuiZmNAf4G9ACqgQPOueFnek6/3+/y8/PP+w8g0pntPHCMOc9vZWXBYQb17sp/XJvLxMFpXseSEGZmG5xz/jbHeXVDiMpdIoVzjle2HeRnS7ez5/AJJg/tzf3ThpGT1tXraBKCAi13/Ywo0s7MjKuG92XZdyZy39ShrN1dxtUPr+BnL2zTDB9pNyp3kQ4SHxPN16+4gNe+dwXXj8xg/lu7+cyv3uDpdXtp0Br2EmQqd5EO1js5gV/eeBHP3XUZOWlJ/GDxZq77r7dYu+uw19EkjKjcRTwyIr07z/yv8fzXTSOpqKrjy/PW8I0nN1BUdsLraBIGVO4iHjIzrru4P6/eewXf/dxgXttxiMkP/ZMHX95JZU291/EkhKncRTqBhNho7p48iNe/N4lrRvTl968X8NnmJY71mbJyLlTuIp1Iv+5deHjmSP76r3n07ZbAd595l+sfWcXGvaG7ONqx6jqqavVxiR1N89xFOqnGRsfijSX850s7OHSshi+MTOffpgylb/cEr6Od0YGKatYXljX/OsKOA0dJioth5phMvnpZNukpXbyOGNJ0E5NImKisqee/3yjgj2/uJtqMb0y6gDsm5pAQG+11NBobHQWlx1lfWEZ+4RHWF5ZRfKQKgMS4aEYN6MHorB4UHq7k+U37Abj2on7ccXkOI9K7exk9ZKncRcJMUdkJfr50Oy9uOUB6Shd+eM0wrrmwL2aBr3lzvmrqG9hSUsH6wiOs311G/p4jVFQ13YjVq2s8Y3w98PtSGetLZVi/ZGJarKVTUl7Fo2/t5ul1e6msbSDvgp7cMTGHSYPTOvTPEOpU7iJhavUHh/nJc1vZceAYY7NT+dF1uQzv3z5nwRVVdby958jHZ+bvFJdTW9/0mTw5aUmMyUrF7+vBGF8qWT0TAyrpiqo6Fq3by6MrCzlwtJrBfbpy++U5zLikP/Ex3v800tmp3EXCWEOj48/ri3hw2U6OnKhl5phM7r1qCL26xp/X8+4rrzrpLZadB4/hHMREGcPTuzMmq+nM3O/rcd7Hqq1v5IXN+5i3Yjfb9x8lLTme2Xk+vnLpAFIStUTy6ajcRSJARVUd//Xq+yxYVUiX2Gi+NXkgs/OyiYtpeyJcY6Pj/UPHWVdYRn5zoZeUN71fnhQXzaispjNyv68Hl2SmkBjXPp/t45xjZcFh5r25ixXvlZIYF82X/Jl87bJsMlMT2+WYoUzlLhJBPig9zs9e2M5rOw6R3SuJ+68ZxuRhvU96m6S6roHNJRUfn5nnF5ZxtLrpRqneyfEfF/kYXypD+578fnlH2b7/KPPf3M2Sd0toaHRMHdGPOybmcElmSodn6axU7iIR6I2dh/jp89v4oLSSywf14stjMtlScpT8wjI2FVdQ29D0fvnA3l2bLn5mpTLGl0pmapdOdVHzQEU1C1YV8uTaPRyrrmesL5U7JuYweWjvs/rQlHCkcheJUHUNjTyxZg+/eeU9jlbXExttjEjvzhhfU5GPzupBaoh87N/xmnqeWV/En97aTUl5FTm9krj98hyuH5XeKaaCekHlLhLhyk/U8kFpJbn9utElLrSLsL6hkRe3HGDeil1sLqmgZ1Ict433cev4rJD5hypYVO4iEnacc6zdXcYfV+zi1R2HiI+J4ov+DL52WQ7ZvZK8jtchAi339rn8LSLSDsyMcTk9GZfTk4JDx5j/5m6eWV/Mk2v3clVuH+6cmMPorFSvY1Lf0MihYzWUlFexr7zqk69HqthXXs3dkwcx7aJ+7ZpB5S4iIWlg72R+ccNF3HvVEB5fXchja/bw8taDjByQwp2X53DV8L5Et9PF1+M19a2UdlNxl5RXceBo9ac+XatHYiz9U7owoGciyQntX716W0ZEwsKJ2nr+uqGY+W/tZs/hEwxITeT2y7O5cXTGWc3Rb2x0lB6vofjjwv7k7LukvJp95VUfL7nwkZgoo2/3BNJTupCe0oX+KV1I79H8NSWB/ildgnafgN5zF5GI1NDoeGXbAf6wYhcb95aTkhjLreOyuG28j7TkeKpqGz4+4/6ktD/5/kBFNXUNJ/dit4SY5qL+pLQ/3k7pQlpyfLv9lHAqlbuIRLwNe8qYt2IXy7YdJDYqiq4JMZRV1p40Jsqa1tHv33yG3bK0m7YTSE6I9ehP8Gm6oCoiEW90Vip/uDWV3R9W8uSaPVTVNXzqDLxPcrwnd+O2t4DK3cymAL8FooH5zrlfnPJ4PPAYMBo4DHzZOVcY3KgiIucmu1cS/35trtcxOlSb/1yZWTQwF5gK5AI3mdmpr9LXgCPOuYHAb4BfBjuoiIgELpCfRcYCBc65Xc65WmARMOOUMTOAhc3fPwtMts60UIWISIQJpNzTgaIW28XN+1od45yrByqAnqc+kZndaWb5ZpZfWlp6bolFRKRNHXoVwTk3zznnd87509LSOvLQIiIRJZByLwEyW2xnNO9rdYyZxQDdabqwKiIiHgik3NcDg8ws28zigJnAklPGLAFmNX9/I/Ca82oCvYiItD0V0jlXb2Z3AS/TNBXyf5xzW81sDpDvnFsC/Al43MwKgDKa/gEQERGPBDTP3Tm3FFh6yr4HWnxfDXwxuNFERORcebb8gJmVAnvO8bf3Aj4MYpxQp9fjZHo9PqHX4mTh8HpkOefanJHiWbmfDzPLD2RthUih1+Nkej0+odfiZJH0eoTfggoiIqJyFxEJR6Fa7vO8DtDJ6PU4mV6PT+i1OFnEvB4h+Z67iIicWaieuYuIyBmEXLmb2RQz22lmBWZ2n9d5vGJmmWb2upltM7OtZvZtrzN1BmYWbWYbzex5r7N4zcxSzOxZM9thZtvNbLzXmbxiZt9p/nuyxcyeNrMErzO1t5Aq9wDXlo8U9cC9zrlcYBzwzQh+LVr6NrDd6xCdxG+Bl5xzQ4GLidDXxczSgbsBv3NuBE132of9XfQhVe4EtrZ8RHDO7XfOvd38/TGa/uKeuhRzRDGzDGAaMN/rLF4zs+7ARJqWBsE5V+ucK/c2ladigC7NCxt3FyCAAAABg0lEQVQmAvs8ztPuQq3cA1lbPuKYmQ8YCaz1NonnHgb+N9DodZBOIBsoBR5tfptqvpkleR3KC865EuBBYC+wH6hwzi3zNlX7C7Vyl1OYWVfgr8A9zrmjXufxipldCxxyzm3wOksnEQOMAh5xzo0EKoGIvEZlZj1o+gk/G+gPJJnZLd6man+hVu6BrC0fMcwslqZif9I5t9jrPB6bAEw3s0Ka3q77rJk94W0kTxUDxc65j36ae5amso9EVwK7nXOlzrk6YDGQ53Gmdhdq5R7I2vIRofkzav8EbHfOPeR1Hq85537gnMtwzvlo+v/iNedc2J+dnY5z7gBQZGZDmndNBrZ5GMlLe4FxZpbY/PdmMhFwcTmgJX87i9OtLe9xLK9MAG4FNpvZO837fti8PLMIwLeAJ5tPhHYBX/U4jyecc2vN7FngbZpmmW0kAu5U1R2qIiJhKNTelhERkQCo3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyp3EVEwtD/B+VRAJGJ3SY8AAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_stacked_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }