{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Recurrent Neural Networks. \n", "\n", "### Many to One Classification by RNN\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `tf.keras.preprocessing.sequence.pad_sequences`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.12.0\n" ] } ], "source": [ "from __future__ import absolute_import, division, print_function\n", "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow import keras\n", "import string\n", "%matplotlib inline\n", "\n", "print(tf.__version__)\n", "\n", "tf.enable_eager_execution()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['<pad>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '*']\n" ] } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space = list(char_space)\n", "char_space.insert(0, '<pad>')\n", "print(char_space)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'<pad>': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, ' ': 27, '*': 28}\n" ] } ], "source": [ "char2idx = {char : idx for idx, char in enumerate(char_space)}\n", "print(char2idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### padding example data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "words = list(map(lambda word : [char2idx.get(char) for char in word],words))\n", "\n", "max_length = 10\n", "X_length = list(map(lambda word : len(word), words))\n", "X_indices = pad_sequences(sequences=words, maxlen=max_length, padding='post', truncating='post')\n", "\n", "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharRNN class" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class CharRNN(keras.Model):\n", " def __init__(self, num_classes, hidden_dim, max_length, dic):\n", " super(CharRNN, self).__init__()\n", "\n", " self.look_up = keras.layers.Embedding(input_dim=len(dic), output_dim=len(dic),\n", " trainable=False, mask_zero=True, input_length=max_length,\n", " embeddings_initializer=keras.initializers.Constant(np.eye(len(dic))))\n", " self.rnn_cell = keras.layers.SimpleRNN(units=hidden_dim, return_sequences=True,\n", " return_state=True)\n", " self.dense = keras.layers.Dense(units=num_classes)\n", " \n", " def call(self, inputs):\n", " token_representation = self.look_up(inputs) \n", " _, final_h = self.rnn_cell(token_representation)\n", " score = self.dense(final_h)\n", " return score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharRNN" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?, 10), (?, 2)), types: (tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "char_rnn = CharRNN(num_classes=2, hidden_dim=16, dic=char2idx, max_length=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def loss_fn(model, x, y):\n", " return tf.losses.softmax_cross_entropy(onehot_labels=y, logits=model(x))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "opt = tf.train.AdamOptimizer(learning_rate=lr)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.739\n", "epoch : 2, tr_loss : 0.649\n", "epoch : 3, tr_loss : 0.546\n", "epoch : 4, tr_loss : 0.444\n", "epoch : 5, tr_loss : 0.367\n", "epoch : 6, tr_loss : 0.272\n", "epoch : 7, tr_loss : 0.215\n", "epoch : 8, tr_loss : 0.156\n", "epoch : 9, tr_loss : 0.113\n", "epoch : 10, tr_loss : 0.088\n" ] } ], "source": [ "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " for x_mb, y_mb in tr_dataset:\n", " with tf.GradientTape() as tape:\n", " tr_loss = loss_fn(char_rnn, x=x_mb, y=y_mb)\n", " grads = tape.gradient(target=tr_loss, sources=char_rnn.variables)\n", " opt.apply_gradients(grads_and_vars=zip(grads, char_rnn.variables))\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " else:\n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x7f3de0704b38>]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd4lGW+//H3N5OEkBAIJdQkhI4gfYgURV11Qd0FC9WKDURYWbfqlnN+xz17zuqeZW2RpmBbiYgNd+2LqGCAhKJ0CC0JRUKHUJLA/fsj0Q0smhEmeTKTz+u6uC6fZ25mPo7mw8NT7tucc4iISHiJ8DqAiIgEn8pdRCQMqdxFRMKQyl1EJAyp3EVEwpDKXUQkDKncRUTCkMpdRCQMqdxFRMJQpFcf3KhRI5eamurVx4uIhKSlS5fucc4lVjTOs3JPTU0lOzvbq48XEQlJZrYtkHE6LSMiEoZU7iIiYUjlLiIShlTuIiJhSOUuIhKGVO4iImFI5S4iEoZCrty37ink0ffWUXLylNdRRESqrZAr9/dX7+Lp+Zu4feYS9hcWeR1HRKRaCrlyH3tpGx4d2pWsLfsZnL6AdbsOeR1JRKTaCblyBxjuTyZjbB9OFJ/ihqc/592VO72OJCJSrYRkuQP0TKnP2z+5mPZN4hn3t2VM+mA9p045r2OJiFQLIVvuAE3qxpAxpg/DeiXxxLwcxryYzeHjxV7HEhHxXEiXO0BMlI9Hh3blvwZ35uP1BVz/9OdsLjjidSwREU+FfLkDmBm390vlxbvS2HvkBEPSFzJ//W6vY4mIeCYsyv1r/do0Yu6Ei2mRUJs7nstiyiebcE7n4UWk5gmo3M1skJmtN7McM3vwLK//1cxWlP3aYGYHgh81MMkNYnn9vn5c06UZf3p3HRMzVnCs6KRXcUREPFHhSkxm5gPSgauAfCDLzOY659Z8PcY590C58T8BelRC1oDFRkfy1KgedGpWl//7YD2bCo4w9dZeJNWP9TKWiEiVCeTIPQ3Icc5tds4VARnAkO8YPwqYFYxw58PMGH95W5693U/u3qMMfmohizbv9TqWiEiVCKTcWwB55bbzy/b9GzNrCbQC5p1/tOD4QccmvDmhPwmxUdzyzGJezNyq8/AiEvaCfUF1JDDHOXfWk9xmNsbMss0su6CgIMgf/e3aJNbhzfH9GdA+kd+/tZqHXl/JiRKdhxeR8BVIuW8HksttJ5XtO5uRfMcpGefcNOec3znnT0xMDDxlENSNiWL6bX7GX96GjKw8bpq+mN2HjldpBhGRqhJIuWcB7cyslZlFU1rgc88cZGYdgfpAZnAjBo8vwvjlwI6k39STNTsO8eOnFrAiz7Mbe0REKk2F5e6cKwEmAO8Da4HZzrnVZvawmQ0uN3QkkOFC4IT2tV2b8dq4fkT5Ihg+NZPXluZ7HUlEJKjMqy72+/0uOzvbk8/+2r7CIsb/bRmZm/dyZ/9W/OaajkT6wuq5LhEJM2a21Dnnr2hcjW6yBnHRvHBXGqP7pTJj4RYtACIiYaNGlztAlC+C/ze4sxYAEZGwUuPL/WvD/cm8ogVARCRMqNzL6VG2AEiHploARERCm8r9DF8vADLc//UCIEu1AIiIhByV+1nUivTxyI1fLwCyWwuAiEjIUbl/i68XAHnprovYV1ikBUBEJKSo3CvQt01D3hrfn+T6sVoARERChso9AMkNYnltXD+u1QIgIhIiVO4Bqh3t48lRPfjVoA68/eUOhk75nO0HjnkdS0TkrFTu34OZcd9lbZlxe29y9x1lyFMLWbNDDzyJSPWjcj8Hl3dszBv39SPaZ4yYlsmSLfu8jiQichqV+zlq2zieV8f1IzG+Frc+u5h5677yOpKIyDdU7uehRUJtXh3blw5N47nnhaW8sVxTB4tI9aByP08N69Ti5Xv6cFGrBjzwyhfMXLjF60giIir3YKhTK5IZo3szsHMT/uvtNUz6YL3uhRcRT6ncgyQmykf6TT0Z4U/miXk5/MdbqzXpmIh4JtLrAOEk0hfBn27sQkJcFFM/2cyBY8X8ZVg3oiP1Z6iIVC2Ve5CZGQ9dfQH1Y6P507vrOHismCm39CQ2Wl+1iFQdHVJWknsvbcMjN3ZhwcYCbnlmMQeOavk+Eak6KvdKNKJ3Ck/f3ItV2w8xYuoivjp03OtIIlJDqNwr2aALm/LcHb3J33+UGyd/ztY9hV5HEpEaIKByN7NBZrbezHLM7MFvGTPczNaY2Wozezm4MUNbv7aNmDWmD0eLTjJ0Siardxz0OpKIhLkKy93MfEA6cDXQCRhlZp3OGNMOeAjo75zrDPy0ErKGtK5JCcwe25donzFy6iLNRyMilSqQI/c0IMc5t9k5VwRkAEPOGHMPkO6c2w/gnNOSRWfRtnGd0vlo6pbOR/PPtZqPRkQqRyDl3gLIK7edX7avvPZAezNbaGaLzGzQ2d7IzMaYWbaZZRcUFJxb4hBXfj6aMS8u5fVlmo9GRIIvWBdUI4F2wGXAKGC6mSWcOcg5N80553fO+RMTE4P00aGn/Hw0P5v9BTMWaD4aEQmuQMp9O5BcbjupbF95+cBc51yxc24LsIHSspdvUX4+mof/rvloRCS4Ain3LKCdmbUys2hgJDD3jDFvUnrUjpk1ovQ0zeYg5gxLZ85H8/u3VnFS89GISBBU+Ey8c67EzCYA7wM+YIZzbrWZPQxkO+fmlr32QzNbA5wEfumc21uZwcPFv81Hc7SYScO7az4aETkv5tWpAL/f77Kzsz357Opq6ieb+N931zGgfaLmoxGRszKzpc45f0XjdHhYjYy9tA2P3tiVBRsLuFnz0YjIeVC5VzPDeyfz9M29WK35aETkPKjcqyHNRyMi50vlXk2dPh/N55qPRkS+F5V7Nfav+WgiGDl1EYs36wYkEQmMyr2aa9u4DnPG9aNx3VrcNmMJH63RfDQiUjGVewhonlCbV+/tR4em8Yx9SfPRiEjFVO4hokFctOajEZGAqdxDyNfz0Qzq3JSH/76Gv2g+GhH5Fir3EBMT5SP95tL5aJ6cl8Pv3tR8NCLy7/R8ewjyRdjp89EcK+avmo9GRMpRuYcoM+Ohqy+gQWw0//vuOg4dK2bKLb2Iq6X/pCKi0zIhb+ylbXh0aFcW5uzh5mcWs79Q89GIiMo9LAz3JzP5ll6s2XmIYVMz2XnwmNeRRMRjKvcwMbBzU56/I41dB48zdHImmwqOeB1JRDykcg8jfds0JGNMH06UnGTYlEy+zD/gdSQR8YjKPcxc2KIer97bj9hoH6OmLWJhzh6vI4mIB1TuYahVozheG9ePpPqx3DEzi3dX7vQ6kohUMZV7mGpSN4bZY/vSJake419exqwluV5HEpEqpHIPY/Vio3jprosY0D6Rh15fSfrHOZquQKSGULmHudrRPqbf5ue67s358/vr+e9/rOWUpisQCXsBlbuZDTKz9WaWY2YPnuX10WZWYGYryn7dHfyocq6ifBFMGt6d0f1SeXbBFn7x6hcUnzzldSwRqUQVPqtuZj4gHbgKyAeyzGyuc27NGUNfcc5NqISMEgQREcZ//rgTDeOi+cuHGzh4rJinbupJ7Wif19FEpBIEcuSeBuQ45zY754qADGBI5caSymBm/OSKdvzhuguZt343t81YzMFjxV7HEpFKEEi5twDyym3nl+07041m9qWZzTGz5KCkk0pxa5+WPDmqByvyDjBiaia7Dx33OpKIBFmwLqi+DaQ657oCHwLPn22QmY0xs2wzyy4oKAjSR8u5+FHX5swY3ZvcfUcZOiWTbXsLvY4kIkEUSLlvB8ofiSeV7fuGc26vc+5E2eYzQK+zvZFzbppzzu+c8ycmJp5LXgmiS9ol8vI9fTh0vJgbJ2eyZschryOJSJAEUu5ZQDsza2Vm0cBIYG75AWbWrNzmYGBt8CJKZeqenMCce/sS5TNGTMtkyZZ9XkcSkSCosNydcyXABOB9Skt7tnNutZk9bGaDy4bdb2arzewL4H5gdGUFluBr2zieOeP6kRhfi1ufXcxHa77yOpKInCfz6olFv9/vsrOzPflsObu9R05wx3NZrN5xiEdu7MrQXkleRxKRM5jZUuecv6JxekJVvtGwTi1evqcPfVo34BevfsEzn232OpKInCOVu5ymTq1IZozuzTVdmvLf/1jLI++t03w0IiFIqynLv6kV6ePJUT1JiF3F5Pmb2F9YxB+v74IvwryOJiIBUrnLWfkijD9edyENYqN56uMcDhwt5rGR3YmJ0nQFIqFAp2XkW5kZvxjYgd//qBPvrd7FHTOzOHxc0xWIhAKVu1TorotbMWl4N5Zs3cdN0xez58iJin+TiHhK5S4BuaFnEtNv68WGrw4zfEom+fuPeh1JRL6Dyl0C9oOOTXjp7ovYc+QEQydnsuGrw15HEpFvoXKX76V3agNeGduXk84xbEomy3L3ex1JRM5C5S7f2wXN6vLavf1IiI3i5umL+WSDZvgUqW5U7nJOUhrG8uq9fUltFMfdz2cx94sdXkcSkXJU7nLOGsfH8MrYPvRIqc/EjOU88c+NlGhtVpFqQeUu56VuTBQv3JnG4G7NmfThBoZOyWRzwRGvY4nUeCp3OW8xUT4eH9mDJ0f1YMueQq554jNeyNyqOWlEPKRyl6D5cbfmfPDAAPq0bsh/vLWa22YsYefBY17HEqmRVO4SVE3qxjBzdG/+eP2FZG/dz8C/fspbK7brKF6kiqncJejMjJsvasm7Ey+hXZN4JmasYMLLy9lfWOR1NJEaQ+UulSa1URyzx/bl14M68sGaXfzwsU+Zt05L+IlUBZW7VCpfhDHusja8Nf5iGsZFc+dz2Tz0+pccOVHidTSRsKZylyrRqXld3prQn3svbUNGVh5XP/4pS7bs8zqWSNhSuUuVqRXp48GrOzJ7bF8MY8S0TP73nbWcKDnpdTSRsKNylyrXO7UB7068hFFpKUz9dDODn1zI6h0HvY4lElYCKnczG2Rm680sx8we/I5xN5qZMzN/8CJKOIqrFcn/XN+FmaN7s+9oEdelLyT94xxNXyASJBWWu5n5gHTgaqATMMrMOp1lXDwwEVgc7JASvi7v2JgPfjqAH3Zqyp/fX8/wqZls3VPodSyRkBfIkXsakOOc2+ycKwIygCFnGfcH4BHgeBDzSQ1QPy6ap27qweMju5Oz+whXP/4ZLy3apgefRM5DIOXeAsgrt51ftu8bZtYTSHbO/eO73sjMxphZtpllFxRoDnD5FzNjSPcWfPDApfhT6/O7N1dx+8wsdh3UsYLIuTjvC6pmFgFMAn5e0Vjn3DTnnN85509MTDzfj5Yw1LReDC/cmcYfhnRmyZa9DHzsU80VL3IOAin37UByue2ksn1fiwcuBOab2VagDzBXF1XlXJkZt/ZN5d2JA2idGMf9s5Yz4eVlmr5A5HsIpNyzgHZm1srMooGRwNyvX3TOHXTONXLOpTrnUoFFwGDnXHalJJYao1WjOF4d25dfDuzAe6t2MfCxT/l4/W6vY4mEhArL3TlXAkwA3gfWArOdc6vN7GEzG1zZAaVmi/RFMP7ytrw5vj8JsVHcMTOL37yxkkJNXyDyncyrOxL8fr/LztbBvQTuePFJJn24gemfbSa5fiyThnfDn9rA61giVcrMljrnKjztrSdUJWTERPn4zTUXkHFPH045x/CpmTzy3jpNXyByFip3CTkXtW7Iez8dwHB/MpPnb2LIUwtZu/OQ17FEqhWVu4SkOrUi+dONXXn2dj97jhQx+KkFTJ6/iZOn9OCTCKjcJcRdcUETPnhgAFde0IRH3lvHiKmZ7DigdVtFVO4S8hrERfP0zT3564hurNt1mOvSF7IyX7NMSs2mcpewYGZc3yOJ18b1I8oXwfCpmby/epfXsUQ8o3KXsNKhaTxvju9Ph6bx3PvSUqZ9ukkTkEmNpHKXsJMYX4uMMX245sJm/M876/jNGysp1jzxUsNEeh1ApDLERPl4clQPUhvFkv7xJvL2HSP95p7Uqx3ldTSRKqEjdwlbERHGLwd25M9Du7J4y15unPw5efuOeh1LpEqo3CXsDfMn88KdF1Fw+ATXpS9k6bZ9XkcSqXQqd6kR+rZpyBv39SM+JpJR0xdrjngJeyp3qTFaJ9bh9fv60z0pgftnLeeJf27UnTQStlTuUqM0iIvmxbvTuKFHCyZ9uIGfz/5CE49JWNLdMlLj1Ir08Zfh3WjVKI6/fLiB/P3HmHJrLxrERXsdTSRodOQuNZKZ8ZMr2vHEqB6syD/A9U8vZFPBEa9jiQSNyl1qtMHdmjPrnj4cOV7CDU9/TuamvV5HEgkKlbvUeL1a1ufN8f1JjK/FbTMWMzs7z+tIIudN5S4CJDeI5bVx/bioVUN+NedLHn1vHac0N7yEMJW7SJl6taOYeUdvRqWl8PT8TUyYtYzjxbqTRkKTyl2knChfBP9z/YX89poLeHfVLkZMW0TB4RNexxL53gIqdzMbZGbrzSzHzB48y+v3mtlKM1thZgvMrFPwo4pUDTPjngGtmXJLLzaULf6xftdhr2OJfC8VlruZ+YB04GqgEzDqLOX9snOui3OuO/AoMCnoSUWq2MDOTZk9ti/FJ09x4+TP+WRDgdeRRAIWyJF7GpDjnNvsnCsCMoAh5Qc458ovPR8H6EqUhIUuSfV4a0J/khvEcudzWby4aJvXkUQCEki5twDK3xuWX7bvNGY23sw2UXrkfn9w4ol4r1m92rx6b18ubZ/I799cxcNvr+Gk7qSRai5oF1Sdc+nOuTbAr4HfnW2MmY0xs2wzyy4o0F9xJXTUqRXJ9Nv83NE/lRkLtzD2xWwKT5R4HUvkWwVS7tuB5HLbSWX7vk0GcN3ZXnDOTXPO+Z1z/sTExMBTilQDvgjjP3/cmYeHdGbeut0Mm5LJzoPHvI4lclaBlHsW0M7MWplZNDASmFt+gJm1K7d5LbAxeBFFqpfb+qYyY3Rvcvcd5br0hazaftDrSCL/psJyd86VABOA94G1wGzn3Goze9jMBpcNm2Bmq81sBfAz4PZKSyxSDVzWoTFzxvUlMiKCYVMy+WD1Lq8jiZzGvFqswO/3u+zsbE8+WyRYdh8+zj3PZ/Pl9oP89poLuOviVpiZ17EkjJnZUuecv6JxekJV5Dw0jo8hY0xfBnVuyn//Yy2/fXMVxSdPeR1LROUucr5qR/tIv6kn4y5rw8uLc7nzuSwOHS/2OpbUcCp3kSCIiDB+Pagjj97YlcxNe7k+fSGfbCjQGq3iGZW7SBAN753MC3elcbz4FLfPWMKwKZl8nrPH61hSA6ncRYKsX5tGzPvFpfzhugvJ33+Mm55ZzMhpmSzZss/raFKD6G4ZkUp0vPgks5bk8vT8TRQcPsHFbRvxwFXt6dWyvtfRJEQFereMyl2kChwrOsnfFm9j8vxN7C0s4rIOiTxwZXu6JSd4HU1CjMpdpBoqPFHCC5nbmPrpJg4cLebKC5rwwFXt6Ny8ntfRJESo3EWqscPHi3lu4Vamf7aZQ8dLuPrCpvz0yvZ0aBrvdTSp5lTuIiHg4LFiZizYwowFWzhSVMK1XZrx0yvb07ZxHa+jSTWlchcJIQeOFjH9s83MXLiV48UnGdK9BROvaEdqozivo0k1o3IXCUF7j5xg2qebeT5zK8UnHTf0aMH9V7QjuUGs19GkmlC5i4Sw3YePM2X+Zl5avI1TpxzD/MlM+EFbWiTU9jqaeEzlLhIGvjp0nPSPc8hYUrrS5ci0ZO67rC1N68V4nEy8onIXCSPbDxwj/eMcZmflERFh3HxRCuMua0PjeJV8TaNyFwlDefuO8uS8jby2bDtRPuO2vqmMHdCahnVqeR1NqojKXSSMbd1TyBPzNvLm8u3ERPkY3S+Vey5pTf24aK+jSSVTuYvUADm7j/D4Pzfy9y93EBcdyZ0Xt+Kui1tRr3aU19GkkqjcRWqQ9bsO8/g/N/DOyl3UjYnknktaM7p/KvExKvlwo3IXqYHW7DjEXz/awIdrviIhNoqxA9owul8qtaN9XkeTIFG5i9RgK/MPMunD9Xy8voDG8bW4/4p2jOidTJRPSziEuqAukG1mg8xsvZnlmNmDZ3n9Z2a2xsy+NLN/mlnLcwktIsHRJakeM+9IY869fWnZMJbfvbmKKyd9wlsrtnPqlJb+qwkqLHcz8wHpwNVAJ2CUmXU6Y9hywO+c6wrMAR4NdlAR+f78qQ2YPbYvM0f3pnaUj4kZK7j2yQV8vH631ncNc4EcuacBOc65zc65IiADGFJ+gHPuY+fc0bLNRUBScGOKyLkyMy7v2Jh37r+Ex0d2p/BECXfMzGLE1EVkb9XSf+EqkHJvAeSV284v2/dt7gLePZ9QIhJ8ERHGkO4t+Ohnpeu7btlbyNApmdz9fBbrdh3yOp4EWVCvrpjZLYAf+PO3vD7GzLLNLLugoCCYHy0iAYqOjODWPi355JeX8cuBHVi8ZR9XP/4ZD7yygty9Ryt+AwkJgZT7diC53HZS2b7TmNmVwG+Bwc65E2d7I+fcNOec3znnT0xMPJe8IhIksdGRjL+8LZ/96nLGDmjDOyt3csWk+fznW6soOHzWH2EJIRXeCmlmkcAG4ApKSz0LuMk5t7rcmB6UXkgd5JzbGMgH61ZIkepl18HjPDFvI69k5RHti+Cui1sx5tLW1NWDUNVKUO9zN7NrgMcAHzDDOfdHM3sYyHbOzTWzj4AuwM6y35LrnBv8Xe+pchepnrbsKWTShxt4+4sdJMRGcd9lbbitbyoxUXoQqjrQQ0wicl5WbT/In99fzycbCmhaN4aJV7ZjWK8kIvUglKeC+hCTiNQ8F7aox/N3ppExpg/NE2J46PWV/PCvn/KPL3fqQagQoHIXke/Up3VDXhvXj+m3+Yn0GeNfXsbg9AV8uqFAD0JVYyp3EamQmXFVpya8O3EAfxnWjf2Fxdw2Ywk3TV/M8tz9XseTs9A5dxH53k6UnGTW4lyenJfD3sIiftipCb8Y2IH2TeK9jhb2dEFVRCrdkRMlzFiwhWmfbuZoUQnX90jigavakVQ/1utoYUvlLiJVZl9hEZPn5/B85jZwcHOfFMZf3pZGWts16FTuIlLldhw4xuMfbeTVpXnUjvJx9yWtufuSVloRKohU7iLimZzdR5j04XreWbmL+rFR3H1Ja4b7k0mM15H8+VK5i4jnvsg7wP99sJ7PNu4hMsK48oImjEhLZkC7RHwR5nW8kKRyF5FqI2f3EV7JyuW1ZdvZV1hEi4TaDPMnMcyfTIuE2l7HCykqdxGpdk6UnOSjNbvJyMrls417MINL2ycysncKV1zQWGu8BkDlLiLVWt6+o7ySlcerS/P46tAJEuNrMbRXEiN7J9OyYZzX8aotlbuIhISSk6eYv76AjKxc5q3bzSkHfVs3ZGRaMgM7N9VslGdQuYtIyNl18DivZufxSnYe+fuPkRAbxQ09khiZlqynX8uo3EUkZJ065Vi4aQ8ZS/L4YM0uik86eqYkMDIthR91bUZsdKTXET2jcheRsLDnyAneWLadWVm5bC4oJL5WJIO7N2dk7xS6JNXzOl6VU7mLSFhxzpG1dT8ZS3L5x8qdnCg5RefmdRmZlsKQ7s1rzHKAKncRCVsHjxbz1hfbmbUkj7U7D1E7yse1XZsxKi2Znin1MQvfB6RU7iIS9pxzfJl/kIysXOau2EFh0UnaNa7DiN7J3NAziQZx0V5HDDqVu4jUKIUnSvj7lzuYtSSPFXkHiPZF8MPOTRiVlkLf1g2JCJPpDlTuIlJjrdt1iIwleby+LJ9Dx0to2TCW4f7S++ZbN4oL6aJXuYtIjXe8+CTvrdrFrCW5LN6yD4B6taPonpxAj5QEeqbUp1tyAvVqh87F2KCWu5kNAh4HfMAzzrk/nfH6AOAxoCsw0jk3p6L3VLmLSFXatreQxVv2sTx3P8tzD7D+q8N8XX9tG9ehZ0oCPVLq0zOlPm0b16m2s1YGrdzNzAdsAK4C8oEsYJRzbk25MalAXeAXwFyVu4hUd4ePF/Nl/kGW5+5nWe4BlufuZ//RYgDq1IqkW3I9eqbUp0dKAt2T61ebi7OBlnsgj3mlATnOuc1lb5wBDAG+KXfn3Nay106dU1oRkSoWHxNF/7aN6N+2EVB65822vUdZVnZkvzxvP0/P38TJU6UHwK0axdGj7HROj5T6dGwaT2Q1nsUykHJvAeSV284HLjqXDzOzMcAYgJSUlHN5CxGRSmFmpDaKI7VRHDf0TALgaFEJK/MPfnNk/+nGPby+fDsAtaN8dE2qR4+yo/seKQk0jo/x8l/hNFU6QYNzbhowDUpPy1TlZ4uIfF+x0ZFc1LohF7VuCJQe3efvP8byvAMs27af5XkHeHbBZopPltZZUv3aZeftS4/uOzWrS3SkN0f3gZT7diC53HZS2T4RkRrFzEhuEEtyg1gGd2sOlN6Rs3rHwdJTObkHyN66j7e/2AFAdGQEXVrUKzudU5+eLRNoVq9qVp4KpNyzgHZm1orSUh8J3FSpqUREQkRMlI9eLRvQq2WDb/btPHisrOxLz9+/sGgbzyzYAkDTujE8dE1HhnRvUam5Kix351yJmU0A3qf0VsgZzrnVZvYwkO2cm2tmvYE3gPrAj83sv5xznSs1uYhINdWsXm2adanNNV2aAVBUcoq1Ow99c2dOVZyb10NMIiIhJNBbIavvfTwiInLOVO4iImFI5S4iEoZU7iIiYUjlLiIShlTuIiJhSOUuIhKGVO4iImHIs4eYzKwA2HaOv70RsCeIcUKdvo/T6fv4F30XpwuH76Olcy6xokGelfv5MLPsQJ7Qqin0fZxO38e/6Ls4XU36PnRaRkQkDKncRUTCUKiW+zSvA1Qz+j5Op+/jX/RdnK7GfB8hec5dRES+W6geuYuIyHcIuXI3s0Fmtt7McszsQa/zeMXMks3sYzNbY2arzWyi15mqAzPzmdlyM/u711m8ZmYJZjbHzNaZ2Voz6+t1Jq+Y2QNlPyerzGyWmVWflawrSUiVu5n5gHTgaqATMMrMOnmbyjMlwM+dc52APsD4GvxdlDcRWOt1iGriceA951xHoBs19HsxsxbA/YDfOXchpSvKjfQ2VeULqXIH0oAc59xm51wRkAEM8TiTJ5xzO51zy8r++TClP7han/riAAABzUlEQVSVuyhjNWdmScC1wDNeZ/GamdUDBgDPAjjnipxzB7xN5alIoLaZRQKxwA6P81S6UCv3FkBeue18anihAZhZKtADWOxtEs89BvwKOOV1kGqgFVAAzCw7TfWMmcV5HcoLzrntwP8BucBO4KBz7gNvU1W+UCt3OYOZ1QFeA37qnDvkdR6vmNmPgN3OuaVeZ6kmIoGewGTnXA+gEKiR16jMrD6lf8NvBTQH4szsFm9TVb5QK/ftQHK57aSyfTWSmUVRWux/c8697nUej/UHBpvZVkpP1/3AzF7yNpKn8oF859zXf5ubQ2nZ10RXAluccwXOuWLgdaCfx5kqXaiVexbQzsxamVk0pRdF5nqcyRNmZpSeT13rnJvkdR6vOececs4lOedSKf3/Yp5zLuyPzr6Nc24XkGdmHcp2XQGs8TCSl3KBPmYWW/ZzcwU14OJypNcBvg/nXImZTQDep/SK9wzn3GqPY3mlP3ArsNLMVpTt+41z7h0PM0n18hPgb2UHQpuBOzzO4wnn3GIzmwMso/Qus+XUgCdV9YSqiEgYCrXTMiIiEgCVu4hIGFK5i4iEIZW7iEgYUrmLiIQhlbuISBhSuYuIhCGVu4hIGPr/1mKGnKp2jjkAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc : 100.00%\n" ] } ], "source": [ "yhat = np.argmax(char_rnn(inputs=tf.convert_to_tensor(X_indices)), axis=-1)\n", "print('acc : {:.2%}'.format(np.mean(yhat == np.argmax(y, axis=-1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }