{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Recurrent Neural Networks. \n", "\n", "### Many to One Classification by RNN\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `tf.keras.preprocessing.sequence.pad_sequences`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.12.0\n" ] } ], "source": [ "from __future__ import absolute_import, division, print_function\n", "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow import keras\n", "import string\n", "%matplotlib inline\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['<pad>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '*']\n" ] } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space = list(char_space)\n", "char_space.insert(0, '<pad>')\n", "print(char_space)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'<pad>': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, ' ': 27, '*': 28}\n" ] } ], "source": [ "char2idx = {char : idx for idx, char in enumerate(char_space)}\n", "print(char2idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### padding example data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "words = list(map(lambda word : [char2idx.get(char) for char in word],words))\n", "\n", "max_length = 10\n", "X_length = list(map(lambda word : len(word), words))\n", "X_indices = pad_sequences(sequences=words, maxlen=max_length, padding='post', truncating='post')\n", "\n", "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharRNN class" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class CharRNN:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.Variable(one_hot,name='one_hot_embedding',\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices) \n", " \n", " # RNN cell\n", " with tf.variable_scope('rnn_cell'):\n", " rnn_cell = keras.layers.SimpleRNNCell(units = hidden_dim, activation = tf.nn.tanh)\n", " _, state = tf.nn.dynamic_rnn(cell = rnn_cell, inputs = self._X_batch,\n", " sequence_length = self._X_length, dtype = tf.float32)\n", " \n", " with tf.variable_scope('output_layer'):\n", " self._score = tf.keras.layers.Dense(units=n_of_classes)(state)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharRNN" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "char_rnn = CharRNN(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb, n_of_classes = 2,\n", " hidden_dim = 16, dic = char2idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_rnn.ce_loss)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.583\n", "epoch : 2, tr_loss : 0.494\n", "epoch : 3, tr_loss : 0.400\n", "epoch : 4, tr_loss : 0.335\n", "epoch : 5, tr_loss : 0.272\n", "epoch : 6, tr_loss : 0.224\n", "epoch : 7, tr_loss : 0.177\n", "epoch : 8, tr_loss : 0.142\n", "epoch : 9, tr_loss : 0.115\n", "epoch : 10, tr_loss : 0.092\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_rnn.ce_loss])\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x7efdf4492fd0>]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl0VeW9//H3NycTJGFMmAkECGgYFDkkyIyUCkVBrVpAFAdEq4jXqdW23l+rba9Va68DDojWCcVZcQCcAMUBCKPMkwIBAgEkjIEMz++PpBq80QQ4yT7D57WWa3Wf83DOZ51VPmvz7P3sx5xziIhIeInyOoCIiASeyl1EJAyp3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyp3EVEwlB0VQaZ2WDgQcAHTHbO3VPBmIuBPwMOWOqcG/Vzn5mcnOxat259vHlFRCLawoULdznnUiobV2m5m5kPmAgMAnKABWY2zTm3styYdOAOoJdz7jsza1TZ57Zu3Zrs7OzKhomISDlmtqkq46oyLZMJrHfObXTOHQWmAsN/NOZqYKJz7jsA59zO4wkrIiKBVZVybw5sKXecU/Zaee2B9mb2uZl9VTaN83+Y2Tgzyzaz7Ly8vBNLLCIilQrUBdVoIB3oD4wEnjSzej8e5Jyb5JzzO+f8KSmVThmJiMgJqkq5bwValjtuUfZaeTnANOdcoXPuG2AtpWUvIiIeqEq5LwDSzSzNzGKBEcC0H415i9KzdswsmdJpmo0BzCkiIseh0nJ3zhUB44GZwCrgFefcCjO7y8yGlQ2bCew2s5XALOA259zu6gotIiI/z7zaicnv9zvdCikicnzMbKFzzl/ZuJBbofrtroP8Y8ZqSkq0PaCIyE8JuXL/YGUuj83ewN3vrUT7v4qIVKxKjx8IJlf3acOOfUd4au43JMXHcPOg9l5HEhEJOiFX7mbGn4aeyoGCIh76eB1JcdFc3beN17FERIJKyJU7lBb83y/ozIGjRfzt/VUkxkczMjPV61giIkEjJMsdwBdl/Ovi0zl4pIg/vPk1CXHRDDutmdexRESCQshdUC0vNjqKxy7pRvfWDbj55SV8snqH15FERIJCSJc7QK1YH0+N8ZPRrA6/fWERX27Q2ikRkZAvd4Ck+BievSKTVg1rM/bZBSzZstfrSCIingqLcgeonxDLC1dl0TAxjjFPz2d17j6vI4mIeCZsyh2gUZ14pozNolaMj9GT5/PtroNeRxIR8URYlTtAywa1eWFsJiXOccnkeWzbe9jrSCIiNS7syh2gXaMknrsyk32HCxn91Dx2HTjidSQRkRoVluUO0Kl5XZ6+ojvb9h7msqfmk3+40OtIIiI1JmzLHaB76wY8camfdTv3c+UzCzh0tMjrSCIiNSKsyx2gX/sUHhrRlcWbv+Oa5xdypKjY60giItUu7MsdYEjnptx74Wl8tm4XE15aTFFxideRRESqVUSUO8CF3Vrw53MzmLliB797fZk2+xCRsBayDw47EZf3SmN/QRH//HAtiXHR/GVYR8zM61giIgEXUeUOMP6sduw/UsSkTzeSFB/NbWef4nUkEZGAi7hyNzPuGHIK+wuKmDhrA0nxMVzbr63XsUREAiriyh1KC/6v53XiwJEi7pm+msS4aEb3aOV1LBGRgInIcofSzT4euPg0Dh0p4s63l5MUH83w05t7HUtEJCAi5m6ZisT4oph4yRn0SGvIza8s5cOV2uxDRMJDRJc7QHyMjyfH+OncvC7Xv7iIz9fv8jqSiMhJi/hyB0iMi+aZK7rTJjmBq5/LZuGm77yOJCJyUlTuZerVjuW5qzJplBTHFf+ez8pt2uxDREKXyr2cRknxvDA2i4S4aC57eh4b8w54HUlE5ISo3H+kRf3avDA2C+dg9OR5bNVmHyISglTuFWibkshzV2Wy/0gRoyfPI2+/NvsQkdCicv8JHZvV5ZkrupObX8ClT80j/5A2+xCR0FGlcjezwWa2xszWm9ntFbx/uZnlmdmSsv/GBj5qzevWqgGTLuvGxryDXP7MfA4e0WYfIhIaKi13M/MBE4EhQAYw0swyKhj6snPu9LL/Jgc4p2f6pKfw8KiuLMvJZ9zz2RQUarMPEQl+VTlzzwTWO+c2OueOAlOB4dUbK7ic3bEJ913Yhc/X72b8i4sp1GYfIhLkqlLuzYEt5Y5zyl77sV+b2TIze83MWgYkXRC54IwW3D28Ix+t2sFtry7VZh8iEtQCdUH1HaC1c64L8CHwbEWDzGycmWWbWXZeXl6AvrrmXHpma247uwNvLdnGnW8vxzkVvIgEp6qU+1ag/Jl4i7LXvuec2+2c+8/9gpOBbhV9kHNuknPO75zzp6SknEhez10/oB3X9mvLlHmbuXfmGq/jiIhUqCqP/F0ApJtZGqWlPgIYVX6AmTV1zm0vOxwGrApoyiDz+8Ed2F9QyGOzN5Q+rqBXmteRRESOUWm5O+eKzGw8MBPwAU8751aY2V1AtnNuGjDBzIYBRcAe4PJqzOw5M+Ou4Z3YdeAId727kkZJ8Qzt0tTrWCIi3zOv5o39fr/Lzs725LsDpaCwmNGT57EsJ5/nrsqkR5uGXkcSkTBnZgudc/7KxmmF6kmIj/ExeYyflg1qcfVz2azJ3e91JBERQOV+0urVjuXZKzOpFePj8n/PZ3u+HjQmIt5TuQdAi/q1eeaKTPYXFHH50wvIP6zn0IiIt1TuAZLRrA5PXNqNjbsOMO65bI4U6TEFIuIdlXsA9WqXzP0Xnca8b/Zw8ytaxSoi3qnKfe5yHIaf3pwd+wr4+/uraVInnjvPqegZayIi1UvlXg2u7tOG7fkFPDX3G5rWjWdsnzZeRxKRCKNyrwZmxp1DM9i57wh/fW8VjerEM+y0Zl7HEpEIonKvJlFRxj8vPo28A0e45ZUlJCfE0rNdstexRCRC6IJqNYqP8fHkpX7SkhO45vmFrNq+z+tIIhIhVO7VrG7tGJ65IpOEuGgu//d8tu7VIicRqX4q9xrQrF4tnrmyO4eOFjPm6fnsPXTU60giEuZU7jXklCZ1mHSpn827DzH2We3FKiLVS+Veg85s25AHfnMaCzd/x39NXUKxFjmJSDVRudewc7o0486hGcxYkctf3lmhrfpEpFroVkgPXNk7jdx9BUz6dCNN69bit/3beh1JRMKMyt0jtw8+hdz8Av4xYzWN68RxwRktvI4kImFE5e6RqCjjvou6sOvAEX732jJSkuLokx6am4aLSPDRnLuH4qJ9PH5pN9o1SuTa5xeyfGu+15FEJEyo3D1WJz6GZ6/MpF7tWK54ZgFb9hzyOpKIhAGVexBoXCeeZ6/sztGiEsY8PZ89B7XISUROjso9SLRrlMTkMX5y9h5m7LMLOHxUi5xE5MSp3INI99YNeGjE6SzespcbXlpMUXGJ15FEJESp3IPM4E5N+fO5Hflo1Q7+e5oWOYnIidGtkEFoTM/W5O4r4LHZG2hWN57xZ6V7HUlEQozKPUj97uwO7Mgv4P4P1tKoTjwX+1t6HUlEQojKPUiZGff8ugt5B45wxxtfk5IUx4AOjbyOJSIhQnPuQSw2OorHRnfjlCZJXPfCIpZu2et1JBEJESr3IJcYF82/r+hOw8RYrnxmAZt2H/Q6koiEAJV7CGiUFM+zV2ZS7ByXPT2fXQeOeB1JRIKcyj1EtE1J5Kkx3cnNL+CqZxZw6GiR15FEJIhVqdzNbLCZrTGz9WZ2+8+M+7WZOTPzBy6i/Ee3VvV5eGRXvt6az/VTFmmRk4j8pErL3cx8wERgCJABjDSzjArGJQE3AvMCHVJ+8MuOTbj7vE7MWpPHH99crkVOIlKhqpy5ZwLrnXMbnXNHganA8ArG3Q38AygIYD6pwCVZrbjhrHa8nL2F//1onddxRCQIVaXcmwNbyh3nlL32PTM7A2jpnHsvgNnkZ9w8qD0XdWvBgx+v46m53+gMXkSOcdKLmMwsCngAuLwKY8cB4wBSU1NP9qsjmpnx9ws6s/dwIXe/u5IV2/L523mdqRXr8zqaiASBqpy5bwXKr31vUfbafyQBnYDZZvYt0AOYVtFFVefcJOec3znnT0nRlnInK8YXxeOju3HjwHTeXLyV8x/9nG936T54EalauS8A0s0szcxigRHAtP+86ZzLd84lO+daO+daA18Bw5xz2dWSWI7hizJuGtSepy/vTu6+As59eC4frMj1OpaIeKzScnfOFQHjgZnAKuAV59wKM7vLzIZVd0CpmgEdGvHO+N60Tk5g3PMLuWf6at0qKRLBzKsLcX6/32Vn6+Q+0AoKi/nLOyt5af5merRpwMMjzyAlKc7rWCISIGa20DlX6VoirVANM/ExPv7ngs7cd2EXFm/ey9CHPiP72z1exxKRGqZyD1MX+Vvy5nW9qBXrY8Skr3hat0uKRBSVexjLaFaHaeN7079DI+56dyXjX1rMgSN6Jo1IJFC5h7m6tWKYdGk3fj/4FKZ/vZ3hj8xl3Y79XscSkWqmco8AUVHGb/u35YWxWeQfLmT4xM+ZtnSb17FEpBqp3CNIz7bJvHtDH05tWocJLy3mz9NWcLRIt0uKhCOVe4RpUjeeqeN6cGWvNJ754ltGPvkVufl61ptIuFG5R6AYXxT/fW4Gj4zqyqrt+xj60Gd8sX6X17FEJIBU7hHsnC7NmDa+F/UTYhn91Dwenb2ekhLdLikSDlTuEa5doyTevr4Xv+rclHtnrGHc8wvJP1zodSwROUkqdyEhLpqHR3blz+dmMHvNTs59eC4rtuV7HUtEToLKXYDS58Nf3iuNl6/pwdGiEi549Atezd5S+R8UkaCkcpdjdGvVgHcn9KZbq/rc9toy7nhjGQWFxV7HEpHjpHKX/yM5MY7nr8ri+gFteWn+Fi58/Au27DnkdSwROQ4qd6mQL8q47exTmHyZn027D3HOw3OZtXqn17FEpIpU7vKzfpHRmHdv6E3zerW44pkFPPDBGop1u6RI0FO5S6VaNUzgjet6clG3Fjz0yXou//d89hw86nUsEfkZKnepkvgYH/de2IV7LujMvG/2cM5Dn7Fky16vY4nIT1C5S5WZGSMyU3n92p5ERRkXPf4Fz3+1SZuAiAQhlbsct84t6vLuDb3p3S6ZO99azs2vLOXQUW0CIhJMVO5yQurVjuWpMd25ZVB73lqylfMnfsHGvANexxKRMip3OWFRUcYNA9N57spMdu4v4JyH5/Lm4hyvY4kIKncJgD7pKbx/Yx86Na/LTS8v5dZXNU0j4jWVuwRE07q1eHFsFhMGpvP6ohzOfXguq7bv8zqWSMRSuUvARPuiuHlQe6ZclcW+giLOm/g5U+bpbhoRL6jcJeB6tktm+o19yGrTkD++uZzxLy5mX4GeES9Sk1TuUi2SE+N45vLu3D7kFGasyGXoQ5+xVIueRGqMyl2qTVSUcW2/trxyzZmUlMCvH/uCJz/dqK38RGqAyl2qXbdW9Xl/Qh8GntqIv72/irHPZevZNCLVTOUuNaJu7RgeH92Nu4Z3ZO66Xfzqwc+Yt3G317FEwpbKXWqMmXHZma1547qe1Ir1MfLJr3jwo3V6hLBINVC5S43r1Lwu79zQm2GnNeNfH61l9OR57NxX4HUskbBSpXI3s8FmtsbM1pvZ7RW8f62ZfW1mS8xsrpllBD6qhJPEuGj+9ZvTuffCLizZspchD37GnLV5XscSCRuVlruZ+YCJwBAgAxhZQXm/6Jzr7Jw7HbgXeCDgSSXsmBkX+1vyzg29SE6MY8zT87ln+moKi0u8jiYS8qpy5p4JrHfObXTOHQWmAsPLD3DOlV9nngBoElWqrF2jJN4e34tRWak8PmcDv3niS3K+04bcIiejKuXeHNhS7jin7LVjmNn1ZraB0jP3CRV9kJmNM7NsM8vOy9M/weUH8TE+/n5+Zx4Z1ZV1Ow7wqwc/Y8byXK9jiYSsgF1Qdc5NdM61BX4P/Oknxkxyzvmdc/6UlJRAfbWEkXO6NOO9CX1onZzAtS8s5P+9vZyCwmKvY4mEnKqU+1agZbnjFmWv/ZSpwHknE0oiW2rD2rx2bU/G9k7j2S83ccGj2ghE5HhVpdwXAOlmlmZmscAIYFr5AWaWXu5wKLAucBElEsVGR/GnczJ4aoyfbfmHtRGIyHGqtNydc0XAeGAmsAp4xTm3wszuMrNhZcPGm9kKM1sC3AyMqbbEElEGntqY6Tf2oVOz0o1AbtNGICJVYl49a9vv97vs7GxPvltCT1FxCQ9+vI5HZq2nbUoij4zqyilN6ngdS6TGmdlC55y/snFaoSohIdoXxS2/7MALV2WRf7iQ4Y9oIxCRn6Nyl5DSq10y70/oQ2Zag9KNQF7SRiAiFVG5S8hJSYrj2Ssy+d3gDsxYnss5D83VRiAiP6Jyl5AUFWVc178dr1zTg+ISx4WPf8HkzzZqmkakjMpdQlq3Vg14b0JvBnRoxF/fW8VVz2azQ0+YFFG5S+irVzuWJy7txp/PzWDuul0MuH82j3yyTitbJaKp3CUsmBmX90rjw5v70jc9hfs/WMvAf87h3WXbNFUjEUnlLmGlVcMEHr+0Gy9d3YM6tWIY/+JiLnr8S5bl6IKrRBaVu4SlM9s25N0benPPBZ35dvdBhj3yObe+ulTz8RIxVO4StnxRxojMVGbd2p9r+rVh2pJtDLh/NhNnrdd8vIQ9lbuEvaT4GO4Yciof3tyXPunJ3DdzjebjJeyp3CVitGqYwBOX+nnx6iyS4qMZ/+JiLn7iS77Oyfc6mkjAqdwl4vRsm8x7E/rwPxd0ZmPeQYZNnMttry5lp+bjJYyo3CUi+aKMkZmpzLqtP+P6tOGtJVs1Hy9hReUuEa1OfAx3/OpUPrypH73alc7H/+KBObz/9XbNx0tIU7mLAK2TE5h0mZ8Xx2aRGBfNdVMW8ZtJX7F8q+bjJTSp3EXK6dmudD7+b+d3Yv3OA5z7yFx+99pSdu7XfLyEFpW7yI/4ooxLslox+7b+jO2dxpuLtzLgPs3HS2hRuYv8hDrxMfxxaAYf3NSPnmXz8YP+NYfpmo+XEKByF6lEWnICT17mZ8rYLGrHRPNbzcdLCFC5i1RRr3bJvDehN38974f5+N+/tkzz8RKUVO4ixyHaF8XoHq2YdWvpfPwbi3M46/45PDZ7g+bjJaio3EVOQN1aP8zH92jTkH/MWM2gf81hxnLNx0twULmLnIS05AQmj/Hz/FWZ1I6J5toXFjFC8/ESBFTuIgHQJz2F9yb05u7zOrF2x37Nx4vnVO4iARLti+LSHq2YfdsAruyVxuuLchhw32wena3746XmqdxFAqxurRjuPCeDD27qy5ltk7l3Runzat5bpvl4qTkqd5Fq0iYlkcljSu+PT4yL5voXF+n58VJjVO4i1axX2fNq/n5+6fPjz31kLre8ov1cpXqp3EVqgC/KGJVV+vz4a/q14Z2l2+h/32we+ngdh49qPl4CT+UuUoPqlO3n+tHN/ejfIYUHPlzLwH/O5u0lWzUfLwFVpXI3s8FmtsbM1pvZ7RW8f7OZrTSzZWb2sZm1CnxUkfCR2rA2j43uxtRxPaifEMuNU5dwwWNfsGjzd15HkzBRabmbmQ+YCAwBMoCRZpbxo2GLAb9zrgvwGnBvoIOKhKMebRoybXxv7r2wCznfHeaCR7/gxqmL2bb3sNfRJMRV5cw9E1jvnNvonDsKTAWGlx/gnJvlnDtUdvgV0CKwMUXCly/KuNjfklm39mf8gHZMX57LWf+czQMfruXQ0SKv40mIqkq5Nwe2lDvOKXvtp1wFTK/oDTMbZ2bZZpadl5dX9ZQiESAxLppbz+7AJ7f04xenNuahj9cx4P7ZvL4wh5ISzcfL8QnoBVUzGw34gfsqet85N8k553fO+VNSUgL51SJho0X92jwy6gxeu/ZMmtSJ55ZXl3Leo5+T/e0er6NJCKlKuW8FWpY7blH22jHM7BfAH4FhzrkjgYknErn8rRvw5nW9+NdvTmPnviNc+PiXXP/iIrbsOVT5H5aIV5VyXwCkm1mamcUCI4Bp5QeYWVfgCUqLfWfgY4pEpqgo4/yuLfjk1n7cODCdj1ftYOADc7h3xmoOHNF8vPy0SsvdOVcEjAdmAquAV5xzK8zsLjMbVjbsPiAReNXMlpjZtJ/4OBE5AbVjo7lpUHtm3dqfoZ2b8ujsDfS/bzavLNhCsebjpQLm1cIJv9/vsrOzPflukVC3ZMte7npnBYs276VjszrceU4GPdo09DqW1AAzW+ic81c2TitURULQ6S3r8fpve/LQyK7sPVTIiElfce3zC9m0+6DX0SRIqNxFQpSZMey0Znx8Sz9u/WV7Pl2Xx6AHPuV/3l/FvoJCr+OJx1TuIiEuPsbH+LPSmXVrf4af3oxJn21kwH2zmTJvk+bjI5jm3EXCzNc5+dz97krmf7uHtikJXHZma84/ozl14mO8jiYBUNU5d5W7SBhyzjF9eS6Pz9nAspx8asX4GHZaMy7pkUqXFvW8jicnQeUuIkDpmfyL8zfx1uJtHC4splPzOlyS1YphpzUjIS7a63hynFTuInKMfQWFvL14K1PmbWZ17n4S46I5v2tzRmWlcmrTOl7HkypSuYtIhZxzLNr8HVPmbebdZds5WlTCGan1uCSrFUO7NCU+xud1RPkZKncRqdTeQ0d5bWEOL87bzMZdB6lbK4Zfn9GCUVmptGuU6HU8qYDKXUSqzDnHVxv3MGXeJmauyKWw2NGjTQMuyWrF2R2bEButu6aDRVXLXVdTRAQz48y2DTmzbUPy9h/h1YVbeGn+Zm54aTENE2K5yN+SUZmppDas7XVUqSKduYtIhUpKHJ+t38WUrzbx8eqdFJc4+rZP4ZKsVAae0ohon87mvaBpGREJmNz8Al5esIWpCzazPb+AxnXi+E33VEZ0b0mzerW8jhdRVO4iEnBFxSXMWpPHlHmbmLM2DwPOOqUxl2Sl0rd9Cr4o8zpi2NOcu4gEXLQvikEZjRmU0Zgtew4xdcFmXl6Qw0erdtC8Xi1GZaVykb8FjZLivY4a8XTmLiIn5WhRCR+u3MGUeZv4YsNuoqOMszs2YVRWKme2aUiUzuYDSmfuIlIjYqOjGNqlKUO7NGVj3gFemr+ZVxfm8N7X20lLTmBUZioXnNGcholxXkeNKDpzF5GAKygsZvry7Uz5ajPZm77DF2VkpTVgSKcm/LJjExrX0bTNidIFVREJCmty9/PO0m1MX76dDXkHMYMzUuszpFMTBndqQov6unf+eKjcRSTorNuxn+nLc5m+PJdV2/cB0Ll5XQZ3asKQTk1ok6JHHlRG5S4iQW3T7oPfF/3SLXsB6NA4qbToOzehQ+MkzHQx9sdU7iISMrbtPczMFaVFv+DbPTgHackJDO7UhMEdm9ClRV0VfRmVu4iEpLz9R/hgZS4zlufyxYbdFJc4mterxdkdS8/ou6XWj+jbK1XuIhLy9h46yocrdzBjeS6frdvF0eISUpLiOLtjY4Z0akpWWoOIe8aNyl1Ewsr+gkI+Wb2TGctzmb0mj8OFxdSvHcOgjNKi79muIXHR4b/RiMpdRMLW4aPFzFm7k+nLc/lk1U72HykiKS6agac2YnCnpvRrn0Kt2PAseq1QFZGwVSvWx+BOTRncqSlHior5fP0uZizP5YOVO3hryTZqxfgYcEoKgzs1ZUCHFJLiY7yOXON05i4iYaOouIR53+xh+vLtzFyxg7z9R4iNjqJPu2QGd2rCwFMb0yAh1uuYJ0XTMiIS0YpLSjcCn/51LjOWb2dbfgFmpYum+qan0Ld9Cl1T6xETYhdkVe4iImWcc3y9NZ/Za/KYszaPJVv2UlziSIqL5sy2DenbPoV+7VNo2SD4H4WgchcR+Qn5hwv5Yv0uPl2Xx6drd7F172GgdOFUv/Yp9G2fTI82DakdG3yXJQNa7mY2GHgQ8AGTnXP3/Oj9vsD/Al2AEc651yr7TJW7iAQD5xwb8g7y6do8Pl2Xx1cbd1NQWEKsLwp/6/r0bZ9C3/QUTm0aHI9DCFi5m5kPWAsMAnKABcBI59zKcmNaA3WAW4FpKncRCVUFhcVkf/td2Vl9Hqtz9wOQkhRHn/Rk+rVPoXe7ZM+eTx/IWyEzgfXOuY1lHzwVGA58X+7OuW/L3is5obQiIkEiPsZH7/Rkeqcn84dfnUpufsH3Rf/J6p28sWgrZtCpWV36tk+mX/tGQXlhtirl3hzYUu44B8iqnjgiIsGlSd14Lva35GJ/S4pLSi/Mfrq2tOwfn7ORibM2kBgXTc8guzBbo1cLzGwcMA4gNTW1Jr9aROSk+aKM01vW4/SW9ZgwMJ38w4V8uWEXc9bu4tO1eXywcgdQemG2b3oyfdun0KNNQxLiav7CbFW+cSvQstxxi7LXjptzbhIwCUrn3E/kM0REgkXdWjHfr5R1zrFx18Hvz+pfyc7h2S83EeMz/K0alF6YbZ9MRtM6NXJhtirlvgBIN7M0Skt9BDCqWlOJiIQYM6NtSiJtUxK5olcaR4pKL8zOKSv7f8xYzT9mQHJiHHeecyrDT29erXkqLXfnXJGZjQdmUnor5NPOuRVmdheQ7ZybZmbdgTeB+sC5ZvYX51zHak0uIhLE4qJ99GqXTK92pRdmd+wrKLvdcleNbBCuRUwiIiGkqrdCBte9OyIiEhAqdxGRMKRyFxEJQyp3EZEwpHIXEQlDKncRkTCkchcRCUMqdxGRMOTZIiYzywM2neAfTwZ2BTBOqNPvcSz9Hj/Qb3GscPg9WjnnUiob5Fm5nwwzy67KCq1Iod/jWPo9fqDf4liR9HtoWkZEJAyp3EVEwlColvskrwMEGf0ex9Lv8QP9FseKmN8jJOfcRUTk54XqmbuIiPyMkCt3MxtsZmvMbL2Z3e51Hq+YWUszm2VmK81shZnd6HWmYGBmPjNbbGbvep3Fa2ZWz8xeM7PVZrbKzM70OpNXzOymsr8ny83sJTOr/t0yPBZS5W5mPmAiMATIAEaaWYa3qTxTBNzinMsAegDXR/BvUd6NwCqvQwSJB4EZzrlTgNOI0N/FzJoDEwC/c64TpTvKjfA2VfULqXIHMoH1zrmNzrmjwFRguMeZPOGc2+6cW1T2v/dT+he3ejdlDHJm1gIYCkz2OovXzKwu0BcSh32GAAABuUlEQVR4CsA5d9Q5t9fbVJ6KBmqZWTRQG9jmcZ5qF2rl3hzYUu44hwgvNAAzaw10BeZ5m8Rz/wv8DijxOkgQSAPygH+XTVNNNrMEr0N5wTm3Fbgf2AxsB/Kdcx94m6r6hVq5y4+YWSLwOvBfzrl9XufxipmdA+x0zi30OkuQiAbOAB5zznUFDgIReY3KzOpT+i/8NKAZkGBmo71NVf1Crdy3Ai3LHbcoey0imVkMpcU+xTn3htd5PNYLGGZm31I6XXeWmb3gbSRP5QA5zrn//GvuNUrLPhL9AvjGOZfnnCsE3gB6epyp2oVauS8A0s0szcxiKb0oMs3jTJ4wM6N0PnWVc+4Br/N4zTl3h3OuhXOuNaX/v/jEORf2Z2c/xTmXC2wxsw5lLw0EVnoYyUubgR5mVrvs781AIuDicrTXAY6Hc67IzMYDMym94v20c26Fx7G80gu4FPjazJaUvfYH59z7HmaS4HIDMKXsRGgjcIXHeTzhnJtnZq8Biyi9y2wxEbBSVStURUTCUKhNy4iISBWo3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMKQyl1EJAyp3EVEwtD/ByDn55V2MMDbAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "yhat = char_rnn.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }