{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Notebook_extra_Recurrent_neural_networks", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "o8d93AD9wBGu", "colab_type": "text" }, "source": [ "# Neural Networks for Data Science Applications\n", "## Lab session (extra): Recurrent neural networks\n", "\n", "**Contents of the lab session:**\n", "+ Recurrent neural networks in TensorFlow.\n", "+ A simple example by implementing a counting algorithm.\n", "+ A more elaborate encoder/decoder on a sorting problem." ] }, { "cell_type": "code", "metadata": { "id": "fvU9p2EdaODu", "colab_type": "code", "outputId": "a8e137f3-2b52-4629-8c1e-96d472b13547", "colab": { "base_uri": "https://localhost:8080/", "height": 156 } }, "source": [ "# Remember to enable a GPU on Colab by:\n", "# Runtime >> Change runtime type >> Hardware accelerator (before starting the VM).\n", "!pip install -q tensorflow-gpu==2.0.0" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 380.8MB 41kB/s \n", "\u001b[K |████████████████████████████████| 3.8MB 51.3MB/s \n", "\u001b[K |████████████████████████████████| 450kB 51.2MB/s \n", "\u001b[K |████████████████████████████████| 81kB 9.0MB/s \n", "\u001b[31mERROR: tensorflow 1.15.0 has requirement tensorboard<1.16.0,>=1.15.0, but you'll have tensorboard 2.0.2 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: tensorflow 1.15.0 has requirement tensorflow-estimator==1.15.1, but you'll have tensorflow-estimator 2.0.1 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: tensorboard 2.0.2 has requirement grpcio>=1.24.3, but you'll have grpcio 1.15.0 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: google-colab 1.0.0 has requirement google-auth~=1.4.0, but you'll have google-auth 1.7.1 which is incompatible.\u001b[0m\n", "\u001b[?25h" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "EJxhPrm_sXpd", "colab_type": "code", "colab": {} }, "source": [ "import tensorflow as tf" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "twvoyNfNwazk", "colab_type": "text" }, "source": [ "### Example 1: counting with RNNs (fixed length)\n", "\n", "To show how RNNs work, we consider a very simplified setup where we want to learn a counting algorithm:\n", "\n", "\n", "1. The RNN receives a sequence of symbols in input (each symbol represented with a one-hot encoding);\n", "2. We output a probability distribution on the most frequent symbol in the sequence.\n", "\n", "For this first example, all sequences will have the same length.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "QwiH1Ve4rss1", "colab_type": "code", "colab": {} }, "source": [ "n_max = 5 # How many symbols we consider\n", "max_seq_len = 10 # Length for the sequence\n", "units = 100 # Size of the RNN state vector" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "3yrE6W2WsT8U", "colab_type": "code", "colab": {} }, "source": [ "# In order to load data into TF, we first define a generator outputting random training sequences.\n", "def gen_seq():\n", " while True: # The generator is infinite\n", " # First generate a random vector of integers in [0, n_max-1]\n", " x = tf.random.uniform(shape=(max_seq_len,), maxval=n_max, dtype=tf.int32)\n", " # Convert the vector to a one-hot representation for the symbols\n", " x = tf.one_hot(x, depth=n_max)\n", " # Get the symbol appearing most frequently\n", " y = tf.argmax(tf.reduce_sum(x, axis=0), axis=0)\n", " yield x, y" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Fmz0AtACtD3E", "colab_type": "code", "colab": {} }, "source": [ "# Get one sequence\n", "X, y = next(gen_seq())" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "M3OhyCCwxUzv", "colab_type": "code", "outputId": "71399ab8-a3ef-4d96-b67e-a30c0c507fa1", "colab": { "base_uri": "https://localhost:8080/", "height": 225 } }, "source": [ "print(X)\n", "print(y)" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "tf.Tensor(\n", "[[0. 0. 0. 1. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 0. 1. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 1. 0. 0.]], shape=(10, 5), dtype=float32)\n", "tf.Tensor(2, shape=(), dtype=int64)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "5Uzof34qX3KO", "colab_type": "code", "colab": {} }, "source": [ "# We load the sequences inside a tf.data.Dataset object\n", "dataset = tf.data.Dataset.from_generator(lambda: gen_seq(), output_types=(tf.float32, tf.float32))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "il3U39YaYEWV", "colab_type": "code", "outputId": "62ea9530-e058-4ed7-9cfb-7c025bec1482", "colab": { "base_uri": "https://localhost:8080/", "height": 52 } }, "source": [ "# Try batching!\n", "for xb, yb in dataset.batch(3):\n", " print(xb.shape)\n", " print(yb.shape)\n", " break" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "(3, 10, 5)\n", "(3,)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "6OF7NhQIv-QO", "colab_type": "code", "colab": {} }, "source": [ "from tensorflow.keras import layers" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "elQBuPSPw_hN", "colab_type": "code", "colab": {} }, "source": [ "# SimpleRNN corresponds to the basic RNN in the slides\n", "rnn = layers.SimpleRNN(units)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9rJLWNz_xn-e", "colab_type": "code", "colab": {} }, "source": [ "# Initialize the state vector (note: one state for each sequence in the mini-batch)\n", "init_state = tf.zeros(shape=(3, units))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "vYUVbNXsxREl", "colab_type": "code", "colab": {} }, "source": [ "# Run the RNN!\n", "states = tf.identity(init_state)\n", "for i in range(max_seq_len):\n", " states = rnn(xb[:, i:i+1, :], states)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "5PqAdoabztzo", "colab_type": "code", "outputId": "a31992bf-e4da-47b4-f1e0-13360d395f9f", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "states.shape" ], "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TensorShape([3, 100])" ] }, "metadata": { "tags": [] }, "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { "id": "T1SGiuwMyEMu", "colab_type": "code", "colab": {} }, "source": [ "# Simplified notation: we can let the SimpleRNN object do the loop\n", "states2 = rnn(xb, init_state)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "B5GMv23ezmyP", "colab_type": "code", "colab": {} }, "source": [ "tf.reduce_all(states == states2)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bkKmZynUzug2", "colab_type": "code", "colab": {} }, "source": [ "from tensorflow.keras import Sequential, losses, optimizers, metrics" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "wmUyRunNz4Ku", "colab_type": "code", "colab": {} }, "source": [ "# Define our model: \n", "# 1. The SimpleRNN computes the state update, returning the state vector after processing the entire sequence.\n", "# 2. The Dense layer makes the final prediction (which symbol is most frequent) from the last state.\n", "rnn_full = Sequential([\n", " layers.SimpleRNN(units, input_shape=(max_seq_len, n_max)),\n", " layers.Dense(n_max, activation='softmax')\n", "])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WPaI6lXL0E1Q", "colab_type": "code", "colab": {} }, "source": [ "rnn_full.compile(loss=losses.SparseCategoricalCrossentropy(), optimizer=optimizers.RMSprop(0.001), metrics=['accuracy'])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "P11Vpthc0AL2", "colab_type": "code", "outputId": "32c04102-c4f1-4ed4-b445-2c5b8dbcc0d9", "colab": { "base_uri": "https://localhost:8080/", "height": 225 } }, "source": [ "rnn_full.summary()" ], "execution_count": 17, "outputs": [ { "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "simple_rnn_1 (SimpleRNN) (None, 100) 10600 \n", "_________________________________________________________________\n", "dense (Dense) (None, 5) 505 \n", "=================================================================\n", "Total params: 11,105\n", "Trainable params: 11,105\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "vbHC623j0DvP", "colab_type": "code", "outputId": "b333e36c-754e-49bd-8628-5095186c245f", "colab": { "base_uri": "https://localhost:8080/", "height": 364 } }, "source": [ "# Note: because the generator is infinite, we can select manually the size of an epoch.\n", "history = rnn_full.fit_generator(dataset.batch(32), steps_per_epoch=100, epochs=10)" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "text": [ "Epoch 1/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.7964 - accuracy: 0.7094\n", "Epoch 2/10\n", "100/100 [==============================] - 8s 85ms/step - loss: 0.4963 - accuracy: 0.8191\n", "Epoch 3/10\n", "100/100 [==============================] - 9s 87ms/step - loss: 0.3976 - accuracy: 0.8597\n", "Epoch 4/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.3503 - accuracy: 0.8737\n", "Epoch 5/10\n", "100/100 [==============================] - 8s 85ms/step - loss: 0.2986 - accuracy: 0.8944\n", "Epoch 6/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.2695 - accuracy: 0.9097\n", "Epoch 7/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.2683 - accuracy: 0.9044\n", "Epoch 8/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.2462 - accuracy: 0.9125\n", "Epoch 9/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.2314 - accuracy: 0.9191\n", "Epoch 10/10\n", "100/100 [==============================] - 9s 86ms/step - loss: 0.2202 - accuracy: 0.9250\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "MyWVL9jaz0LA", "colab_type": "text" }, "source": [ "### Example 2: counting with RNNs (variable length)\n", "\n", "This is the same example as before, but we let the sequences vary in length. This also shows an example of masking inside TensorFlow." ] }, { "cell_type": "code", "metadata": { "id": "9zk1m_Wc32sS", "colab_type": "code", "colab": {} }, "source": [ "def gen_seq():\n", " while True:\n", "\n", " # This is the main modification: for each sequence, we sample a length uniformly in [1, max_seq_len-1]\n", " seq_len = tf.random.uniform(shape=(1,), minval=1, maxval=max_seq_len, dtype=tf.int32)\n", "\n", " x = tf.random.uniform(shape=(seq_len[0],), maxval=n_max, dtype=tf.int32)\n", " x = tf.cast(tf.one_hot(x, depth=n_max), tf.float32)\n", " y = tf.argmax(tf.reduce_sum(x, axis=1), axis=0)\n", " yield x, y" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "hLMZYxdX5Bsq", "colab_type": "code", "colab": {} }, "source": [ "dataset = tf.data.Dataset.from_generator(lambda: gen_seq(), output_types=(tf.float32, tf.float32))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "rkkcsGcu-INT", "colab_type": "code", "colab": {} }, "source": [ "# Because the sequences vary in length, we need padding to make a batch.\n", "# The first argument defines the size of the output tensors. We use -1.0 as padding, because 0\n", "# is reserved for the first symbol in our encoding.\n", "train_it = dataset.padded_batch(3, ([max_seq_len, n_max],[]), padding_values=(-1.0, -1.0))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bYp8lfew4hoC", "colab_type": "code", "outputId": "c891813d-8e93-4144-e7da-d0e37d9390b9", "colab": { "base_uri": "https://localhost:8080/", "height": 641 } }, "source": [ "# Check the result\n", "for xb, yb in train_it:\n", " print(xb.shape)\n", " print(yb.shape)\n", " print(xb)\n", " print(yb)\n", " break" ], "execution_count": 22, "outputs": [ { "output_type": "stream", "text": [ "(3, 10, 5)\n", "(3,)\n", "tf.Tensor(\n", "[[[ 0. 1. 0. 0. 0.]\n", " [ 0. 0. 0. 1. 0.]\n", " [ 0. 1. 0. 0. 0.]\n", " [ 0. 0. 0. 1. 0.]\n", " [ 0. 0. 0. 0. 1.]\n", " [ 0. 0. 1. 0. 0.]\n", " [ 0. 0. 0. 0. 1.]\n", " [ 1. 0. 0. 0. 0.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]]\n", "\n", " [[ 0. 1. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 1.]\n", " [ 0. 0. 1. 0. 0.]\n", " [ 0. 0. 0. 1. 0.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]]\n", "\n", " [[ 0. 1. 0. 0. 0.]\n", " [ 0. 0. 1. 0. 0.]\n", " [ 0. 1. 0. 0. 0.]\n", " [ 0. 0. 0. 1. 0.]\n", " [ 1. 0. 0. 0. 0.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]\n", " [-1. -1. -1. -1. -1.]]], shape=(3, 10, 5), dtype=float32)\n", "tf.Tensor([0. 0. 0.], shape=(3,), dtype=float32)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Z_vpYQcU-GrE", "colab_type": "code", "colab": {} }, "source": [ "# Because sequences are padded, we can add a Masking layer to mask the operations\n", "# of all subsequent layers, stopping the RNN update as soon as a padding value is encountered.\n", "rnn_full_masked = Sequential([\n", " layers.Masking(mask_value=-1, input_shape=(max_seq_len, n_max)),\n", " layers.SimpleRNN(units),\n", " layers.Dense(n_max, activation='softmax')\n", "])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "0FfTyD__4i93", "colab_type": "code", "colab": {} }, "source": [ "rnn_full_masked.compile(loss=losses.SparseCategoricalCrossentropy(), optimizer=optimizers.RMSprop(0.001), metrics=['accuracy'])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "B_P05FQw-Z56", "colab_type": "code", "outputId": "52d85531-f46d-4d0b-cb7f-e2018176e9d0", "colab": { "base_uri": "https://localhost:8080/", "height": 260 } }, "source": [ "rnn_full_masked.summary()" ], "execution_count": 25, "outputs": [ { "output_type": "stream", "text": [ "Model: \"sequential_1\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "masking (Masking) (None, 10, 5) 0 \n", "_________________________________________________________________\n", "simple_rnn_2 (SimpleRNN) (None, 100) 10600 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 5) 505 \n", "=================================================================\n", "Total params: 11,105\n", "Trainable params: 11,105\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ctCqPd56-dF6", "colab_type": "code", "colab": {} }, "source": [ "train_it = dataset.padded_batch(32, ([max_seq_len, n_max],[]), padding_values=(-1.0, -1.0))\n", "history = rnn_full_masked.fit_generator(train_it, steps_per_epoch=100, epochs=10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "er4EUTXG08xa", "colab_type": "text" }, "source": [ "### Example 3: sorting with RNNs\n", "\n", "In this case, we consider a sequence-to-sequence problem, where the output sequence should be a sorted version of the input sequence.\n", "\n", "To solve the problem, we implement an encoder/decoder architecture. When training, we provide the decoder with the real output sequence (**teacher forcing**). During inference, we instead use the predicted values (autoregressive mode)." ] }, { "cell_type": "code", "metadata": { "id": "nXbtZYHz_hcG", "colab_type": "code", "colab": {} }, "source": [ "# We output three elements: input sequence, target sequence, and input sequence for the decoder.\n", "def gen_sequences(batch_size):\n", " while True:\n", "\n", " # Define the input sequence (similar to before)\n", " x = tf.random.uniform(shape=(batch_size, max_seq_len,), maxval=n_max, dtype=tf.int32)\n", " x_onehot = tf.cast(tf.one_hot(x, depth=n_max), tf.float32)\n", " \n", " # This time, the output is a sorted version of the input sequence\n", " y = tf.sort(x, axis=1)\n", " y_onehot = tf.cast(tf.one_hot(y, depth=n_max), tf.float32)\n", "\n", " # The output for the decoder is the same as the target sequence, shifted by one \n", " x_dec = tf.concat([tf.zeros((batch_size, 1, n_max), dtype=tf.float32), y_onehot[:, 0:-1, :]], axis=1)\n", "\n", " yield [x_onehot, x_dec], y_onehot" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "7y1ZCqhZACtK", "colab_type": "code", "outputId": "278cd0df-7689-46ad-cf1a-29b5fe9f4b0f", "colab": { "base_uri": "https://localhost:8080/", "height": 589 } }, "source": [ "for xb, yb in gen_sequences(1):\n", " print(xb[0])\n", " print(xb[1])\n", " print(yb)\n", " break" ], "execution_count": 30, "outputs": [ { "output_type": "stream", "text": [ "tf.Tensor(\n", "[[[0. 1. 0. 0. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 0. 0. 1. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 1. 0. 0.]]], shape=(1, 10, 5), dtype=float32)\n", "tf.Tensor(\n", "[[[0. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 0. 1. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 0. 0. 0. 1.]]], shape=(1, 10, 5), dtype=float32)\n", "tf.Tensor(\n", "[[[0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 0. 1. 0.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 0. 0. 0. 1.]\n", " [0. 0. 0. 0. 1.]]], shape=(1, 10, 5), dtype=float32)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "GLtKLJlfAXec", "colab_type": "code", "colab": {} }, "source": [ "# This time we define everything in functional form, because it is easier. \n", "# This part is the encoder. For variety, we use two layers of bidirectional GRUs instead of the SimpleRNN.\n", "\n", "# Input tensor\n", "input_seq = layers.Input(batch_shape=(None, max_seq_len, n_max))\n", "\n", "# First layer of GRUs. Note the return_sequences keyword, because we need all states (not just the last one)\n", "# as input to the next layer.\n", "h = layers.Bidirectional(layers.GRU(units, return_sequences=True))(input_seq)\n", "\n", "# Output state vector\n", "out_state = layers.Bidirectional(layers.GRU(units))(h)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "r2izVQwBOBWE", "colab_type": "code", "outputId": "e17f4cdb-0bd2-498b-eaa2-1cf41f494368", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "# This is 100*2 because the last GRU is bidirectional.\n", "out_state.shape" ], "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TensorShape([None, 200])" ] }, "metadata": { "tags": [] }, "execution_count": 32 } ] }, { "cell_type": "code", "metadata": { "id": "SqVMqc1dbeFn", "colab_type": "code", "colab": {} }, "source": [ "# During decoding, we will feed the last state vector continuously as input to the decoder.\n", "out_state = layers.RepeatVector(max_seq_len)(out_state)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bdJgb4GNOrp7", "colab_type": "code", "colab": {} }, "source": [ "# The input to the decoder is encoder's state + the last prediction (in this case, the real one).\n", "input_dec = layers.Input(batch_shape=(None, max_seq_len, n_max))\n", "out_state = layers.Concatenate()([input_dec, out_state])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KTh1UV4Gbpxd", "colab_type": "code", "outputId": "00b25bb1-9062-484d-85f7-b2546772383a", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "out_state.shape" ], "execution_count": 35, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TensorShape([None, 10, 205])" ] }, "metadata": { "tags": [] }, "execution_count": 35 } ] }, { "cell_type": "code", "metadata": { "id": "PsWppJGYDgXP", "colab_type": "code", "colab": {} }, "source": [ "# The decoder is also another GRU.\n", "dec_state = layers.GRU(units, return_sequences=True)(tf.concat([input_dec, out_state], axis=2))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "BcOtGK1YDrMT", "colab_type": "code", "colab": {} }, "source": [ "# We make one prediction for each time-step with a TimeDistributed layer.\n", "ypred = layers.TimeDistributed(layers.Dense(n_max, activation='softmax'), batch_input_shape=(None, max_seq_len, n_max))(dec_state)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "LnZP3Ou9D5o7", "colab_type": "code", "colab": {} }, "source": [ "from tensorflow.keras import Model\n", "model = Model(inputs=[input_seq, input_dec], outputs=[ypred])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Lrw1qSDKLBXJ", "colab_type": "code", "outputId": "0f36f960-e2ea-4d89-d451-bbdec57e7e95", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "model.predict(xb).shape" ], "execution_count": 39, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1, 10, 5)" ] }, "metadata": { "tags": [] }, "execution_count": 39 } ] }, { "cell_type": "code", "metadata": { "id": "zjHZCuOXELbz", "colab_type": "code", "colab": {} }, "source": [ "model.compile(loss=losses.CategoricalCrossentropy(), optimizer=optimizers.RMSprop(0.001), metrics=['accuracy'])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "4KkiVw8rJX1-", "colab_type": "code", "outputId": "7d6d906d-acd1-4269-8e21-60b7112b641c", "colab": { "base_uri": "https://localhost:8080/", "height": 503 } }, "source": [ "model.summary()" ], "execution_count": 41, "outputs": [ { "output_type": "stream", "text": [ "Model: \"model\"\n", "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "input_1 (InputLayer) [(None, 10, 5)] 0 \n", "__________________________________________________________________________________________________\n", "bidirectional (Bidirectional) (None, 10, 200) 64200 input_1[0][0] \n", "__________________________________________________________________________________________________\n", "bidirectional_1 (Bidirectional) (None, 200) 181200 bidirectional[0][0] \n", "__________________________________________________________________________________________________\n", "input_2 (InputLayer) [(None, 10, 5)] 0 \n", "__________________________________________________________________________________________________\n", "repeat_vector (RepeatVector) (None, 10, 200) 0 bidirectional_1[0][0] \n", "__________________________________________________________________________________________________\n", "concatenate (Concatenate) (None, 10, 205) 0 input_2[0][0] \n", " repeat_vector[0][0] \n", "__________________________________________________________________________________________________\n", "tf_op_layer_concat (TensorFlowO [(None, 10, 210)] 0 input_2[0][0] \n", " concatenate[0][0] \n", "__________________________________________________________________________________________________\n", "gru_2 (GRU) (None, 10, 100) 93600 tf_op_layer_concat[0][0] \n", "__________________________________________________________________________________________________\n", "time_distributed (TimeDistribut (None, 10, 5) 505 gru_2[0][0] \n", "==================================================================================================\n", "Total params: 339,505\n", "Trainable params: 339,505\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "xQiJ0KdNJZfl", "colab_type": "code", "outputId": "cf0a0405-13fe-4937-e757-93b6a7c2e0cd", "colab": { "base_uri": "https://localhost:8080/", "height": 364 } }, "source": [ "history = model.fit_generator(gen_sequences(32), steps_per_epoch=100, epochs=10)" ], "execution_count": 42, "outputs": [ { "output_type": "stream", "text": [ "Epoch 1/10\n", "100/100 [==============================] - 10s 98ms/step - loss: 0.5363 - accuracy: 0.8075\n", "Epoch 2/10\n", "100/100 [==============================] - 10s 99ms/step - loss: 0.1849 - accuracy: 0.9456\n", "Epoch 3/10\n", "100/100 [==============================] - 10s 104ms/step - loss: 0.1179 - accuracy: 0.9677\n", "Epoch 4/10\n", "100/100 [==============================] - 10s 103ms/step - loss: 0.0765 - accuracy: 0.9818\n", "Epoch 5/10\n", "100/100 [==============================] - 10s 104ms/step - loss: 0.0638 - accuracy: 0.9843\n", "Epoch 6/10\n", "100/100 [==============================] - 11s 105ms/step - loss: 0.0497 - accuracy: 0.9882\n", "Epoch 7/10\n", "100/100 [==============================] - 11s 106ms/step - loss: 0.0492 - accuracy: 0.9861\n", "Epoch 8/10\n", "100/100 [==============================] - 11s 108ms/step - loss: 0.0386 - accuracy: 0.9910\n", "Epoch 9/10\n", "100/100 [==============================] - 11s 108ms/step - loss: 0.0348 - accuracy: 0.9905\n", "Epoch 10/10\n", "100/100 [==============================] - 11s 109ms/step - loss: 0.0412 - accuracy: 0.9902\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "R9KBvJzwPc37", "colab_type": "code", "outputId": "bf86f64f-97ce-498b-e414-2944710ff2f4", "colab": { "base_uri": "https://localhost:8080/", "height": 287 } }, "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(history.history['accuracy'])" ], "execution_count": 43, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": { "tags": [] }, "execution_count": 43 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df3xcdZ3v8dcnv5MmaZMmLTVp2gKl\nWAr2R7agLMLCFWt1BcGr7aKIV2G9CruruLuw66Jb5bp7Hz5014fIbtHCgij2Fn/0Km7XFRDlgttM\nW1paLJTQSZMCTcmkTdo0aTKf+8ecpNOQNtN2kjOZeT8fj3nMOd9zzjffM9DznvM9Z77H3B0REck9\neWE3QEREwqEAEBHJUQoAEZEcpQAQEclRCgARkRylABARyVEpBYCZrTGzfWb2/AmWm5l908x2mdlW\nM1uctOxjZvZS8PpYUvkSM9sWbPNNM7Mz3x0REUlVqmcADwDLTrL8PcDc4HULcC+AmVUDXwQuBpYC\nXzSzqmCbe4Gbk7Y7Wf0iIpJmKQWAuz8FdJxklWuABz3hWWCKmc0A3g380t073D0G/BJYFiyrdPdn\nPfFLtAeBa89oT0RE5JQUpKmeOmBP0nxrUHay8tYRyk+qpqbGZ8+efaZtFRHJKZFIZL+71w4vT1cA\njBkzu4VEtxINDQ00NTWF3CIRkYnFzKIjlafrLqA2YGbSfH1QdrLy+hHK38TdV7t7o7s31ta+KcBE\nROQ0pSsA1gM3BncDXQIccPdXgQ3A1WZWFVz8vRrYECw7aGaXBHf/3Aj8NE1tERGRFKTUBWRmPwCu\nAGrMrJXEnT2FAO7+L8BjwHJgF3AY+HiwrMPMvgxsDKpa5e6DF5M/TeLuolLgF8FLRETGiU2k4aAb\nGxtd1wBERE6NmUXcvXF4uX4JLCKSoxQAIiI5SgEgIpKjMv53ACKSG/r647R0HGLXvkM07++m92ic\ngjwjP98S73l55Bvk5+cF8zb0fmx65GWJ+bzEe35QZnbc/NDyYXVm8zBlCgARGVedh/t4ub2bl/cd\n4uX9iffm9m6iHYcZiGfmTSl5Bnlm5JlhQ9PBe96xaUsuNxLzeSfedsT1R/pbefC/P/g26qaUpnW/\nFAAiknYDcact1pM40A++9h3i5fZu3jjUN7ReUX4ec2omcf6MCt570QzOqS3nnNpy5tROoqwwnwF3\nBuJOfzzxnpiOJ94HgrLBdQaOX35s/ZPUkVR2bD5Of9yJx52jA44D7k7cnbiTeI8fm/bBsmC5uxOP\nc/y8j7K+j7B+UMfAYDvG4I5NBYBISA739dMW66G1s4fWWA+He/uZXFo49KocnC4rpKK4ICO7Ig73\n9dPcfig4wHfzcjDdvP8Qff3xofWqJxVxTu0k3jV/euIgP20S59SWU19VRn7eifcrD6Mwfzz2JDcp\nAETGyIGeo4kDfOwwbZ09wXRPYrqzh46kb8KjyTOOBcKwgJgyrDzd4eHu7OvqDQ7wxw7yL+/rZu+B\nI8e1saG6jHNqy3nnebWcU5s4yJ9dW071pKLT/vsydhQAIqfB3ek41Edb8O29LTiwt8YOD8139fYf\nt01xQR71VaXUVZWxoG4y9VWlifkppdRXlTGpOJ+DR/o5cPgoB3oSr4M9x6YHX53Be2usZ6jsZH3n\nJwqPkYJjcmkhB48cTRzkkw743Un7Mqkon3OmlXPx2VOHDvLnTCtn1tQyigv0dX0iUQCIjCAed9q7\ne48d0Icd6NtiPfQcHThum/LigqED+sVzqqmrKqVuSllw0C9l6qSiUb+JV5QUnvKFPnfnUN9AIgxO\nEh6dSdOphMeMySWcU1vO9YvrOGda+VD//PTK4ozsjpJTpwCQnNPTN0DH4T46uvvoONzH/q7eoW/v\ngwf3vZ1H6BuIH7ddVVkhdVWlnFtbzuXn1Qbf3BMH9/opZVSWhtNPb2aUFxdQXlyQlvCYVJzP2bXl\nlBfr8JDt9F9YJrS+/jidhxMH8o5DfcQOHaXjcB+xQ4n5jkN9xIaWJdY7cjQ+Yl21FcXUV5WyoG4y\n715wFvVVZdRPKQ2+yZcyKQsPiGcSHjLxZd//0TJhxePOgZ6jvDHCQTtxQD86VD64bHg/e7KKkgKq\nJxVRVVbEWZUlvHVG5dB89aTC4L2IqeXFzJhcQoluN5EcowCQcRePO4//fh8/3tzGvq4jwbf0o3Qe\n7uNE1zKLC/KYOqmI6vLEAXzW1LKhA3jVpCKqg+nEfCFTSosoKtBIJyInowCQcXPk6AA/2tTGd37b\nTHP7IaZVFHN27STmnVVx7GBeVsTU8qI3HdxLi/TtXCTdFAAy5t7o7uWhZ6M89EyUNw71saCukn9e\nsZDlF86gMF/f0kXCogCQMdPc3s13f/sK6yKt9PbHufL8adx82dlccna1biMUyQAKAEkrd6cpGmP1\nU8385wuvU5iXx3WL6/jkZXM4d1pF2M0TkSSpPhN4GfDPQD7wHXf/h2HLZwFrgFqgA/iIu7ea2R8B\n30ha9Xxghbv/xMweAC4HDgTLbnL3LWeyMxKe/oE4G7a/zurfNPPcnk6mlBVy6x+dy0ffPotpFSVh\nN09ERjBqAJhZPnAP8C6gFdhoZuvdfUfSal8DHnT3fzOzK4GvAh919yeAhUE91SQeGv8fSdv9pbuv\nS8+uSBgO9faztmkPa55+hT0dPcyaWsaXr7mA65fUU1akE0yRTJbKv9ClwC53bwYws0eAa4DkAJgP\nfC6YfgL4yQj1fBD4hbsfPv3mSqbYd/AID/y/3Xzv2SgHj/SzZFYVf7t8Pu+aP/2kozuKSOZIJQDq\ngD1J863AxcPWeQ64jkQ30QeACjOb6u5vJK2zAvj6sO3uNrO7gF8Bd7h77/A/bma3ALcANDQ0pNBc\nGUs7X+vivt8089MtbfTHnWUXnMUnLzubJbOqwm6aiJyidJ2jfx74lpndBDwFtAFDI2WZ2QzgQmBD\n0jZ3Aq8BRcBq4K+BVcMrdvfVwXIaGxsz83FBWc7deXrXG9z3m2Z+/WI7pYX5rFzawP+4dA6zayaF\n3TwROU2pBEAbMDNpvj4oG+Lue0mcAWBm5cD17t6ZtMqHgB+7+9GkbV4NJnvN7H4SISIZ5OhAnJ9t\n3cvqp17hhVcPUlNezOevPo8bLp5FlcZ3F5nwUgmAjcBcM5tD4sC/AviT5BXMrAbocPc4iW/2a4bV\nsTIoT95mhru/aokbwq8Fnj+9XZB0O3jkKD/4XQv3P72b1w4e4dxp5fzj9RdyzcI6jZcjkkVGDQB3\n7zezW0l03+QDa9x9u5mtAprcfT1wBfBVM3MSXUCfGdzezGaTOIP49bCqHzazWsCALcCnznhv5Iy0\nxg5z/9O7+eHGPXT39vP2s6fy1esu5PLzasnThV2RrGNj8aDhsdLY2OhNTU1hNyPrbGs9wH2/aebn\n2xK9cu+7aAY3X3Y2C+omh9wyEUkHM4u4e+Pwct2onaPicefJF/ex+qlmnm3uoLy4gI+/YzYf/8M5\nGhdeJEcoAHJMPO48uqmVf32qmV37ujmrsoS/WX4+K5Y2UFlSGHbzRGQcKQByyGsHjvD5//Mcv921\nn/kzKvmnDy/kvRdpRE6RXKUAyBGPbXuVO3+0jb7+OP/rAxeyculMjcgpkuMUAFmu68hRvrR+B49u\nauVt9ZP5xocXcnZtedjNEpEMoADIYk27O/js2i20xXr4syvP5bar5qq7R0SGKACy0NGBON/81Uvc\n88Qu6qpKWfunb6dxdnXYzRKRDKMAyDLN7d38xQ+3sLX1AB9cUs8X/3g+Fbq7R0RGoADIEu7O9/+r\nha/87AWKCvL49g2LWX7hjLCbJSIZTAGQBfZ393LHo1v5zxf28Yfn1vC1//42zpqsp3CJyMkpACa4\nx3//On+1bisHj/Tzd++bz8ffMVvj9ohIShQAE1RP3wB3P7aD7z3bwvlnVfC9T17M+WdVht0sEZlA\nFAAT0LbWA/z5DzfT3H6Imy+bw+1Xz9MwzSJyyhQAE8hA3PmXX7/MN375IjXlxTz8yYu59NyasJsl\nIhOUAmCC2NNxmM+t3cLG3THee9EM7r52AVPK9FQuETl9CoAM5+78eHMbd/10OwBf/9Db+MCiOo3j\nIyJnTAGQwToP9/G3P3men299lT+YXcXXP7SQmdVlYTdLRLKEAiBDPb1rP7evfY793b385bvn8anL\nzyFft3eKSBqlNDKYmS0zs51mtsvM7hhh+Swz+5WZbTWzJ82sPmnZgJltCV7rk8rnmNnvgjp/aGbq\n0AZ6+we4++c7uOE7v6OsOJ8ff/pSPvNH5+rgLyJpN2oAmFk+cA/wHmA+sNLM5g9b7WvAg+5+EbAK\n+GrSsh53Xxi83p9U/o/AN9z9XCAGfOIM9iMr7Hyti2u+9TT3/eYVPnrJLH5+22VcWK/n8orI2Ejl\nDGApsMvdm929D3gEuGbYOvOBx4PpJ0ZYfhxLXMG8ElgXFP0bcG2qjc428bjz3d++wh9/67fs7+5l\nzU2NfPnaBZQW6d5+ERk7qQRAHbAnab41KEv2HHBdMP0BoMLMpgbzJWbWZGbPmtngQX4q0Onu/Sep\nEwAzuyXYvqm9vT2F5k4srx04wo1r/osv/2wH75xbw7//xTu58vzpYTdLRHJAui4Cfx74lpndBDwF\ntAEDwbJZ7t5mZmcDj5vZNuBAqhW7+2pgNUBjY6Onqb0ZQY9pFJEwpRIAbcDMpPn6oGyIu+8lOAMw\ns3LgenfvDJa1Be/NZvYksAh4FJhiZgXBWcCb6sxmXUeO8vf/dwfrInpMo4iEJ5UuoI3A3OCunSJg\nBbA+eQUzqzGzwbruBNYE5VVmVjy4DnApsMPdncS1gg8G23wM+OmZ7sxE0LS7g+Xf/A0/2tTKbVee\ny7r/+Q4d/EUkFKMGQPAN/VZgA/ACsNbdt5vZKjMbvKvnCmCnmb0ITAfuDsrfCjSZ2XMkDvj/4O47\ngmV/DXzOzHaRuCbw3TTtU8a676lmPvSvzwCw9k/fzu1Xz9MzekUkNJb4Mj4xNDY2elNTU9jNOC19\n/XEWfGkDl5w9lXv+ZJEe0ygi48bMIu7eOLxcXz/Hyfa9B+jrj7PyD2bq4C8iGUEBME4i0RgAS2ZV\nhdwSEZEEBcA4iURjzKwuZVqlntUrIplBATAO3J2maIwlDfr2LyKZQwEwDlpjPbR39ar7R0QyigJg\nHGxqSfT/L1YAiEgGUQCMg6bdMSYV5XP+WZVhN0VEZIgCYBxEojEWNVRpTH8RySgKgDHW3dvP7187\nqO4fEck4CoAx9tyeTuKu+/9FJPMoAMZYJBrDDBY1TAm7KSIix1EAjLGmaIx50yuo1PAPIpJhFABj\nKB53Nkdj6v8XkYykABhDL+3rpqu3X78AFpGMpAAYQxoATkQymQJgDEWiMaZOKmLW1LKwmyIi8iYK\ngDEUiXawZFaVHvQuIhlJATBG9nf3svuNw+r+EZGMlVIAmNkyM9tpZrvM7I4Rls8ys1+Z2VYze9LM\n6oPyhWb2jJltD5Z9OGmbB8zsFTPbErwWpm+3wrdJ/f8ikuFGDQAzywfuAd4DzAdWmtn8Yat9DXjQ\n3S8CVgFfDcoPAze6+wXAMuCfzCz5F1F/6e4Lg9eWM9yXjBJpiVGYbyyomxx2U0RERpTKGcBSYJe7\nN7t7H/AIcM2wdeYDjwfTTwwud/cX3f2lYHovsA+oTUfDM92maIwFdZMpKcwPuykiIiNKJQDqgD1J\n861BWbLngOuC6Q8AFWY2NXkFM1sKFAEvJxXfHXQNfcPMikf642Z2i5k1mVlTe3t7Cs0NX2//AM+1\nHqBR3T8iksHSdRH488DlZrYZuBxoAwYGF5rZDOAh4OPuHg+K7wTOB/4AqAb+eqSK3X21uze6e2Nt\n7cQ4edi+9yB9/XH1/4tIRitIYZ02YGbSfH1QNiTo3rkOwMzKgevdvTOYrwR+Dvytuz+btM2rwWSv\nmd1PIkSywuAF4MX6BbCIZLBUzgA2AnPNbI6ZFQErgPXJK5hZjZkN1nUnsCYoLwJ+TOIC8bph28wI\n3g24Fnj+THYkk0SiMWZWlzKtsiTspoiInNCoAeDu/cCtwAbgBWCtu283s1Vm9v5gtSuAnWb2IjAd\nuDso/xDwTuCmEW73fNjMtgHbgBrgK+naqTC5O03RmMb/EZGMl0oXEO7+GPDYsLK7kqbXAetG2O57\nwPdOUOeVp9TSCaI11kN7Vy9LZleH3RQRkZPSL4HTbGgAOJ0BiEiGUwCkWSQaY1JRPvPOqgi7KSIi\nJ6UASLNINMaihiry8zQAnIhkNgVAGnX39vP71w7qCWAiMiEoANJoS0sncdcAcCIyMSgA0igSjWEG\nixqmjL6yiEjIFABpFGmJMW96BZUlhWE3RURkVAqANInHnc3RmPr/RWTCUACkyUv7uunq7df9/yIy\nYSgA0iSiJ4CJyASjAEiTpmgHNeVFzJpaFnZTRERSogBIk03RGIsbqkgMbioikvkUAGmwv7uX3W8c\nVvePiEwoCoA02KT+fxGZgBQAaRBpiVGUn8eCuslhN0VEJGUKgDSI7I6xoK6SksL8sJsiIpIyBcAZ\n6u0fYGvbAXX/iMiEowA4Q9v3HqSvP64AEJEJJ6UAMLNlZrbTzHaZ2R0jLJ9lZr8ys61m9qSZ1Sct\n+5iZvRS8PpZUvsTMtgV1ftMm6P2TgxeAF+sXwCIywYwaAGaWD9wDvAeYD6w0s/nDVvsa8KC7XwSs\nAr4abFsNfBG4GFgKfNHMBo+U9wI3A3OD17Iz3psQRKIxZlaXMq2yJOymiIicklTOAJYCu9y92d37\ngEeAa4atMx94PJh+Imn5u4FfunuHu8eAXwLLzGwGUOnuz7q7Aw8C157hvow7d6cpGqNxlh4ALyIT\nTyoBUAfsSZpvDcqSPQdcF0x/AKgws6kn2bYumD5ZnQCY2S1m1mRmTe3t7Sk0d/y0xnpo7+rVCKAi\nMiGl6yLw54HLzWwzcDnQBgyko2J3X+3uje7eWFtbm44q02ZoADj1/4vIBFSQwjptwMyk+fqgbIi7\n7yU4AzCzcuB6d+80szbgimHbPhlsXz+s/Lg6J4JINMakonzmnVURdlNERE5ZKmcAG4G5ZjbHzIqA\nFcD65BXMrMbMBuu6E1gTTG8ArjazquDi79XABnd/FThoZpcEd//cCPw0DfszriLRGIsaqsjPm5A3\nMIlIjhs1ANy9H7iVxMH8BWCtu283s1Vm9v5gtSuAnWb2IjAduDvYtgP4MokQ2QisCsoAPg18B9gF\nvAz8Il07NR66e/v5/WsHdf+/iExYqXQB4e6PAY8NK7sraXodsO4E267h2BlBcnkTsOBUGptJtrR0\nEncNACciE5d+CXyaItEYZrCwYUrYTREROS0KgNMUaYkxb3oFlSWFYTdFROS0KABOQzzubI7GdP+/\niExoCoDT8OK+Lrp6+2lUAIjIBKYAOA0RPQFMRLKAAuA0RKIxasqLaKguC7spIiKnTQFwGjZFYyxu\nqGKCjmAtIgIoAE7Z/u5edr9xWN0/IjLhKQBO0WD/f+NsBYCITGwKgFO0KRqjKD+PC94yOeymiIic\nEQXAKYpEYyyoq6SkMD/spoiInBEFwCno7R9ga9sB9f+LSFZQAJyC7XsP0tcfVwCISFZQAJyCyO7E\nBWANASEi2UABcAoi0RgN1WVMqygJuykiImdMAZAidyfSElP3j4hkDQVAilpjPbR39ar7R0SyhgIg\nRUMDwDUoAEQkO6QUAGa2zMx2mtkuM7tjhOUNZvaEmW02s61mtjwov8HMtiS94ma2MFj2ZFDn4LJp\n6d219GqKdlBeXMC8syrCboqISFqM+kxgM8sH7gHeBbQCG81svbvvSFrtCyQeFn+vmc0n8fzg2e7+\nMPBwUM+FwE/cfUvSdjcEzwbOeJFoJ4sappCfpwHgRCQ7pHIGsBTY5e7N7t4HPAJcM2wdByqD6cnA\n3hHqWRlsO+F0HTnKztcOsljdPyKSRVIJgDpgT9J8a1CW7EvAR8yslcS3/9tGqOfDwA+Gld0fdP/8\nnZ1gbGUzu8XMmsysqb29PYXmpt9zew4Qdz0ARkSyS7ouAq8EHnD3emA58JCZDdVtZhcDh939+aRt\nbnD3C4HLgtdHR6rY3Ve7e6O7N9bW1qapuacmEo1hBgsbpoTy90VExkIqAdAGzEyarw/Kkn0CWAvg\n7s8AJUBN0vIVDPv27+5twXsX8H0SXU0ZqSnawbzpFVSWFIbdFBGRtEklADYCc81sjpkVkTiYrx+2\nTgtwFYCZvZVEALQH83nAh0jq/zezAjOrCaYLgfcBz5OBBuLOlpZOdf+ISNYZ9S4gd+83s1uBDUA+\nsMbdt5vZKqDJ3dcDtwP3mdlnSVwQvsndPajincAed29OqrYY2BAc/POB/wTuS9tepdFL+7ro6u1X\nAIhI1hk1AADc/TESF3eTy+5Kmt4BXHqCbZ8ELhlWdghYcoptDcXQD8AUACKSZfRL4FFEojFqyoto\nqC4LuykiImmlABhFJBpjcUMVJ7hLVURkwlIAnER7Vy/RNw7rAfAikpUUACexqUX9/yKSvRQAJ7Ep\nGqMoP48L3jI57KaIiKSdAuAkItEYC+oqKSnMD7spIiJppwA4gd7+Aba2HVD3j4hkLQXACTzfdpC+\n/jhLZlWH3RQRkTGhADiBTcEPwBbP0gBwIpKdFAAnEInGaKguY1pFSdhNEREZEwqAEbg7kZaY+v9F\nJKspAEawp6OH9q5eFisARCSLKQBGEGnpAKBRASAiWUwBMIJINEZ5cQHnTa8IuykiImNGATCCSLST\nRQ1TyM/TAHAikr0UAMN0HTnKztcOsrhB3T8ikt0UAMNs2dNJ3DUAnIhkPwXAMJFoDDNY1KAfgIlI\ndkspAMxsmZntNLNdZnbHCMsbzOwJM9tsZlvNbHlQPtvMesxsS/D6l6RtlpjZtqDOb1qGPHElEo0x\nb3oFFSWFYTdFRGRMjRoAZpYP3AO8B5gPrDSz+cNW+wKw1t0XASuAbycte9ndFwavTyWV3wvcDMwN\nXstOfzfSYyDubGnpVPePiOSEVM4AlgK73L3Z3fuAR4Brhq3jQGUwPRnYe7IKzWwGUOnuz7q7Aw8C\n155Sy8fAS/u66OrtVwCISE5IJQDqgD1J861BWbIvAR8xs1bgMeC2pGVzgq6hX5vZZUl1to5SJwBm\ndouZNZlZU3t7ewrNPX1Nu/UEMBHJHem6CLwSeMDd64HlwENmlge8CjQEXUOfA75vZpUnqedN3H21\nuze6e2NtbW2amjuyTdEYNeVFNFSXjenfERHJBAUprNMGzEyarw/Kkn2CoA/f3Z8xsxKgxt33Ab1B\necTMXgbOC7avH6XOcTc4AFyGXI8WERlTqZwBbATmmtkcMysicZF3/bB1WoCrAMzsrUAJ0G5mtcFF\nZMzsbBIXe5vd/VXgoJldEtz9cyPw07Ts0Wlq7+ol+sZhdf+ISM4Y9QzA3fvN7FZgA5APrHH37Wa2\nCmhy9/XA7cB9ZvZZEheEb3J3N7N3AqvM7CgQBz7l7h1B1Z8GHgBKgV8Er9BsalH/v4jkllS6gHD3\nx0hc3E0uuytpegdw6QjbPQo8eoI6m4AFp9LYsRSJxijKz+OCt0wOuykiIuNCvwQORKIxFtRVUlKY\nH3ZTRETGhQIA6O0fYFvrARpn6wHwIpI7FADA820H6RuIawRQEckpCgAS9/8DLJ6lAeBEJHcoAICm\naAcN1WVMqygJuykiIuMm5wPA3YlENQCciOSenA+APR097O/uVQCISM7J+QCItCR+l6YAEJFcowCI\nxigvLuC86RVhN0VEZFzlfAA07Y6xqGEK+XkaAE5EcktOB0DXkaPsfL1L9/+LSE7K6QDYsqcTd2ic\nrQAQkdyT0wEQicYwg4Uz9QMwEck9OR8A86ZXUFFSGHZTRETGXc4GwEDc2dyiH4CJSO7K2QB48fUu\nunv7FQAikrNyNgAiwQBwjbM0BLSI5KacDYBN0Rg15cXMrC4NuykiIqFIKQDMbJmZ7TSzXWZ2xwjL\nG8zsCTPbbGZbzWx5UP4uM4uY2bbg/cqkbZ4M6twSvKalb7dGF2mJsWTWFBLPpBcRyT2jPhPYzPKB\ne4B3Aa3ARjNbHzwHeNAXgLXufq+ZzSfx/ODZwH7gj919r5ktIPFg+bqk7W4Ing08rtq7eom+cZgb\nLm4Y7z8tIpIxUjkDWArscvdmd+8DHgGuGbaOA5XB9GRgL4C7b3b3vUH5dqDUzIrPvNlnZrD/XxeA\nRSSXpRIAdcCepPlWjv8WD/Al4CNm1kri2/9tI9RzPbDJ3XuTyu4Pun/+zk7QF2Nmt5hZk5k1tbe3\np9Dc0W1qiVGUn8eCuslpqU9EZCJK10XglcAD7l4PLAceMrOhus3sAuAfgT9N2uYGd78QuCx4fXSk\nit19tbs3untjbW1tWhobica4sH4yxQX5aalPRGQiSiUA2oCZSfP1QVmyTwBrAdz9GaAEqAEws3rg\nx8CN7v7y4Abu3ha8dwHfJ9HVNOZ6+wfY1npA3T8ikvNSCYCNwFwzm2NmRcAKYP2wdVqAqwDM7K0k\nAqDdzKYAPwfucPenB1c2swIzGwyIQuB9wPNnujOpeL7tAH0DcY0AKiI5b9QAcPd+4FYSd/C8QOJu\nn+1mtsrM3h+sdjtws5k9B/wAuMndPdjuXOCuYbd7FgMbzGwrsIXEGcV96d65kQxeAF48SwPAiUhu\nG/U2UAB3f4zExd3ksruSpncAl46w3VeAr5yg2iWpNzN9ItEYDdVlTKsoCePPi4hkjJz6JbC7E4l2\n0qj+fxGR3AqAPR097O/uZbECQEQktwKgKdoB6AdgIiKQYwEQicYoLy7gvOkVYTdFRCR0ORcAixqm\nkJ+nAeBERHImALqOHGXn613q/hERCeRMAGzZ04m7+v9FRAblTAA07Y5hBgtn6gdgIiKQQwGwqSXG\nvOkVVJQUht0UEZGMkBMBMBB3Nrd0qvtHRCRJTgTAi6930d3bT+NsBYCIyKCcCIChJ4A1VIfcEhGR\nzJEzAVBTXszM6tKwmyIikjFSGg10ops7vZyzJpdwgqdOiojkpJwIgE9fcW7YTRARyTg50QUkIiJv\npgAQEclRCgARkRyVUgCY2TIz22lmu8zsjhGWN5jZE2a22cy2mtnypGV3BtvtNLN3p1qniIiMrVED\nwMzygXuA9wDzgZVmNn/Yau9KupcAAAPgSURBVF8g8bD4RcAK4NvBtvOD+QuAZcC3zSw/xTpFRGQM\npXIGsBTY5e7N7t4HPAJcM2wdByqD6cnA3mD6GuARd+9191eAXUF9qdQpIiJjKJUAqAP2JM23BmXJ\nvgR8xMxagceA20bZNpU6ATCzW8ysycya2tvbU2iuiIikIl0XgVcCD7h7PbAceMjM0lK3u69290Z3\nb6ytrU1HlSIiQmo/BGsDZibN1wdlyT5Boo8fd3/GzEqAmlG2Ha3ON4lEIvvNLJpCm0dSA+w/zW2z\nkT6PY/RZHE+fx/Gy4fOYNVJhKgGwEZhrZnNIHKRXAH8ybJ0W4CrgATN7K1ACtAPrge+b2deBtwBz\ngf8CLIU638TdT/sUwMya3L3xdLfPNvo8jtFncTx9HsfL5s9j1ABw934zuxXYAOQDa9x9u5mtAprc\nfT1wO3CfmX2WxAXhm9zdge1mthbYAfQDn3H3AYCR6hyD/RMRkROwxHE6+2Vzip8OfR7H6LM4nj6P\n42Xz55FLvwReHXYDMow+j2P0WRxPn8fxsvbzyJkzABEROV4unQGIiEiSnAgAjTuUYGYzgzGbdpjZ\ndjP787DblAmC4Uk2m9nPwm5L2MxsipmtM7Pfm9kLZvb2sNsUFjP7bPDv5Hkz+0Fwe3tWyfoA0LhD\nx+kHbnf3+cAlwGdy+LNI9ufAC2E3IkP8M/Dv7n4+8DZy9HMxszrgz4BGd19A4m7FFeG2Kv2yPgDQ\nuEND3P1Vd98UTHeR+Mc94hAcucLM6oH3At8Juy1hM7PJwDuB7wK4e5+7d4bbqlAVAKVmVgCUcWyM\ns6yRCwGQ8rhDucTMZgOLgN+F25LQ/RPwV0A87IZkgDkkfsB5f9Al9h0zmxR2o8Lg7m3A10j8yPVV\n4IC7/0e4rUq/XAgAGcbMyoFHgb9w94NhtycsZvY+YJ+7R8JuS4YoABYD9wZDux8CcvKamZlVkegp\nmENiFINJZvaRcFuVfrkQAKmMZZQzzKyQxMH/YXf/UdjtCdmlwPvNbDeJrsErzex74TYpVK1Aq7sP\nnhWuIxEIuei/Aa+4e7u7HwV+BLwj5DalXS4EwNBYRmZWROJCzvqQ2xQKMzMS/bsvuPvXw25P2Nz9\nTnevd/fZJP6/eNzds+5bXqrc/TVgj5nNC4quIjGMSy5qAS4xs7Lg381VZOEF8VQGg5vQTjSWUcjN\nCsulwEeBbWa2JSj7G3d/LMQ2SWa5DXg4+LLUDHw85PaEwt1/Z2brgE0k7p7bTBb+Ili/BBYRyVG5\n0AUkIiIjUACIiOQoBYCISI5SAIiI5CgFgIhIjlIAiIjkKAWAiEiOUgCIiOSo/w9Zp/76T7ei4AAA\nAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "NuayaautPh4S", "colab_type": "code", "colab": {} }, "source": [ "# Test the architecture\n", "xb, yb = next(gen_sequences(1))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ZC8v1_n1PjWX", "colab_type": "code", "outputId": "9ecb1992-47ce-4105-a006-27cac15e4e36", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "tf.argmax(xb[0], axis=2)" ], "execution_count": 45, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 45 } ] }, { "cell_type": "code", "metadata": { "id": "Vo416SulPm_n", "colab_type": "code", "outputId": "6b57f3fa-dbce-4b3f-accd-e40320a66ec6", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "tf.argmax(model.predict(xb), axis=2)" ], "execution_count": 46, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 46 } ] }, { "cell_type": "code", "metadata": { "id": "E-K9cUsdPvl_", "colab_type": "code", "outputId": "c1d7b74a-6b4b-4f48-f8a5-89b756eda560", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "xb[0].shape" ], "execution_count": 47, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TensorShape([1, 10, 5])" ] }, "metadata": { "tags": [] }, "execution_count": 47 } ] }, { "cell_type": "code", "metadata": { "id": "hpZunGdycVA4", "colab_type": "code", "colab": {} }, "source": [ "# TODO: replicate the evaluation, without teacher forcing" ], "execution_count": 0, "outputs": [] } ] }