{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Gmail style smart compose with char ngram based language model.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "axha_JTtHfXL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This notebook borrows a couple of ideas from the [**Original TensorFlow NMT tutorial**](https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) . But the main focus of this noteobook is to illustrate the power of Char Ngram based Langauge Models learned using an Encoder-Decoder Model and how it is used to solve real world problems. At the first blush smart compose will look very similar to predictive keyboard. But there is a lot more to smart compose. Please look at the accompanying post for more details\n", "\n", "\n", "**This notebook is tested in tensorflow-gpu=1.13.1**\n", "\n", "\n", "\n" ] }, { "metadata": { "id": "0wRyCFFocMNT", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "5fbd6c99-56f6-437c-be9d-f9d746d6c12e" }, "cell_type": "code", "source": [ "# Start by importing all the things we'll need.\n", "%matplotlib inline\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, CuDNNLSTM, Flatten, TimeDistributed, Dropout, LSTMCell, RNN, Bidirectional, Concatenate, Layer\n", "from tensorflow.keras.callbacks import ModelCheckpoint\n", "from tensorflow.python.keras.utils import tf_utils\n", "from tensorflow.keras import backend as K\n", "\n", "import unicodedata\n", "import re\n", "import numpy as np\n", "import os\n", "import time\n", "import shutil\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import string, os \n", "tf.__version__" ], "execution_count": 2, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'1.13.1'" ] }, "metadata": { "tags": [] }, "execution_count": 2 } ] }, { "metadata": { "id": "0KhSQ0slK44g", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "file = open(\"./sample_data/dataset.txt\", 'r')\n", "corpus = [line for line in file]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "1aiUHlDnK44n", "colab_type": "code", "outputId": "fc9a4c58-8ed3-4365-f61b-8b6fe0ec3e63", "colab": { "base_uri": "https://localhost:8080/", "height": 197 } }, "cell_type": "code", "source": [ "corpus[40:50]" ], "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['I will be there\\n',\n", " 'Resignation Please accept my resignation.\\n',\n", " 'I need a response on this issue\\n',\n", " 'Here is the revised version.\\n',\n", " 'Please accept our offer\\n',\n", " 'this sounds acceptable to us.\\n',\n", " 'Great. Thanks.\\n',\n", " 'please find attached a very rough draft\\n',\n", " 'Have a nice weekend\\n',\n", " 'Christmas Party Remembered: Merry Christmas Everyone!\\n']" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "metadata": { "id": "106RyC0uMhOq", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def clean_special_chars(text, punct):\n", " for p in punct:\n", " text = text.replace(p, '')\n", " return text\n", "\n", " \n", "def preprocess(data):\n", " output = []\n", " punct = '#$%&*+-/<=>@[\\\\]^_`{|}~\\t\\n'\n", " for line in data:\n", " pline= clean_special_chars(line.lower(), punct)\n", " output.append(pline)\n", " return output \n", "\n", "\n", "def generate_dataset():\n", " \n", " processed_corpus = preprocess(corpus) \n", " output = []\n", " for line in processed_corpus:\n", " token_list = line\n", " for i in range(1, len(token_list)):\n", " data = []\n", " x_ngram = ' '+ token_list[:i+1] + ' '\n", " y_ngram = ' '+ token_list[i+1:] + ' '\n", " data.append(x_ngram)\n", " data.append(y_ngram)\n", " output.append(data)\n", " print(\"Dataset prepared with prefix and suffixes for teacher forcing technique\")\n", " dummy_df = pd.DataFrame(output, columns=['input','output'])\n", " return output, dummy_df \n", " " ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Swufi2zlVmvx", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", " self.word2idx = {}\n", " self.idx2word = {}\n", " self.vocab = set()\n", " self.create_index()\n", " def create_index(self):\n", " for phrase in self.lang:\n", " self.vocab.update(phrase.split(' '))\n", " self.vocab = sorted(self.vocab)\n", " self.word2idx[\"\"] = 0\n", " self.idx2word[0] = \"\"\n", " for i,word in enumerate(self.vocab):\n", " self.word2idx[word] = i + 1\n", " self.idx2word[i+1] = word\n", "\n", "def max_length(t):\n", " return max(len(i) for i in t)\n", "\n", "def load_dataset():\n", " pairs,df = generate_dataset()\n", " out_lang = LanguageIndex(sp for en, sp in pairs)\n", " in_lang = LanguageIndex(en for en, sp in pairs)\n", " input_data = [[in_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n", " output_data = [[out_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n", "\n", " max_length_in, max_length_out = max_length(input_data), max_length(output_data)\n", " input_data = tf.keras.preprocessing.sequence.pad_sequences(input_data, maxlen=max_length_in, padding=\"post\")\n", " output_data = tf.keras.preprocessing.sequence.pad_sequences(output_data, maxlen=max_length_out, padding=\"post\")\n", " return input_data, output_data, in_lang, out_lang, max_length_in, max_length_out, df" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Sq7Mrs04W20y", "colab_type": "code", "outputId": "89899522-b8a6-4f86-e7ad-e01c082d5a64", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "cell_type": "code", "source": [ "input_data, teacher_data, input_lang, target_lang, len_input, len_target, df = load_dataset()\n", "\n", "\n", "target_data = [[teacher_data[n][i+1] for i in range(len(teacher_data[n])-1)] for n in range(len(teacher_data))]\n", "target_data = tf.keras.preprocessing.sequence.pad_sequences(target_data, maxlen=len_target, padding=\"post\")\n", "target_data = target_data.reshape((target_data.shape[0], target_data.shape[1], 1))\n", "\n", "# Shuffle all of the data in unison. This training set has the longest (e.g. most complicated) data at the end,\n", "# so a simple Keras validation split will be problematic if not shuffled.\n", "\n", "p = np.random.permutation(len(input_data))\n", "input_data = input_data[p]\n", "teacher_data = teacher_data[p]\n", "target_data = target_data[p]\n", "\n" ], "execution_count": 60, "outputs": [ { "output_type": "stream", "text": [ "Dataset prepared with prefix and suffixes for teacher forcing technique\n" ], "name": "stdout" } ] }, { "metadata": { "id": "72SmmNQEYgwW", "colab_type": "code", "outputId": "916c269e-ca87-40bf-bc49-d941f7aabcf7", "colab": { "base_uri": "https://localhost:8080/", "height": 204 } }, "cell_type": "code", "source": [ "pd.set_option('display.max_colwidth', -1)\n", "BUFFER_SIZE = len(input_data)\n", "BATCH_SIZE = 128\n", "embedding_dim = 300\n", "units = 128\n", "vocab_in_size = len(input_lang.word2idx)\n", "vocab_out_size = len(target_lang.word2idx)\n", "df.iloc[60:65]" ], "execution_count": 61, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
inputoutput
60<start> thank you fo <end><start> r your cooperation <end>
61<start> thank you for <end><start> your cooperation <end>
62<start> thank you for <end><start> your cooperation <end>
63<start> thank you for y <end><start> our cooperation <end>
64<start> thank you for yo <end><start> ur cooperation <end>
\n", "
" ], "text/plain": [ " input output\n", "60 thank you fo r your cooperation \n", "61 thank you for your cooperation \n", "62 thank you for your cooperation \n", "63 thank you for y our cooperation \n", "64 thank you for yo ur cooperation " ] }, "metadata": { "tags": [] }, "execution_count": 61 } ] }, { "metadata": { "id": "xmZbR9QhOyoG", "colab_type": "text" }, "cell_type": "markdown", "source": [ "" ] }, { "metadata": { "id": "9qSiFU8pXqL3", "colab_type": "code", "outputId": "71f4c54a-2b6e-4ee0-d4b4-c32b6fa18e95", "colab": { "base_uri": "https://localhost:8080/", "height": 647 } }, "cell_type": "code", "source": [ "# Create the Encoder layers first.\n", "encoder_inputs = Input(shape=(len_input,))\n", "encoder_emb = Embedding(input_dim=vocab_in_size, output_dim=embedding_dim)\n", "\n", "# Use this if you dont need Bidirectional LSTM\n", "# encoder_lstm = CuDNNLSTM(units=units, return_sequences=True, return_state=True)\n", "# encoder_out, state_h, state_c = encoder_lstm(encoder_emb(encoder_inputs))\n", "\n", "encoder_lstm = Bidirectional(CuDNNLSTM(units=units, return_sequences=True, return_state=True))\n", "encoder_out, fstate_h, fstate_c, bstate_h, bstate_c = encoder_lstm(encoder_emb(encoder_inputs))\n", "state_h = Concatenate()([fstate_h,bstate_h])\n", "state_c = Concatenate()([bstate_h,bstate_c])\n", "encoder_states = [state_h, state_c]\n", "\n", "\n", "# Now create the Decoder layers.\n", "decoder_inputs = Input(shape=(None,))\n", "decoder_emb = Embedding(input_dim=vocab_out_size, output_dim=embedding_dim)\n", "decoder_lstm = CuDNNLSTM(units=units*2, return_sequences=True, return_state=True)\n", "decoder_lstm_out, _, _ = decoder_lstm(decoder_emb(decoder_inputs), initial_state=encoder_states)\n", "# Two dense layers added to this model to improve inference capabilities.\n", "decoder_d1 = Dense(units, activation=\"relu\")\n", "decoder_d2 = Dense(vocab_out_size, activation=\"softmax\")\n", "decoder_out = decoder_d2(Dropout(rate=.2)(decoder_d1(Dropout(rate=.2)(decoder_lstm_out))))\n", "\n", "\n", "# Finally, create a training model which combines the encoder and the decoder.\n", "# Note that this model has three inputs:\n", "model = Model(inputs = [encoder_inputs, decoder_inputs], outputs= decoder_out)\n", "\n", "# We'll use sparse_categorical_crossentropy so we don't have to expand decoder_out into a massive one-hot array.\n", "# Adam is used because it's, well, the best.\n", "\n", "model.compile(optimizer=tf.train.AdamOptimizer(), loss=\"sparse_categorical_crossentropy\", metrics=['sparse_categorical_accuracy'])\n", "model.summary()" ], "execution_count": 62, "outputs": [ { "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "input_5 (InputLayer) (None, 16) 0 \n", "__________________________________________________________________________________________________\n", "embedding_4 (Embedding) (None, 16, 300) 152400 input_5[0][0] \n", "__________________________________________________________________________________________________\n", "input_6 (InputLayer) (None, None) 0 \n", "__________________________________________________________________________________________________\n", "bidirectional_2 (Bidirectional) [(None, 16, 256), (N 440320 embedding_4[0][0] \n", "__________________________________________________________________________________________________\n", "embedding_5 (Embedding) (None, None, 300) 171600 input_6[0][0] \n", "__________________________________________________________________________________________________\n", "concatenate_4 (Concatenate) (None, 256) 0 bidirectional_2[0][1] \n", " bidirectional_2[0][3] \n", "__________________________________________________________________________________________________\n", "concatenate_5 (Concatenate) (None, 256) 0 bidirectional_2[0][3] \n", " bidirectional_2[0][4] \n", "__________________________________________________________________________________________________\n", "cu_dnnlstm_5 (CuDNNLSTM) [(None, None, 256), 571392 embedding_5[0][0] \n", " concatenate_4[0][0] \n", " concatenate_5[0][0] \n", "__________________________________________________________________________________________________\n", "dropout_5 (Dropout) (None, None, 256) 0 cu_dnnlstm_5[0][0] \n", "__________________________________________________________________________________________________\n", "dense_4 (Dense) (None, None, 128) 32896 dropout_5[0][0] \n", "__________________________________________________________________________________________________\n", "dropout_4 (Dropout) (None, None, 128) 0 dense_4[0][0] \n", "__________________________________________________________________________________________________\n", "dense_5 (Dense) (None, None, 572) 73788 dropout_4[0][0] \n", "==================================================================================================\n", "Total params: 1,442,396\n", "Trainable params: 1,442,396\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n" ], "name": "stdout" } ] }, { "metadata": { "id": "sFrEQmkwYDPX", "colab_type": "code", "outputId": "bf99f03d-2048-4c0d-f1f8-a6d9edb18999", "colab": { "base_uri": "https://localhost:8080/", "height": 415 } }, "cell_type": "code", "source": [ "# Note, we use 20% of our data for validation.\n", "epochs = 10\n", "history = model.fit([input_data, teacher_data], target_data,\n", " batch_size= BATCH_SIZE,\n", " epochs=epochs,\n", " validation_split=0.2)\n" ], "execution_count": 63, "outputs": [ { "output_type": "stream", "text": [ "Train on 118120 samples, validate on 29531 samples\n", "Epoch 1/10\n", "118120/118120 [==============================] - 15s 130us/sample - loss: 0.6460 - sparse_categorical_accuracy: 0.8711 - val_loss: 0.1746 - val_sparse_categorical_accuracy: 0.9520\n", "Epoch 2/10\n", "118120/118120 [==============================] - 14s 119us/sample - loss: 0.1022 - sparse_categorical_accuracy: 0.9688 - val_loss: 0.0233 - val_sparse_categorical_accuracy: 0.9916\n", "Epoch 3/10\n", "118120/118120 [==============================] - 14s 120us/sample - loss: 0.0291 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.0153 - val_sparse_categorical_accuracy: 0.9918\n", "Epoch 4/10\n", "118120/118120 [==============================] - 14s 119us/sample - loss: 0.0207 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.0148 - val_sparse_categorical_accuracy: 0.9923\n", "Epoch 5/10\n", "118120/118120 [==============================] - 14s 119us/sample - loss: 0.0182 - sparse_categorical_accuracy: 0.9912 - val_loss: 0.0141 - val_sparse_categorical_accuracy: 0.9922\n", "Epoch 6/10\n", "118120/118120 [==============================] - 14s 118us/sample - loss: 0.0174 - sparse_categorical_accuracy: 0.9914 - val_loss: 0.0139 - val_sparse_categorical_accuracy: 0.9923\n", "Epoch 7/10\n", "118120/118120 [==============================] - 14s 122us/sample - loss: 0.0165 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.0136 - val_sparse_categorical_accuracy: 0.9920\n", "Epoch 8/10\n", "118120/118120 [==============================] - 14s 121us/sample - loss: 0.0150 - sparse_categorical_accuracy: 0.9918 - val_loss: 0.0132 - val_sparse_categorical_accuracy: 0.9923\n", "Epoch 9/10\n", "118120/118120 [==============================] - 14s 118us/sample - loss: 0.0155 - sparse_categorical_accuracy: 0.9916 - val_loss: 0.0132 - val_sparse_categorical_accuracy: 0.9922\n", "Epoch 10/10\n", "118120/118120 [==============================] - 14s 118us/sample - loss: 0.0147 - sparse_categorical_accuracy: 0.9918 - val_loss: 0.0132 - val_sparse_categorical_accuracy: 0.9924\n" ], "name": "stdout" } ] }, { "metadata": { "id": "eJexzgyki8vT", "colab_type": "code", "outputId": "18ee6f70-0ac0-4cb2-845c-918a65b94f12", "colab": { "base_uri": "https://localhost:8080/", "height": 269 } }, "cell_type": "code", "source": [ "# Plot the results of the training.\n", "import matplotlib.pyplot as plt\n", "\n", "plt.plot(history.history['loss'], label=\"Training loss\")\n", "plt.plot(history.history['val_loss'], label=\"Validation loss\")\n", "plt.show()" ], "execution_count": 64, "outputs": [ { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHJlJREFUeJzt3Xt8XHWd//HXZ3JpmrSZ3gK2k9IW\nqEpIVCRb+QH2p9wsP9dWBXfLw3VlV6mylKvuLqwu/hb3tw/FfSi4v8KPCvx2dVVErv1hl+IFBVGk\n4SKlLYVQCqS0NL3QS3rJ7fP7Y2aSyXTSTJOZTM457+fjESbnzDfnfDql73NyPudi7o6IiIRLrNQF\niIhI4SncRURCSOEuIhJCCncRkRBSuIuIhJDCXUQkhBTuIiIhpHAXEQkhhbuISAiVl2rF06ZN89mz\nZ5dq9SIigfT0009vd/e6ocaVLNxnz55NS0tLqVYvIhJIZvZaPuN0WEZEJIQU7iIiIaRwFxEJIYW7\niEgIKdxFREJI4S4iEkIKdxGREApcuLds2sk3H34RPR5QRGRwgQv3Fzbv5tZfv8LWPQdLXYqIyJgV\nuHBvqo8DsKZtd4krEREZuwIX7g3T48QsuQcvIiK5BS7cx1eWceIxE1ijcBcRGVTgwh2gMRFnzeY9\naqqKiAwikOHelIizfd8h3tpzqNSliIiMSYENd0CHZkREBhHIcG+YUUvMFO4iIoMJZLhXV5ZzQt0E\nnTEjIjKIQIY7JA/NaM9dRCS3wIZ7YyJO+95DvKUrVUVEDpNXuJvZAjPbYGatZnbtIGP+zMzWmdla\nM/tRYcs8nK5UFREZ3JDhbmZlwDLgfKABuMjMGrLGzAWuA85w95OBq4pQ6wAN02sxNVVFRHLKZ899\nHtDq7hvdvRO4C1iUNeYSYJm77wJw922FLfNwNePUVBURGUw+4Z4A3siYbkvNy/RO4J1m9oSZPWlm\nCwpV4JE0JeK88KbCXUQkW6EaquXAXOBDwEXA98xsUvYgM1tiZi1m1tLe3j7ilTYm4ry15xDb9qqp\nKiKSKZ9w3wzMzJiuT83L1AascPcud38VeIlk2A/g7svdvdndm+vq6oZbc5/0lao6NCMiMlA+4b4a\nmGtmc8ysElgMrMga8wDJvXbMbBrJwzQbC1hnTifPSDVV2/YUe1UiIoEyZLi7ezewFFgFrAfudve1\nZnaDmS1MDVsF7DCzdcCjwN+6+45iFZ1WM66c46fV6IwZEZEs5fkMcveVwMqseddnfO/ANamvUdWU\niPPkxp2jvVoRkTEtsFeopjUm4mzdc5D2vbr9r4hIWijCHdRUFRHJFPhwP3lGLaArVUVEMgU+3CdW\nVaipKiKSJfDhDslDMzosIyLSLxTh3pSIs2X3QbbvU1NVRARCEu6NeqaqiMgAoQj3kxPJpuoLure7\niAgQknCvrapgjpqqIiJ9QhHuoKaqiEim0IR7U6KWN3cfZIeaqiIi4Ql3NVVFRPqFLtx1aEZEJETh\nXltVweyp1dpzFxEhROEO6aaqHtwhIhKqcG9KxNn89gF2dnSWuhQRkZIKXbiDmqoiIqEK95PVVBUR\nAUIW7vHxFcyaWs0a3YZARCIuVOEOyaaqDsuISNSFLtzTTdVdaqqKSISFMtxBTVURiba8wt3MFpjZ\nBjNrNbNrc7x/sZm1m9lzqa/PF77U/DTOULiLiJQPNcDMyoBlwLlAG7DazFa4+7qsoT9x96VFqPGo\nxKsrOG5Ktc6YEZFIy2fPfR7Q6u4b3b0TuAtYVNyyRqZJTVURibh8wj0BvJEx3Zaal+0CM3vezO4x\ns5kFqW6YGhNx2napqSoi0VWohur/A2a7+3uAnwP/kWuQmS0xsxYza2lvby/Qqg+XbqqufVP3mRGR\naMon3DcDmXvi9al5fdx9h7unn5JxO3BqrgW5+3J3b3b35rq6uuHUm5fG1DNVdWhGRKIqn3BfDcw1\nszlmVgksBlZkDjCz6RmTC4H1hSvx6E2qrmTmlPFqqopIZA15toy7d5vZUmAVUAbc6e5rzewGoMXd\nVwBXmNlCoBvYCVxcxJrzoqaqiETZkOEO4O4rgZVZ867P+P464LrCljYyjYk4K9dsZff+LuLVFaUu\nR0RkVIXuCtW0dFP1hTe19y4i0RPacNeVqiISZaEN98k1ldRPHq9wF5FICm24Q/LQjM6YEZEoCnW4\nNybivLZjP7sPdJW6FBGRURX6cAdYq713EYmYUIe77u0uIlEV6nCfUlNJYpKaqiISPaEOd0jeZ0ZN\nVRGJmtCHe1MizqYd+9lzUE1VEYmO0Id7uqmqvXcRiZLQh3uTwl1EIij04T51wjhmxKtYs1kP7hCR\n6Ah9uEPy0Iz23EUkSiIR7k2JOK9u71BTVUQiIxLh3lifvlJVh2ZEJBoiEe5qqopI1EQi3KdNGMf0\neJWuVBWRyIhEuIOaqiISLZEJ96ZEnI3bO9irpqqIRECkwh1g7ZtqqopI+EUm3HUbAhGJkrzC3cwW\nmNkGM2s1s2uPMO4CM3Mzay5ciYVRN3Ec76hVU1VEomHIcDezMmAZcD7QAFxkZg05xk0ErgT+UOgi\nC6UxEVe4i0gk5LPnPg9odfeN7t4J3AUsyjHu68A3gYMFrK+g0leq7jvUXepSRESKKp9wTwBvZEy3\npeb1MbP3AzPd/WcFrK3gmuprcdczVUUk/EbcUDWzGPBt4Et5jF1iZi1m1tLe3j7SVR+1Rj1TVUQi\nIp9w3wzMzJiuT81Lmwg0Ar82s03AacCKXE1Vd1/u7s3u3lxXVzf8qofpmIlVHFs7TqdDikjo5RPu\nq4G5ZjbHzCqBxcCK9Jvuvtvdp7n7bHefDTwJLHT3lqJUPEJNaqqKSAQMGe7u3g0sBVYB64G73X2t\nmd1gZguLXWChNSbivNK+jw41VUUkxMrzGeTuK4GVWfOuH2Tsh0ZeVvE0JeK4w7ote/iT2VNKXY6I\nSFFE5grVtPRtCNa06dCMiIRX5ML9mNoqjpk4TrchEJFQi1y4g5qqIhJ+kQz3dFN1f6eaqiISTpEM\n96ZEnF6HdTrfXURCKprhXq8rVUUk3CIZ7sfWVlE3cZzCXURCK5LhDslDMzpjRkTCKrLh3piI07pN\nTVURCafIhnu6qbp+i5qqIhI+kQ33xkQtoCtVRSScIhvu76itYtqEStZs1p67iIRPZMPdzGhUU1VE\nQiqy4Q7J4+4vb9vLgc6eUpciIlJQkQ73xvSVqmqqikjIRDrc07f/1aEZEQmbSIf79HgVU2sqdaWq\niIROpMNdTVURCatIhzukm6r7ONilpqqIhEfkw70xEaen19VUFZFQiXy4p2//q0MzIhImkQ/3GfEq\nptRU6jYEIhIqkQ/3dFNVZ8yISJjkFe5mtsDMNphZq5ldm+P9L5rZGjN7zsx+a2YNhS+1eJoStWqq\nikioDBnuZlYGLAPOBxqAi3KE94/cvcnd3wfcCHy74JUWUVOqqarb/4pIWOSz5z4PaHX3je7eCdwF\nLMoc4O6ZqVgDeOFKLL5GXakqIiFTnseYBPBGxnQb8IHsQWZ2GXANUAmclWtBZrYEWAJw3HHHHW2t\nRZOYNJ7J1RU67i4ioVGwhqq7L3P3E4C/B746yJjl7t7s7s11dXWFWvWI9TdVdVhGRMIhn3DfDMzM\nmK5PzRvMXcDHR1JUKTQl4rz81l41VUUkFPIJ99XAXDObY2aVwGJgReYAM5ubMflR4OXClTg6mhJx\nunudDVv3lroUEZERGzLc3b0bWAqsAtYDd7v7WjO7wcwWpoYtNbO1ZvYcyePuny1axUWSbqrquLuI\nhEE+DVXcfSWwMmve9RnfX1ngukZd/eTxTKqu0BkzIhIKkb9CNc3MaNKVqiISEgr3DI2JOC+9tZdD\n3WqqikiwKdwzNCXidPWoqSoiwadwz9CkpqqIhITCPUP95PHEx6upKiLBp3DPoKaqiISFwj1LYyLO\nhq1qqopIsCncs6Sbqi9t3VfqUkREhk3hnkVNVREJA4V7lplTkk1VhbuIBJnCPUvy9r+1OmNGRAJN\n4Z5Duqna2d1b6lJERIZF4Z5DUyJOZ08vL72lK1VFJJgU7jmoqSoiQadwz+G4KdVMrCpXuItIYCnc\nczAzGmfE1VQVkcBSuA+iqT7Oi1vUVBWRYFK4D6JRTVURCTCF+yDSTVUdmhGRIFK4D2LWlGomjlNT\nVUSCSeE+iFjMOFlXqopIQCncj6ApEWf91r109aipKiLBkle4m9kCM9tgZq1mdm2O968xs3Vm9ryZ\n/dLMZhW+1NHXmIjT2a2mqogEz5DhbmZlwDLgfKABuMjMGrKGPQs0u/t7gHuAGwtdaCmoqSoiQZXP\nnvs8oNXdN7p7J3AXsChzgLs/6u77U5NPAvWFLbM0Zk+tYYKaqiISQPmEewJ4I2O6LTVvMJ8D/msk\nRY0VsZhx8oxa1mzeU+pSRESOSkEbqmb2F0Az8K1B3l9iZi1m1tLe3l7IVRdNUyLO+i171FQVkUDJ\nJ9w3AzMzputT8wYws3OArwAL3f1QrgW5+3J3b3b35rq6uuHUO+qa6pNN1dZteqaqiARHPuG+Gphr\nZnPMrBJYDKzIHGBmpwC3kQz2bYUvs3QadftfEQmgIcPd3buBpcAqYD1wt7uvNbMbzGxhati3gAnA\nT83sOTNbMcjiAmdOqqmqM2ZEJEjK8xnk7iuBlVnzrs/4/pwC1zVmxGJGw4xa7bmLSKDoCtU8pJuq\n3WqqikhABC/cuw7Ayz8f1VU2JeIc7OqltV1NVREJhuCF+2Pfgh/9Gbz+5Kitsq+p2qZDMyISDMEL\n9zOugknHwb2XwMHRCdvjp9VQU1mmpqqIBEbwwr2qFi64A/ZshoeuAfeirzJ5pWpcTVURCYzghTtA\nfTN86Dp44R54/iejssrGRJx1aqqKSEAEM9wBPngNHHc6/OzLsPPVoq+uqb6Wg129vNLeUfR1iYiM\nVHDDPVYGn1wOFoP7LoGerqKurklXqopIgAQ33AEmzYSP3QRtq+E3xb2F/JxpE6hWU1VEAiLY4Q7Q\n+El436fh8X+F135XtNWU9d3+V+EuImNf8MMd4PxvwuTZcN8SOPB20VbTmIiz7s099PQW/wwdEZGR\nCEe4j5sIF9wOe7fAQ1cV7fTIpkScA109vKIrVUVkjAtHuAMkToUP/wOsvR+e+1FRVtGkK1VFJCDC\nE+6QvHp11pmw8m9hxysFX/zxdcmmqo67i8hYF65wj5XBJ2+Dsgq49/MFPz2yLGY0TK/VGTMiMuaF\nK9wB4vWw8Lvw5jPw6L8UfPGNiThr1VQVkTEufOEO0LAITvkM/PY78OrjBV10uqm6UU1VERnDwhnu\nAAu+AVOOh/u/APt3FmyxTfW6UlVExr7whvu4CXDhHbBvW0FPjzyhbgLjK9RUFZGxLbzhDjDjFDjr\nq7DuQXj2BwVZZFnqmapqqorIWBbucAc4/QqYMx/+6+9he2tBFtk4o1ZNVREZ08If7rEYfOI2KB8H\n934OujtHvMjGRJz9nT28ul1NVREZm/IKdzNbYGYbzKzVzK7N8f58M3vGzLrN7MLClzlCtTNg4b/B\nlufg0X8e8eLUVBWRsW7IcDezMmAZcD7QAFxkZg1Zw14HLgaKc91/IZz0MTj1Ynjiu7DxNyNa1Il1\nE6iqiLGmbU9hahMRKbB89tznAa3uvtHdO4G7gEWZA9x9k7s/D4ztZ9B95F9g6okjPj2yvCzGSdNr\n+cOrOzjY1VPAAkVECiOfcE8Ab2RMt6XmBU9lTfL0yI7tsOLyEZ0e+cn317P2zT189LuP8+zruwpY\npIjIyI1qQ9XMlphZi5m1tLe3j+aq+01/L5zzNXjxIXj634e9mM+cNosffG4eBzp7uODW33Hjwy9y\nqFt78SIyNuQT7puBmRnT9al5R83dl7t7s7s319XVDWcRhXHaZXD8h+Hh66D9pWEv5oNz63j46vlc\neGo9t/z6FRb97yd0/ruIjAn5hPtqYK6ZzTGzSmAxsKK4ZRVZLAYfvxUqxqdOjzw07EXVVlVw44Xv\n5c6Lm9nZ0cnHlz3BTb94ia6esd1+EJFwGzLc3b0bWAqsAtYDd7v7WjO7wcwWApjZn5hZG/Ap4DYz\nW1vMoguidjosWgZbn4df3jDixZ317mN55Or5/Ol7pnPTL17mE7c8wYatewtQqIjI0TMv0iPphtLc\n3OwtLS0lWfcAD10DLXfAZ+6HE84qyCIffmELX7n/BfYe7Oaqc+ey5IPHU14W/uvFRKT4zOxpd28e\napwS57x/hmnvgvsvhY4dBVnkgsbpPHL1fM4+6RhufHgDF/6f3+u5qyIyqhTuldXJ0yMP7IQVSwt2\n98ipE8Zxy6ffz82L38er2zv4Hzc/zh2/fZVe3Y9GREaBwh3gHU1wzj/BhpXJQzQFYmYsel+Cn189\nnzNPnMbXH1rH4uVP8tqOjoKtQ0QkF4V72ge+CCecDau+AtteLOiij6mt4vbPNvOtC9/D+i17OP/m\nx/nB7zdpL15EikbhnpY+PbJyQvLh2l0HC7p4M+NTzTNZdfV8Tp01mX98cC1/eedTbH77QEHXIyIC\nCveBJh4LH78F3loDv/ynoqxixqTxfP+v5/G/PtHIM6/vYsF3HuPu1W9QqrOWRCScFO7Z3vkRmLcE\nnrwFXv5FUVZhZnz6A7NYddV8GmbU8nf3Ps9f//tq3tpT2N8WRCS6FO65nHsDHNMAD1wK+4p3D5yZ\nU6r58SWn8bWPNfD7jTs47zuP8cCzm7UXLyIjpnDPpWI8XHA7HNwND15WsNMjc4nFjL86Yw4rr/gg\nJ9TVcNVPnuOL//k02/cN/5YIIiIK98EcezKc93V4eRU89b2ir+74ugn89Iunc9357+bRF9s57zuP\nsXLNlqKvV0TCSeF+JPOWwNzz4JGvwlvrir66spjxhf9+Aj+74kzqJ4/nb374DJf/+Fl2dYz8ua8i\nEi0K9yMxg0W3QFVt8u6RXaNz2uLcYydy76Wn86Vz38nDL2zhvJse4xfr3hqVdYtIOCjchzKhLnn+\n+7Z18POvjdpqK8piXH72XB647Aym1lTy+e+38KW7/8juA12jVoOIBJfCPR9zz4UPXApP3QYvPTKq\nqz55RpwVS89k6YdP5IHnNrPgpsf4zUsleoqViASGwj1f5/xPOLYRHvwb2LdtVFddWR7jyx95F/dd\nejo148r57J1Pcd19a9h3qHtU6xCR4FC456uiKnl65KG9yfPfe0f/SUvvnTmJhy4/kyXzj+eu1a+z\n4KbH+N0r20e9DhEZ+/SwjqP11Pdg5ZdhwTfgtEtLVkbLpp18+ad/ZNOO/Zw6azJTayqZUlPJ5JpK\nJldXMLk6OT0p9TqlupKJVeXEYlaymkVk5PJ9WIfC/Wi5w48Xwyu/gksehXc0lqyU/Z3d/NuvWnn2\n9V3s6uhi1/5Odu3vpKsn999pzGBy9WAbgOR0+v0pqTG1VRXaIIiMIQr3YurYDreeDh3tMH4yjJ8C\n1VMyXidnTWe9lo8rWmnuzr5D3by9v4udHZ3s3N/J2/s72dnRxa6Ozr4NwM6Ozr4x+WwQJlVXpAI/\ncwNQkdwwVFcyuaaC8RXlVJYbFWWxvq/KshgVqXnlMcNMGwqRkcg33MtHo5jQqZkGf/kgrLkn+QSn\n/TuTr7vbkg/c3r8Tuo9wTnxFTf9GYKgNQfr9cfHkbYmHYGZMrKpgYlUFM6dU5/XHcXc6Onv6wj8d\n+OnfBjI3BK/v3M8f295mV0cXnT1H33eoLItRUWZUlMcoj8WoTH3fvzHI2DiUZ02XxXJsPKxvbEVq\n2TFLf0HMDLPk55I5nTnGBsw7fEzy5w//mVgey41lzosd/n5ZLP+xIkdD4T5cx5wEZ//j4O93HegP\n/cNedw2c3tKWfD3wNjDIb1JWBuMn5Q7/iupk+mAZrzEwcszLHmeYxZiAMcGMmen3zGCcQZXBlPS8\nWN/POHCoO7lR6OjsZV9nD129Rk+v0+1Gd6/R1fc9dPdCVy9093ryPU9N9yRfB345XV1GZ2p8Zw8c\n6oW9PU5nL3SlfuZQj3OoJ/nx9BKjF8Oxvk/QSQdi8jXzl9T0e/2vR57PoPOPvIzMn881LnNe9rqz\nx8RSv/lkfpWl/q6SG4RY/0YjY6NQXpaajhnlseTY8tRGsCyW+hpkXlnqZ484L+O99Dqy58UMep2+\nm+K5g+Op14HTyfcHvtfrnvzcUq+D/Tx904f/PJBjI565UT18Q285dhT6pmO5f36wnYTMjfhJ02vz\n3vkaLoV7sVSMh3gi+ZWv3p7kzcoG3ShkvOb7W0KRGFCV+po66mtPFVCO/g9Oc6Cnf7KX1Mab9MbE\n+jcU1r/RGPjewPmeHJj1Xtar93/fy8CNUeZ7ZLyO5I+YltxvGXx5h71j6WVYxoJ84EbabcB0/nUN\nXsdg72065UpmfuILw1hb/vL6p2FmC4CbgTLgdnf/Rtb744DvA6cCO4A/d/dNhS01AmJlyb3x6inD\n+3n31O6pg/dmfJ/x6r2DzCPH+FzLGGy5We97b8a6ejPme455Rxibc372eM+aD33/gvt21zN32/3I\nY/KezvrsDxsz2PdkfH5H8z255+dYRyx7/hFfBxtHHj+b+QpOb3KvuddTe8+9qb8e789i6w/f/uiz\nw7K6b4wZjHjTADn/3vrf7P+v90+5978OnEffn4v0RiLjr8izftPoG5/65ti5s0b6pxnSkOFuZmXA\nMuBcoA1YbWYr3D3zTlqfA3a5+4lmthj4JvDnxShYjiB9OAVIbodFRlf/7wvBdPhGJ7jyuYhpHtDq\n7hvdvRO4C1iUNWYR8B+p7+8BzjZ1gERESiafcE8Ab2RMt6Xm5Rzj7t3Abkp0KFZEREb59gNmtsTM\nWsyspb1dN78SESmWfMJ9MzAzY7o+NS/nGDMrB+IkG6sDuPtyd2929+a6urrhVSwiIkPKJ9xXA3PN\nbI6ZVQKLgRVZY1YAn019fyHwK9dTnkVESmbIs2XcvdvMlgKrSJ6Ccae7rzWzG4AWd18B3AH8wMxa\ngZ0kNwAiIlIieZ3n7u4rgZVZ867P+P4g8KnCliYiIsOl+7mLiIRQye4KaWbtwGvD/PFpgJ5S0U+f\nx0D6PPrpsxgoDJ/HLHcf8oyUkoX7SJhZSz63vIwKfR4D6fPop89ioCh9HjosIyISQgp3EZEQCmq4\nLy91AWOMPo+B9Hn002cxUGQ+j0AecxcRkSML6p67iIgcQeDC3cwWmNkGM2s1s2tLXU+pmNlMM3vU\nzNaZ2Vozu7LUNY0FZlZmZs+a2UOlrqXUzGySmd1jZi+a2Xoz+2+lrqlUzOzq1L+TF8zsx2ZWVeqa\nii1Q4Z7x4JDzgQbgIjNrKG1VJdMNfMndG4DTgMsi/FlkuhJYX+oixoibgYfd/d3Ae4no52JmCeAK\noNndG0neRiX0t0gJVLiT34NDIsHdt7j7M6nv95L8h3sUD2wNHzOrBz4K3F7qWkrNzOLAfJL3fcLd\nO9397dJWVVLlwPjUXWurgTdLXE/RBS3c83lwSOSY2WzgFOAPpa2k5G4C/g7oHWpgBMwB2oH/mzpM\ndbuZ1ZS6qFJw983AvwKvA1uA3e7+SGmrKr6ghbtkMbMJwL3AVe6+p9T1lIqZ/Smwzd2fLnUtY0Q5\n8H7gVnc/BegAItmjMrPJJH/DnwPMAGrM7C9KW1XxBS3c83lwSGSYWQXJYP+hu99X6npK7AxgoZlt\nInm47iwz+8/SllRSbUCbu6d/m7uHZNhH0TnAq+7e7u5dwH3A6SWuqeiCFu75PDgkElIPIL8DWO/u\n3y51PaXm7te5e727zyb5/8Wv3D30e2eDcfetwBtm9q7UrLOBdSUsqZReB04zs+rUv5uziUBzOa/7\nuY8Vgz04pMRllcoZwGeANWb2XGreP6TuvS8CcDnww9SO0Ebgr0pcT0m4+x/M7B7gGZJnmT1LBK5U\n1RWqIiIhFLTDMiIikgeFu4hICCncRURCSOEuIhJCCncRkRBSuIuIhJDCXUQkhBTuIiIh9P8BRsh8\nsCgJsdoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "IBO0KTYWi9n_", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "# Create the encoder model from the tensors we previously declared.\n", "encoder_model = Model(encoder_inputs, [encoder_out, state_h, state_c])\n", "\n", "# Generate a new set of tensors for our new inference decoder. Note that we are using new tensors, \n", "# this does not preclude using the same underlying layers that we trained on. (e.g. weights/biases).\n", "\n", "inf_decoder_inputs = Input(shape=(None,), name=\"inf_decoder_inputs\")\n", "# We'll need to force feed the two state variables into the decoder each step.\n", "state_input_h = Input(shape=(units*2,), name=\"state_input_h\")\n", "state_input_c = Input(shape=(units*2,), name=\"state_input_c\")\n", "decoder_res, decoder_h, decoder_c = decoder_lstm(\n", " decoder_emb(inf_decoder_inputs), \n", " initial_state=[state_input_h, state_input_c])\n", "inf_decoder_out = decoder_d2(decoder_d1(decoder_res))\n", "inf_model = Model(inputs=[inf_decoder_inputs, state_input_h, state_input_c], \n", " outputs=[inf_decoder_out, decoder_h, decoder_c])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "K5ER4kOfi9iP", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "# Converts the given sentence (just a string) into a vector of word IDs\n", "# Output is 1-D: [timesteps/words]\n", "\n", "def sentence_to_vector(sentence, lang):\n", "\n", " pre = sentence\n", " vec = np.zeros(len_input)\n", " sentence_list = [lang.word2idx[s] for s in pre.split(' ')]\n", " for i,w in enumerate(sentence_list):\n", " vec[i] = w\n", " return vec\n", "\n", "# Given an input string, an encoder model (infenc_model) and a decoder model (infmodel),\n", "def translate(input_sentence, infenc_model, infmodel):\n", " sv = sentence_to_vector(input_sentence, input_lang)\n", " sv = sv.reshape(1,len(sv))\n", " [emb_out, sh, sc] = infenc_model.predict(x=sv)\n", " \n", " i = 0\n", " start_vec = target_lang.word2idx[\"\"]\n", " stop_vec = target_lang.word2idx[\"\"]\n", " \n", " cur_vec = np.zeros((1,1))\n", " cur_vec[0,0] = start_vec\n", " cur_word = \"\"\n", " output_sentence = \"\"\n", "\n", " while cur_word != \"\" and i < (len_target-1):\n", " i += 1\n", " if cur_word != \"\":\n", " output_sentence = output_sentence + \" \" + cur_word\n", " x_in = [cur_vec, sh, sc]\n", " [nvec, sh, sc] = infmodel.predict(x=x_in)\n", " cur_vec[0,0] = np.argmax(nvec[0,0])\n", " cur_word = target_lang.idx2word[np.argmax(nvec[0,0])]\n", " return output_sentence" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "jzHdBkDji9cM", "colab_type": "code", "outputId": "dd56edf2-07f7-4f8a-8fe2-b3dec77823a9", "colab": { "base_uri": "https://localhost:8080/", "height": 545 } }, "cell_type": "code", "source": [ "#Note that only words that we've trained the model on will be available, otherwise you'll get an error.\n", "\n", "\n", "test = [\n", " 'hi there',\n", " 'hell',\n", " 'presentation please fin',\n", " 'resignation please find at',\n", " 'resignation please ',\n", " 'have a nice we',\n", " 'let me ',\n", " 'promotion congrats ',\n", " 'christmas Merry ',\n", " 'please rev',\n", " 'please ca',\n", " 'thanks fo',\n", " 'Let me kno',\n", " 'Let me know if y',\n", " 'this soun',\n", " 'is this call going t'\n", "]\n", " \n", "\n", "import pandas as pd\n", "output = [] \n", "for t in test: \n", " output.append({\"Input seq\":t.lower(), \"Pred. Seq\":translate(t.lower(), encoder_model, inf_model)})\n", "\n", "results_df = pd.DataFrame.from_dict(output) \n", "results_df.head(len(test))" ], "execution_count": 73, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Input seqPred. Seq
0hi there, how are you today?
1hello, how are you?
2presentation please find attached the presentation
3resignation please find attached my resignation letter.
4resignation pleaseaccept my resignation.
5have a nice weekend
6let meknow if you need anything else.
7promotion congratson your promotion
8christmas merryeveryone!let me know when you are leaving today
9please review.
10please call with any questions.
11thanks for the update.
12let me know if you need anything else.
13let me know if you need anything else.
14this sounds acceptable to us.
15is this call going to happen?
\n", "
" ], "text/plain": [ " Input seq \\\n", "0 hi there \n", "1 hell \n", "2 presentation please fin \n", "3 resignation please find at \n", "4 resignation please \n", "5 have a nice we \n", "6 let me \n", "7 promotion congrats \n", "8 christmas merry \n", "9 please rev \n", "10 please ca \n", "11 thanks fo \n", "12 let me kno \n", "13 let me know if y \n", "14 this soun \n", "15 is this call going t \n", "\n", " Pred. Seq \n", "0 , how are you today? \n", "1 o, how are you? \n", "2 d attached the presentation \n", "3 tached my resignation letter. \n", "4 accept my resignation. \n", "5 ekend \n", "6 know if you need anything else. \n", "7 on your promotion \n", "8 everyone!let me know when you are leaving today \n", "9 iew. \n", "10 ll with any questions. \n", "11 r the update. \n", "12 w if you need anything else. \n", "13 ou need anything else. \n", "14 ds acceptable to us. \n", "15 o happen? " ] }, "metadata": { "tags": [] }, "execution_count": 73 } ] }, { "metadata": { "id": "v6iADg_MXEdH", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "# This is to save the model for the web app to use for generation\n", "from keras.models import model_from_json\n", "from keras.models import load_model\n", "\n", "# serialize model to JSON\n", "# the keras model which is trained is defined as 'model' in this example\n", "model_json = inf_model.to_json()\n", "\n", "\n", "with open(\"./sample_data/model_num.json\", \"w\") as json_file:\n", " json_file.write(model_json)\n", "\n", "# serialize weights to HDF5\n", "inf_model.save_weights(\"./sample_data/model_num.h5\")" ], "execution_count": 0, "outputs": [] } ] }