{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CS 20 : TensorFlow for Deep Learning Research\n", "## Lecture 11 : Recurrent Neural Networks\n", "Simple example for Many to One Classification (word sentiment classification) by Bi-directional Gated Recurrent Unit.\n", "\n", "### Many to One Classification by Bi-directional GRU\n", "- Creating the **data pipeline** with `tf.data`\n", "- Preprocessing word sequences (variable input sequence length) using `padding technique` by `user function (pad_seq)`\n", "- Using `tf.nn.embedding_lookup` for getting vector of tokens (eg. word, character)\n", "- Creating the model as **Class**\n", "- Reference\n", " - https://github.com/golbin/TensorFlow-Tutorials/blob/master/10%20-%20RNN/02%20-%20Autocomplete.py\n", " - https://github.com/aisolab/TF_code_examples_for_Deep_learning/blob/master/Tutorial%20of%20implementing%20Sequence%20classification%20with%20RNN%20series.ipynb\n", " - https://pozalabs.github.io/blstm/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.8.0\n" ] } ], "source": [ "import os, sys\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import string\n", "%matplotlib inline\n", "\n", "slim = tf.contrib.slim\n", "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare example data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "words = ['good', 'bad', 'amazing', 'so good', 'bull shit', 'awesome']\n", "y = [[1.,0.], [0.,1.], [1.,0.], [1., 0.],[0.,1.], [1.,0.]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz *'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Character quantization\n", "char_space = string.ascii_lowercase \n", "char_space = char_space + ' ' + '*'\n", "char_space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26, '*': 27}\n" ] } ], "source": [ "char_dic = {char : idx for idx, char in enumerate(char_space)}\n", "print(char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create pad_seq function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def pad_seq(sequences, max_len, dic):\n", " seq_len, seq_indices = [], []\n", " for seq in sequences:\n", " seq_len.append(len(seq))\n", " seq_idx = [dic.get(char) for char in seq]\n", " seq_idx += (max_len - len(seq_idx)) * [dic.get('*')] # 27 is idx of meaningless token \"*\"\n", " seq_indices.append(seq_idx)\n", " return seq_len, seq_indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply pad_seq function to data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "max_length = 10\n", "X_length, X_indices = pad_seq(sequences = words, max_len = max_length, dic = char_dic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4, 3, 7, 7, 9, 7]\n", "(6, 10)\n" ] } ], "source": [ "print(X_length)\n", "print(np.shape(X_indices))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CharBiGRU class" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class CharBiGRU:\n", " def __init__(self, X_length, X_indices, y, n_of_classes, hidden_dim, dic):\n", " \n", " # data pipeline\n", " with tf.variable_scope('input_layer'):\n", " self._X_length = X_length\n", " self._X_indices = X_indices\n", " self._y = y\n", " \n", " one_hot = tf.eye(len(dic), dtype = tf.float32)\n", " self._one_hot = tf.get_variable(name='one_hot_embedding', initializer = one_hot,\n", " trainable = False) # embedding vector training 안할 것이기 때문\n", " self._X_batch = tf.nn.embedding_lookup(params = self._one_hot, ids = self._X_indices)\n", " \n", " # Bi-directional GRU\n", " with tf.variable_scope('bi-directional_gru'):\n", " gru_fw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " gru_bw_cell = tf.contrib.rnn.GRUCell(num_units = hidden_dim, activation = tf.nn.tanh)\n", " _, output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw = gru_fw_cell,\n", " cell_bw = gru_bw_cell,\n", " inputs = self._X_batch,\n", " sequence_length = self._X_length,\n", " dtype = tf.float32)\n", "\n", " final_state = tf.concat([output_states[0], output_states[1]], axis = 1)\n", "\n", " with tf.variable_scope('output_layer'):\n", " self._score = slim.fully_connected(inputs = final_state, num_outputs = n_of_classes,\n", " activation_fn = None)\n", " \n", " with tf.variable_scope('loss'):\n", " self.ce_loss = tf.losses.softmax_cross_entropy(onehot_labels = self._y, logits = self._score)\n", " \n", " with tf.variable_scope('prediction'):\n", " self._prediction = tf.argmax(input = self._score, axis = -1, output_type = tf.int32)\n", " \n", " def predict(self, sess, X_length, X_indices):\n", " feed_prediction = {self._X_length : X_length, self._X_indices : X_indices}\n", " return sess.run(self._prediction, feed_dict = feed_prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model of CharBiGRU" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "# hyper-parameter#\n", "lr = .003\n", "epochs = 10\n", "batch_size = 2\n", "total_step = int(np.shape(X_indices)[0] / batch_size)\n", "print(total_step)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<BatchDataset shapes: ((?,), (?, 10), (?, 2)), types: (tf.int32, tf.int32, tf.float32)>\n" ] } ], "source": [ "## create data pipeline with tf.data\n", "tr_dataset = tf.data.Dataset.from_tensor_slices((X_length, X_indices, y))\n", "tr_dataset = tr_dataset.shuffle(buffer_size = 20)\n", "tr_dataset = tr_dataset.batch(batch_size = batch_size)\n", "tr_iterator = tr_dataset.make_initializable_iterator()\n", "print(tr_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_length_mb, X_indices_mb, y_mb = tr_iterator.get_next()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "char_bi_gru = CharBiGRU(X_length = X_length_mb, X_indices = X_indices_mb, y = y_mb,\n", " n_of_classes = 2, hidden_dim = 16, dic = char_dic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creat training op and train model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "## create training op\n", "opt = tf.train.AdamOptimizer(learning_rate = lr)\n", "training_op = opt.minimize(loss = char_bi_gru.ce_loss)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch : 1, tr_loss : 0.670\n", "epoch : 2, tr_loss : 0.624\n", "epoch : 3, tr_loss : 0.598\n", "epoch : 4, tr_loss : 0.550\n", "epoch : 5, tr_loss : 0.518\n", "epoch : 6, tr_loss : 0.489\n", "epoch : 7, tr_loss : 0.456\n", "epoch : 8, tr_loss : 0.431\n", "epoch : 9, tr_loss : 0.394\n", "epoch : 10, tr_loss : 0.366\n" ] } ], "source": [ "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "tr_loss_hist = []\n", "\n", "for epoch in range(epochs):\n", " avg_tr_loss = 0\n", " tr_step = 0\n", " \n", " sess.run(tr_iterator.initializer)\n", " try:\n", " while True:\n", " _, tr_loss = sess.run(fetches = [training_op, char_bi_gru.ce_loss])\n", " avg_tr_loss += tr_loss\n", " tr_step += 1\n", " \n", " except tf.errors.OutOfRangeError:\n", " pass\n", " \n", " avg_tr_loss /= tr_step\n", " tr_loss_hist.append(avg_tr_loss)\n", " \n", " print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x115f56f28>]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VPW9//HXJzs7kUS2AIEQoLLDsMuiiKJtQQUUpC1YFUERUKu1Pu793V57b12LqAVkcasFWbWlboiyyiYJsshq2MMadpAlBL6/PzLawEUzkOVMZt7PxyMPOWfOybwzwpvDfL/fM+acQ0REwkOE1wFERKT4qPRFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIxEeR3gUgkJCS45OdnrGCIiJUp6evpB51xifscFXeknJyeTlpbmdQwRkRLFzHYEcpze3hERCSMqfRGRMKLSFxEJIyp9EZEwotIXEQkjKn0RkTCi0hcRCSMhU/rnLzj+/PEGdh0+5XUUEZGgFTKlv+PQd0z5aid3jl3C2sxjXscREQlKIVP6dRLLMnNIe2IiI7h7/FLmbTzgdSQRkaATMqUPkFq5HB881J46iWW4/29pTF6+0+tIIiJBJaRKH+Da8nFMHdSOTqkJPP3BWl6cvRHnnNexRESCQsiVPkCZ2Cgm/MZHv9Y1GD1vC49OXcXZnPNexxIR8VzQ3WWzsERFRvDnOxqTFF+aF2dvYt/xM4z7tY8KpaK9jiYi4pmQvNL/npnx8A11efnupqTvOELvsUvYffS017FERDwT0qX/vTuaJ/HOb1uz7/gZ7hi9mG92a0qniISnsCh9gPYpCcwc0p6oCOPucUuZv0lTOkUk/IRN6QPUq1yODx7uQK1KZbjvnTSmrtCUThEJLwGVvpl1N7NNZpZhZk/9yDF3mdl6M1tnZpPz7D9vZqv8X7MKK/jVqlw+jmmD29GhbgK/n7mWkZ9t0pROEQkb+c7eMbNIYDTQDcgEVpjZLOfc+jzHpAJ/ADo4546Y2bV5vsVp51yzQs5dIGVjo3hjgI//+OAbXp2bQeaR0zzXqwkxUWH1Dx8RCUOBTNlsDWQ457YCmNkUoCewPs8xDwCjnXNHAJxzQf+GeXRkBM/1akxSfCn+Mmcz+0+cYeyvWlI+TlM6RSR0BXJpWx3YlWc7078vr3pAPTNbbGbLzKx7nsfizCzNv//2yz2BmQ3yH5OWlZV1RT9AQZgZj3RN5S99mrJ862H6jF3KHk3pFJEQVljvZ0QBqUAXoB8wwcwq+h+r5ZzzAfcAo8ws5dKTnXPjnXM+55wvMTGxkCIFrlfL3Cmde46e5o4xi1m/53ixZxARKQ6BlP5uoEae7ST/vrwygVnOuXPOuW3AZnL/EsA5t9v/363AfKB5ATMXiQ51E5g+pB0RZtw1bikLNxffvzhERIpLIKW/Akg1s9pmFgP0BS6dhfMPcq/yMbMEct/u2Wpm8WYWm2d/By4eCwgqDaqU54OHOpAUX4rfvr2CaWm78j9JRKQEybf0nXM5wFBgNrABmOacW2dmz5hZD/9hs4FDZrYemAc84Zw7BPwMSDOz1f79z+Wd9ROMqlSIY/rgdrRLqcSTM9bw8pzNmtIpIiHDgq3QfD6fS0tL8zoG585f4On31zI9PZPeLZN49s7GREdqSqeIBCczS/ePn/6kkL3LZkFFR0bwQu8mVI8vxajPv2X/8TOM6d+CcprSKSIlmC5df4KZMeKmerzYuwlLtxyiz+tL2XfsjNexRESumko/AH18NXhzYCsyj+RO6dy4T1M6RaRkUukHqFO9RKY92A7noM/YpXz57UGvI4mIXDGV/hW4rlp5Pni4PdXjSzHwra+YkZ7pdSQRkSui0r9CVSuUYtrgdrSpcw2/m76aV7/4VlM6RaTEUOlfhfJx0bw1sDV3tqjOyDmbeWrmWs6dv+B1LBGRfGnK5lWKiYrgL32akhRfmle/+Ja9/imdZWP1kopI8NKVfgGYGY91q8fzvRqzOOMgd72+lP3HNaVTRIKXSr8Q3N2qJm8ObMWOQ99xx+jFbN5/wutIIiKXpdIvJJ3rJTJtcDtyLjh6jVnCxEVbOXPuvNexREQuotIvRA2rVeCDhzvQrGZF/uejDXR6YR5/W7qdszkqfxEJDir9Qla9Yineva8NUwe1JblSGf7fP9dx40sLmLpip2b4iIjndJfNIuSc48uMg/zls82s2nWUWpVKM+KmVHo0rU5khHkdT0RCSKB32dSVfhEyMzqmJvLBQ+15Y4CPMjFRPDp1NbeMWshHa/Zy4UJw/YUrIqFPpV8MzIyuP6vMh49cz9j+LTDg4ckr+flrXzJn/X6t6BWRYqPSL0YREcatjavy6YhOjLq7Gaezc3jgb2ncPnoxCzZnqfxFpMip9D0QGWHc3rw6nz/WmRd6NeHgyWwGvPkVd41byrKth7yOJyIhTAO5QSA75wJT03bx17nfsv/4WTrUrcRj3erTsla819FEpIQIdCBXpR9Ezpw7z6TlOxk7P4ODJ7O5oX4ij99cn0bVK3gdTUSCnEq/BPvubA7vLN3OuAVbOXb6HN0bVuHRbvWoX6Wc19FEJEip9EPA8TPnePPLbbyxaBsns3P4ZZNqjLgplTqJZb2OJiJBRqUfQo6eymb8wq28tTj3lg53tkhieNdUalxT2utoIhIkCnVxlpl1N7NNZpZhZk/9yDF3mdl6M1tnZpPz7B9gZt/6vwYE/iPI9yqWjuHJ7g1Y9PsbuLdDbWat3sMNL83n6Q/WsvfYaa/jiUgJku+VvplFApuBbkAmsALo55xbn+eYVGAacKNz7oiZXeucO2Bm1wBpgA9wQDrQ0jl35MeeT1f6+dt37Ayj52UwZcVOzIx7WtfkoRtSuLZcnNfRRMQjhXml3xrIcM5tdc5lA1OAnpcc8wAw+vsyd84d8O+/BZjjnDvsf2wO0D3QH0Iur0qFOP50eyPm/a4LdzSrzrvLdtDphXk8+8kGDn+X7XU8EQligZR+dWBXnu1M/7686gH1zGyxmS0zs+5XcK5cpaT40jzfuwmfP9aZ7g2rMH7hVjo+P5eRn23i2OlzXscTkSBUWCtyo4BUoAvQD5hgZhUDPdnMBplZmpmlZWVlFVKk8FE7oQyj+jZn9ohOdK6fyKtzM+j4/FzGzt+im7qJyEUCKf3dQI0820n+fXllArOcc+ecc9vIHQNIDfBcnHPjnXM+55wvMTHxSvJLHvUql2NM/5Z8NOx6WiVfw/OfbmT41FVk5+g+/iKSK5DSXwGkmlltM4sB+gKzLjnmH+Re5WNmCeS+3bMVmA3cbGbxZhYP3OzfJ0WoYbUKTBzg46lbG/Cv1Xu4750VnDyb43UsEQkC+Za+cy4HGEpuWW8Apjnn1pnZM2bWw3/YbOCQma0H5gFPOOcOOecOA38i9y+OFcAz/n1SxMyMwZ1TeKF3E5ZsOUT/Ccs4dPKs17FExGNanBUGPl+/n4cnr6R6xVL87b7WJMVrUZdIqNEnZ8kPbrquMn+/vw0HT56l19glbNp3wutIIuIRlX6YaJV8DdMGtwOgz+tLSNuud9lEwpFKP4w0qFKeGYPbk1A2lv4Tl/PFhv1eRxKRYqbSDzM1rinN9MHtqF+lHIPeTWd62q78TxKRkKHSD0OVysYy+YG2tKtTiSdmrGHcgi1eRxKRYqLSD1NlY6N4Y6CPXzSpyrOfbOR/P1qv1bsiYSDK6wDindioSF7t25xKZWKYsGgbh05m83zvJkRH6lpAJFSp9MNcRITxxx4NqVQ2lpFzNnPkVDaj+7egdIx+a4iEIl3SCWbGsK6p/O8djViwOYv+E5dz9JRu0SwSilT68oP+bWoxpn8L1u0+Tp/Xl+pTuURCkEpfLtK9UVXe/m0r9h47Q68xS8g4oNW7IqFEpS//R/uUBKYMakv2eUfv15fy9c4f/XRLESlhVPpyWY2qV2DmkHaUj4vmngnLWbBZH24jEgpU+vKjalUqw4wh7aidUIb73l7BP1f9n8+/EZESRqUvP+nacnFMebAtvuR4hk9ZxZtfbvM6kogUgEpf8lU+Lpq3721N94ZVeObD9bzw6UaC7XMYRCQwKn0JSFx0JKP7t6Bf65qMmb+Fp2auJee8PntXpKTRsksJWGSE8ec7GpFQNobX5mZw+FQ2r/VrTlx0pNfRRCRAutKXK2JmPH5zff74y+v4fMN+fvPGVxw7fc7rWCISIJW+XJWBHWrzSt/mfL3rCHePW8qB42e8jiQiAVDpy1Xr0bQabw5sxc7Dp7hz7BK2HfzO60gikg+VvhRIx9RE3nugLaeyz9N77BLWZh7zOpKI/ASVvhRY0xoVmT64HXHRkfQdv5TFGQe9jiQiP0KlL4UiJbEsM4e0Jym+NPe+tYKP1uz1OpKIXEZApW9m3c1sk5llmNlTl3l8oJllmdkq/9f9eR47n2f/rMIML8GlSoU4pj3YjiZJFRj63kreXbrd60gicol85+mbWSQwGugGZAIrzGyWc279JYdOdc4Nvcy3OO2ca1bwqFISVCgdzbv3tWHo5JX85z/XcfBkNiNuSsXMvI4mIgR2pd8ayHDObXXOZQNTgJ5FG0tKslIxkYz7dUt6t0zilS++5fcz13DijObyiwSDQEq/OrArz3amf9+lepnZGjObYWY18uyPM7M0M1tmZrdf7gnMbJD/mLSsLN3CNxRERUbwYu8mPHxDCtPTM+k2ciGffrPP61giYa+wBnL/BSQ755oAc4B38jxWyznnA+4BRplZyqUnO+fGO+d8zjlfYmJiIUUSr5kZT9zSgPeHtKdi6WgG/z2d+99JY/dRfQyjiFcCKf3dQN4r9yT/vh845w455876NycCLfM8ttv/363AfKB5AfJKCdS8Zjz/euR6nr6tAYszDtJt5AImLtqqG7aJeCCQ0l8BpJpZbTOLAfoCF83CMbOqeTZ7ABv8++PNLNb/6wSgA3DpALCEgejICAZ1SmHOY51oW6cS//PRBnqOXsyazKNeRxMJK/mWvnMuBxgKzCa3zKc559aZ2TNm1sN/2DAzW2dmq4FhwED//p8Baf7984DnLjPrR8JIUnxp3hjgY2z/FmSdOMvtoxfzx1nrNNArUkws2D4Mw+fzubS0NK9jSDE4fuYcL83exLvLdlC5XBx/7NGQWxpW1vROkatgZun+8dOfpBW54pnycdE807PRRQO9D/wtXQO9IkVIpS+e00CvSPFR6UtQ0ECvSPFQ6UtQ0UCvSNFS6UvQMTNubVyVzx/vzK/a1uKdpdt/WNEbbBMPREoalb4ELQ30ihQ+lb4EPQ30ihQelb6UCBroFSkcKn0pUTTQK1IwKn0pcTTQK3L1VPpSYmmgV+TKqfSlxNNAr0jgVPoSEjTQKxIYlb6EFA30ivw0lb6EnMsN9N40cgEfrtmjgV4Jeyp9CVnfD/R+8FAHrikTy9DJX9Nr7BLSdxz2OpqIZ1T6EvKa1ajIh49czwu9mpB55DS9xi7loUnpbD/4ndfRRIqdPjlLwsqp7BwmLNzGuIVbOHf+Ar9qW4thN6YSXybG62giBaJPzhK5jNIxUQy/KZX5v+tC75ZJvLNkO51enMf4hVs4c+681/FEipxKX8LSteXjePbOJnw6ohO+WvH8+eON3DRyAbNWa7BXQptKX8JavcrleOve1vz9vjaUi4tm2Htfc/uYJXy1TYO9EppU+iLA9akJfPjI9bzUpyn7j53hrnFLGfS3NLZmnfQ6mkih0kCuyCVOZ5/njS+3Mnb+Fs7mXKB/m5oM65pKpbKxXkcT+VGFOpBrZt3NbJOZZZjZU5d5fKCZZZnZKv/X/XkeG2Bm3/q/BlzZjyFS/ErFRDL0xlTmP3EDfVvX4O/Ld9LlxfmMna/BXin58r3SN7NIYDPQDcgEVgD9nHPr8xwzEPA554Zecu41QBrgAxyQDrR0zh35sefTlb4Em4wDJ3juk418vuEA1SuW4olb6tOjaTUiIszraCI/KMwr/dZAhnNuq3MuG5gC9Awwxy3AHOfcYX/RzwG6B3iuSFCoe205Jg5oxeQH2hBfJpoRU1fRY/SXLN1yyOtoIlcskNKvDuzKs53p33epXma2xsxmmFmNKzxXJOi1T0lg1sPX8/LdTTl8Mpt+E5Zx/zsryDhwwutoIgErrNk7/wKSnXNNyL2af+dKTjazQWaWZmZpWVlZhRRJpPBFRBh3NE9i7u+68PvuDVi+9TC3jFrEf/xjLQdPnvU6nki+Ain93UCNPNtJ/n0/cM4dcs59/zt+ItAy0HP95493zvmcc77ExMRAs4t4Ji46kiFdUpj/RBd+1aYmU77aRZcX5zN6XganszXYK8ErkNJfAaSaWW0ziwH6ArPyHmBmVfNs9gA2+H89G7jZzOLNLB642b9PJCRUKhvLf/dsxOxHO9E+pRIvzt7EDS/NZ0Z6JhcuBNd0aBEIoPSdcznAUHLLegMwzTm3zsyeMbMe/sOGmdk6M1sNDAMG+s89DPyJ3L84VgDP+PeJhJSUxLKM/42PaQ+2o3L5WH43fTW/eO1LFmcc9DqayEW0OEukkF244Phw7V6e/2Qju4+e5ob6ifzhtp9Rr3I5r6NJCNNdNkU8EhFh9GhajS8e78zTtzUgbccRuo9ayB/eX8OBE2e8jidhTqUvUkTioiMZ1CmFhU/cwID2ycxIz6TLi/N5fcEWcs5f8DqehCmVvkgRiy8Tw3/9siFzHu1M+5QEnvtkIz1HL+ab3ce8jiZhSKUvUkySE8ow4TctGdu/BfuPn6Xn6MU8+8kG3c9HipVKX6QYmRm3Nq7KF491pneLJMYt2Er3UQt1SwcpNip9EQ9UKB3N872bMPn+Nlxw0G/CMp6auYZjp895HU1CnEpfxEPt6yYwe0QnHuxUh2lpu+g2cgGffrPP61gSwlT6Ih4rFRPJH277Gf98+HoqlY1l8N/TGfxuOgeOa3qnFD6VvkiQaJxUgVlDO/Bk9/rM3XSAriMXMOWrnfqgdilUKn2RIBIdGcFDXeoye0QnrqtanqfeX8s9E5az/eB3XkeTEKHSFwlCtRPK8N4DbXn2zsZ8s/sYt4xaqEVdUihU+iJBKiLC6Ne6Jp8/3pnO9RK1qEsKhUpfJMhVLh/HuF9rUZcUDpW+SAmgRV1SWFT6IiWIFnVJQan0RUogLeqSq6XSFymhtKhLroZKX6SE06IuuRIqfZEQoEVdEiiVvkgI0aIuyY9KXyTEaFGX/BSVvkiI0qIuuRyVvkgI06IuuZRKXyQMXG5R131vryB9xxGvo0kxC6j0zay7mW0yswwze+onjutlZs7MfP7tZDM7bWar/F+vF1ZwEbly3y/qeqxbPdJ3HqHX2CXcPW4pCzdnaYpnmLD8/kebWSSwGegGZAIrgH7OufWXHFcO+AiIAYY659LMLBn40DnXKNBAPp/PpaWlXcnPICJX4buzObz31U4mLtrGvuNnaFy9Ag91SeGWhlWIiDCv48kVMrN055wvv+MCudJvDWQ457Y657KBKUDPyxz3J+B5QMsBRUqAMrFR3N+xDgue7MJzdzbmxJlzDJm0km4vL2B62i7OaZpnSAqk9KsDu/JsZ/r3/cDMWgA1nHMfXeb82mb2tZktMLOOl3sCMxtkZmlmlpaVlRVodhEpBLFRkfRtXZMvHu/Ca/2aEx0ZwRMz1tDlxfm8vXgbp7M12yeUFHgg18wigJHA45d5eC9Q0znXHHgMmGxm5S89yDk33jnnc875EhMTCxpJRK5CZITxy6bV+GR4R94a2IqqFeL447/Wc/3zcxk9L0N38gwRgZT+bqBGnu0k/77vlQMaAfPNbDvQFphlZj7n3Fnn3CEA51w6sAWoVxjBRaRomBk3NLiWGUPaM+3BdjSqXoEXZ2/i+ufm8vynG8k6cdbriFIAgQzkRpE7kNuV3LJfAdzjnFv3I8fPB37nH8hNBA47586bWR1gEdDYOXf4x55PA7kiweeb3ccYO38LH3+zl5jICPq2qsEDneqQFF/a62jiF+hAblR+BzjncsxsKDAbiATedM6tM7NngDTn3KyfOL0T8IyZnQMuAIN/qvBFJDg1ql6B0f1bsCXrJOMWbGHS8p1MWr6THs2q8VCXFOpeW87riBKgfK/0i5uu9EWC356jp5mwaCvvfbWTszkXuPm6yjzUpS5Na1T0OlrYCvRKX6UvIlft0MmzvL1kO28v2c6JMzl0TE1gSJcU2tWphJnm+hcnlb6IFJsTZ84xaXnuQq+DJ8/SvGZFHupSl64NrtVCr2Ki0heRYnfm3Hmmp2cybsEWMo+cpn7lcgzpksIvmlQlKlK3+ipKKn0R8cy58xf4cM0exszbwrcHTlLjmlI82CmF3i2TiIuO9DpeSFLpi4jnLlxwfL5hP6Pnb2H1rqMklovl/utr079tLcrG5jt5UK6ASl9EgoZzjiVbDjFmfgaLMw5RoVQ0A9onc2/7ZOLLxHgdLySo9EUkKK3adZQx8zL4bP1+ysRE8l89GtKnZZJm+xRQYd5lU0Sk0DSrUZHxv/Hx2aOdaJxUgSdnrGHo5K85dkr39ikOKn0R8US9yuWYdH9bnuxen9nr9nHrKwtZvlUf41jUVPoi4pnICOOhLnWZOaQ9MVER9JuwjL98tkn38i9CKn0R8VzTGhX5aFhHerVI4rW5GfR5fSk7Dn3ndayQpNIXkaBQJjaKF/s05a/3NGdL1klue2UR76/M1Gf3FjKVvogElV80qcanIzrRsFoFHpu2muFTVnH8jAZ5C4tKX0SCTvWKpXhvUFse71aPj9bu5dZRi0jbrruyFwaVvogEpcgI45GuqUwf3I6ICLhr3FJenrOZHA3yFohKX0SCWoua8Xw8rCO3N6vOK198y93jl7Hr8CmvY5VYKn0RCXrl4qIZeXczXunbjM37TnDbK4v456rd+Z8o/4dKX0RKjJ7NqvPx8I7Uq1KO4VNW8djUVZzQIO8VUemLSIlS45rSTB3UlhE3pfKPVbv5+atfsnLnEa9jlRgqfREpcaIiIxhxUz2mPdiO8xccfV5fymtffMv5C5rTnx+VvoiUWL7ka/hkREd+3rgqf5mzmX7jl7H76GmvYwU1lb6IlGjl46J5pW8zRt7VlHV7jtF91EI+XLPH61hBS6UvIiWemXFniyQ+Ht6RlMSyDJ38NU9MX813Z3O8jhZ0Aip9M+tuZpvMLMPMnvqJ43qZmTMzX559f/Cft8nMbimM0CIil1OrUhmmD27HIzfWZcbKTH7+6iJW7zrqdaygkm/pm1kkMBq4FbgO6Gdm113muHLAcGB5nn3XAX2BhkB3YIz/+4mIFInoyAgev7k+Ux5oS3bOBXqNXcKY+Rka5PUL5Eq/NZDhnNvqnMsGpgA9L3Pcn4DngTN59vUEpjjnzjrntgEZ/u8nIlKk2tSpxCfDO3FLoyq88Okm+k9cxt5jGuQNpPSrA7vybGf69/3AzFoANZxzH13puSIiRaVC6Wj+2q85L/RuwprMY3QftYhP1u71OpanCjyQa2YRwEjg8QJ8j0FmlmZmaVlZWQWNJCLyAzPjLl8NPhrWkVqVSjNk0kqemrmGU9nhOcgbSOnvBmrk2U7y7/teOaARMN/MtgNtgVn+wdz8zgXAOTfeOedzzvkSExOv7CcQEQlA7YQyzBjcniFdUpiatotfvPolazOPeR2r2AVS+iuAVDOrbWYx5A7Mzvr+QefcMedcgnMu2TmXDCwDejjn0vzH9TWzWDOrDaQCXxX6TyEiEoCYqAh+370Bk+5vw6ns89w5djHjFmzhQhgN8uZb+s65HGAoMBvYAExzzq0zs2fMrEc+564DpgHrgU+Bh51z5wseW0Tk6rVPSeCT4R3p2qAyz36ykXsmLuOb3eFx1W/B9vmTPp/PpaWleR1DRMKAc45pabt49pONHD11jp83qcrj3epRJ7Gs19GumJmlO+d8+R0XVRxhRESCkZlxd6ua3Nq4KhMXbmXil9v49Jt93OWrwfCuqVSpEOd1xEKnK30REb+sE2cZPS+DSct3EGHGwA7JDOmcQsXSMV5Hy1egV/oqfRGRS+w6fIqXP9/MB1/vpmxsFIM7p3Bvh2RKxwTvmyMqfRGRAtq07wQvzt7E5xv2k1A2luFd63J3q5rERAXfvSoDLf3gSy4iEiTqVynHxAE+Zg5pT53EMvznP9dx08gF/OPr3SV2mqdKX0QkHy1rxTN1UFvevrcVZWOjGDF1Fbe9uoi5G/cTbO+W5EelLyISADOjS/1r+fCR63m1X3NOnzvPb99O465xS1mx/bDX8QKm0hcRuQIREUaPptX4/LHO/M/tjdhx6BR9Xl/Kb99ewYa9x72Oly8N5IqIFMDp7PO8vWQ7Y+dncOJsDj2bVuOxbvWpWal0sebQ7B0RkWJ07NQ5xi3cwpuLt5Fz3tGvdU0e6VqXa8sVzwIvlb6IiAcOHD/Dq3O/ZcpXu4iOjOC31yczqFMKFUpFF+nzqvRFRDy0/eB3jJyzmVmr91ChVDRDuqQwoF0ypWKK5hNjVfoiIkFg3Z5jvDR7E/M2ZVG5fCzDu9ajjy+J6MjCnUejxVkiIkGgYbUKvHVva6YOaktSfGme/mAtN7+8kH+t3uPJAi+VvohIMWhTpxIzBrdj4m98xERG8Mh7X/PLv37Jgs1ZxbrAS6UvIlJMzIybrqvMx8M78vLdTTl2+hwD3vyKfhOWsXLnkWLJoNIXESlmkRHGHc2TmPt4F/67R0MyDpzkzjFLeHjSyiK/6g/e+4SKiIS4mKgIBrRPpnfLJN5avI3T585jZkX6nCp9ERGPlYmNYuiNqcXyXHp7R0QkjKj0RUTCiEpfRCSMqPRFRMKISl9EJIyo9EVEwohKX0QkjKj0RUTCSNDdWtnMsoAdBfgWCcDBQopT0um1uJhej4vp9fi3UHgtajnnEvM7KOhKv6DMLC2Qe0qHA70WF9PrcTG9Hv8WTq+F3t4REQkjKn0RkTASiqU/3usAQUSvxcX0elxMr8e/hc1rEXLv6YuIyI8LxSt9ERH5ESFT+mbW3cw2mVmGmT3ldR4vmVkNM5tnZuvNbJ2ZDfc6k9fMLNLMvjazD73O4jUzq2hmM8xso5ltMLN2Xmfykpk96v9z8o2ZvWdmcV5nKkohUfpmFgl27bA1AAACK0lEQVSMBm4FrgP6mdl13qbyVA7wuHPuOqAt8HCYvx4Aw4ENXocIEq8AnzrnGgBNCePXxcyqA8MAn3OuERAJ9PU2VdEKidIHWgMZzrmtzrlsYArQ0+NMnnHO7XXOrfT/+gS5f6ire5vKO2aWBPwcmOh1Fq+ZWQWgE/AGgHMu2zl31NtUnosCSplZFFAa2ONxniIVKqVfHdiVZzuTMC65vMwsGWgOLPc2iadGAU8CF7wOEgRqA1nAW/63uyaaWRmvQ3nFObcbeAnYCewFjjnnPvM2VdEKldKXyzCzssBMYIRz7rjXebxgZr8ADjjn0r3OEiSigBbAWOdcc+A7IGzHwMwsntx3BWoD1YAyZvYrb1MVrVAp/d1AjTzbSf59YcvMoskt/EnOufe9zuOhDkAPM9tO7tt+N5rZ372N5KlMINM59/2//GaQ+5dAuLoJ2Oacy3LOnQPeB9p7nKlIhUrprwBSzay2mcWQOxAzy+NMnjEzI/c92w3OuZFe5/GSc+4Pzrkk51wyub8v5jrnQvpK7qc45/YBu8ysvn9XV2C9h5G8thNoa2al/X9uuhLiA9tRXgcoDM65HDMbCswmd/T9TefcOo9jeakD8GtgrZmt8u972jn3sYeZJHg8AkzyXyBtBe71OI9nnHPLzWwGsJLcWW9fE+Krc7UiV0QkjITK2zsiIhIAlb6ISBhR6YuIhBGVvohIGFHpi4iEEZW+iEgYUemLiIQRlb6ISBj5/2H04vKsXcqoAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(tr_loss_hist, label = 'train')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "yhat = char_bi_gru.predict(sess = sess, X_length = X_length, X_indices = X_indices)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training acc: 100.00%\n" ] } ], "source": [ "print('training acc: {:.2%}'.format(np.mean(yhat == np.argmax(y, axis = -1))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }