{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Spam or Ham?: RNN Remix" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Lab Assignment Nine: RNNs" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Justin Ledford, Luke Wood, Traian Pop \n", "___" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Business Understanding" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Data Background\n", "SMS messages play a huge role in a person's life, and the confidentiality and integrity of said messages are of the highest priority to mobile carriers around the world. Due to this fact, many unlawful individuals and groups try and take advantange of the average consumer by flooding their inbox with spam, and while the majority of people successfully avoid it, there are people out there affected negatively by falling for false messages. \n", "\n", "The data we selected is a compilation of 5574 SMS messages acquired from a variety of different sources, broken down in the following way: 452 of the messages came from the Grumbletext Web Site, 3375 of the messages were taken from the NUS SMS Corpus (database with legitimate message from the University of Singapore), 450 messages collected from Caroline Tag's PhD Thesis, and the last 1324 messages were from the SMS Spam Corpus v.0.1 Big. \n", "\n", "Overall there were 4827 \"ham\" messages and 747 \"spam\" messages, and about 92,000 words.\n", "\n", "### Purpose\n", "This data was collected initially for studies on deciphering the differences between a spam or ham (legitimate) messages. Uses for this research can involve advanced spam filtering technology or improved data sets for machine learning programs. However, a slight problem with this data set, as with most localized language-based data sets, is that due to the relatively small area of sampling, there are a lot of regional data points (such as slang, acronyms, etc) that can be considering \"useless\" data if a much more generalized data set is wanted. For our specific project however, we are keeping all this data in order for us to analyze it and get a better understanding of our data.\n", "___" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "deletable": true, "editable": true }, "source": [ "## Preparation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import requests\n", "import re\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n", "import matplotlib.pyplot as plt\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "%matplotlib inline\n", "\n", "descriptors_url = 'https://raw.githubusercontent.com/LukeWoodSMU/TextAnalysis/master/data/SMSSpamCollection'\n", "descriptors = requests.get(descriptors_url).text\n", "texts = []\n", "\n", "\n", "for line in descriptors.splitlines():\n", " texts.append(line.rstrip().split(\"\\t\"))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "After the first look at the data we noticed a lot of phone numbers. Since almost every number was unique we concluded that the numbers were irrelevant to consider as words. We considered grouping all number tokens into one token and analyze the presence of words, but we decided to first start by just removing the numbers." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "# Remove numbers\n", "texts = list(zip([a for a,b in texts], [re.sub('((\\(\\d{3}\\) ?)|(\\d{3}-))?\\d{3}-\\d', 'PHONE_NUMBER', b) for a,b in texts]))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Citation: regex from google search top results/stack overflow" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import numpy as np\n", "from keras.preprocessing import sequence" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "X = [x[1] for x in texts]\n", "y = [x[0] for x in texts]\n", "X = np.array(X)\n", "print(type(X))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 0., 1.],\n", " [ 0., 1.],\n", " [ 1., 0.],\n", " ..., \n", " [ 0., 1.],\n", " [ 0., 1.],\n", " [ 0., 1.]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import keras\n", "y = [0 if y_ == \"spam\" else 1 for y_ in y]\n", "y_ohe = keras.utils.to_categorical(y)\n", "y_ohe" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "We assign spam as a value of 0 and ham as a value of one so that we can use precision score to measure false positive scores." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "import keras\n", "from keras.preprocessing.text import Tokenizer\n", "from keras.preprocessing.sequence import pad_sequences\n", "\n", "NUM_TOP_WORDS = None\n", "\n", "tokenizer = Tokenizer(num_words=NUM_TOP_WORDS)\n", "tokenizer.fit_on_texts(X)\n", "word_index = tokenizer.word_index\n", "\n", "sequences = tokenizer.texts_to_sequences(X)\n", "sequences = pad_sequences(sequences)\n", "\n", "MAX_TEXT_LEN = len(sequences[0]) # maximum and minimum number of words" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "We tokenize and measure the max length of the text using keras' tokenizer." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Cross Validation Method" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "We now have an embedding matrix for our word index.\n", "\n", "Finally, we split our data into training data and testing data. We stratify the data on y_ohe to ensure that we get a fair representation of the spam and ham messages. We believe this to be appropriate because each model needs to see a fair number of both spam messages and ham messages to ensure it does not overtrain on either." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import train_test_split\n", "# Split it into train / test subsets\n", "X_train, X_test, y_train_ohe, y_test_ohe = train_test_split(sequences, y_ohe, test_size=0.2,\n", " stratify=y_ohe, \n", " random_state=42)\n", "NUM_CLASSES = len(y_train_ohe[0])\n", "NUM_CLASSES" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Evaluation Metrics\n", "We decided that due to our business understanding being that we can potentially create a spam filter, our largest cost should be false positives. It would be incredibly frustrating to have a real text filtered out so we should evaluate our models in accordance with this. To evaluate this, we must implement precision score which has been removed from keras. Luckily, the old code is available in a one of keras' old versions." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "# Old version of keras had precision score, copied the code to re-implement it.\n", "import keras.backend as K\n", "def precision(y_true, y_pred):\n", " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n", " precision = true_positives / (predicted_positives + K.epsilon())\n", " return precision" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Citation: old keras version" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Modeling\n", "To avoid the need for training our own embedding layer which is incredibly computationally expensive, we load up a pretrained glove embedding." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 400000 word vectors.\n", "(9008, 100)\n" ] } ], "source": [ "EMBED_SIZE = 100\n", "# the embed size should match the file you load glove from\n", "embeddings_index = {}\n", "f = open('GLOVE/glove.6B/glove.6B.100d.txt')\n", "# save key/array pairs of the embeddings\n", "# the key of the dictionary is the word, the array is the embedding\n", "for line in f:\n", " values = line.split()\n", " word = values[0]\n", " coefs = np.asarray(values[1:], dtype='float32')\n", " embeddings_index[word] = coefs\n", "f.close()\n", "\n", "print('Found %s word vectors.' % len(embeddings_index))\n", "\n", "# now fill in the matrix, using the ordering from the\n", "# keras word tokenizer from before\n", "embedding_matrix = np.zeros((len(word_index) + 1, EMBED_SIZE))\n", "for word, i in word_index.items():\n", " embedding_vector = embeddings_index.get(word)\n", " if embedding_vector is not None:\n", " # words not found in embedding index will be all-zeros.\n", " embedding_matrix[i] = embedding_vector\n", "\n", "print(embedding_matrix.shape)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from keras.layers import Embedding\n", "\n", "embedding_layer = Embedding(len(word_index) + 1,\n", " EMBED_SIZE,\n", " weights=[embedding_matrix],\n", " input_length=MAX_TEXT_LEN,\n", " trainable=False)\n", "metrics=[precision,\"accuracy\"]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "embedding_1 (Embedding) (None, 189, 100) 900800 \n", "_________________________________________________________________\n", "lstm_1 (LSTM) (None, 100) 80400 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 2) 202 \n", "=================================================================\n", "Total params: 981,402\n", "Trainable params: 80,602\n", "Non-trainable params: 900,800\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.layers import LSTM\n", "\n", "rnn = Sequential()\n", "rnn.add(embedding_layer)\n", "rnn.add(LSTM(100,dropout=0.2, recurrent_dropout=0.2))\n", "rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", "print(rnn.summary())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 19s - loss: 0.1908 - precision: 0.9525 - acc: 0.9325 - val_loss: 0.1071 - val_precision: 0.9852 - val_acc: 0.9578\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0982 - precision: 0.9902 - acc: 0.9684 - val_loss: 0.1500 - val_precision: 0.9794 - val_acc: 0.9471\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0742 - precision: 0.9926 - acc: 0.9751 - val_loss: 0.0779 - val_precision: 0.9885 - val_acc: 0.9731\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnn.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe), epochs=3, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Comparing Different Model Types" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "To begin, we will evaluate a network using an LSTM cell, a GRU cell, and a SimpleRNN cell. We will use a standard hyperparameter set to evaluate the results and decide which two architectures we want to explore in depth based on the results." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from keras.layers import LSTM, GRU, SimpleRNN\n", "\n", "rnns = []\n", "\n", "for func in [SimpleRNN, LSTM, GRU]:\n", " rnn = Sequential()\n", " rnn.add(embedding_layer)\n", " rnn.add(func(100,dropout=0.2, recurrent_dropout=0.2))\n", " rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "\n", " rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", " rnns.append(rnn)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Testing Cell Type: simple ========\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 7s - loss: 0.2723 - precision: 0.8913 - acc: 0.8924 - val_loss: 0.1902 - val_precision: 0.9341 - val_acc: 0.9318\n", "Epoch 2/3\n", "4459/4459 [==============================] - 7s - loss: 0.1705 - precision: 0.9464 - acc: 0.9394 - val_loss: 0.1373 - val_precision: 0.9591 - val_acc: 0.9525\n", "Epoch 3/3\n", "4459/4459 [==============================] - 7s - loss: 0.1444 - precision: 0.9565 - acc: 0.9457 - val_loss: 0.1132 - val_precision: 0.9702 - val_acc: 0.9641\n", "\n", "Testing Cell Type: lstm ========\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 19s - loss: 0.1813 - precision: 0.9784 - acc: 0.9365 - val_loss: 0.1089 - val_precision: 0.9966 - val_acc: 0.9561\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0931 - precision: 0.9958 - acc: 0.9682 - val_loss: 0.0859 - val_precision: 0.9957 - val_acc: 0.9713\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0740 - precision: 0.9958 - acc: 0.9749 - val_loss: 0.0805 - val_precision: 0.9946 - val_acc: 0.9686\n", "\n", "Testing Cell Type: gru ========\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 19s - loss: 0.2118 - precision: 0.9403 - acc: 0.9141 - val_loss: 0.0982 - val_precision: 0.9931 - val_acc: 0.9695\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0956 - precision: 0.9921 - acc: 0.9699 - val_loss: 0.0729 - val_precision: 0.9939 - val_acc: 0.9749\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0720 - precision: 0.9947 - acc: 0.9771 - val_loss: 0.0668 - val_precision: 0.9950 - val_acc: 0.9758\n" ] } ], "source": [ "for rnn, name in zip(rnns,['simple','lstm','gru']):\n", " print('\\nTesting Cell Type: ',name,'========')\n", " rnn.fit(X_train, y_train_ohe, epochs=3, batch_size=64, validation_data=(X_test, y_test_ohe))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "As we can see, the GRU model performs the best by a large margin. If we continue to train the GRU model it seems that we will get some really great results. We will try also try to find the best hyperparameters for the GRU model.\n", "\n", "After we find the best GRU results we will use an LSTM and then measure the results of the LSTM." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Gridsearch on GRU Model" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 19s - loss: 0.2019 - precision: 0.9571 - acc: 0.9186 - val_loss: 0.0949 - val_precision: 0.9967 - val_acc: 0.9668\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0783 - precision: 0.9955 - acc: 0.9715 - val_loss: 0.0870 - val_precision: 0.9937 - val_acc: 0.9677\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0579 - precision: 0.9955 - acc: 0.9818 - val_loss: 0.0671 - val_precision: 0.9938 - val_acc: 0.9758\n", "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.1992 - precision: 0.9645 - acc: 0.9258 - val_loss: 0.1475 - val_precision: 0.9853 - val_acc: 0.9507\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0890 - precision: 0.9959 - acc: 0.9711 - val_loss: 0.0786 - val_precision: 0.9957 - val_acc: 0.9713\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0635 - precision: 0.9965 - acc: 0.9798 - val_loss: 0.0646 - val_precision: 0.9958 - val_acc: 0.9776\n", "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2042 - precision: 0.9616 - acc: 0.9285 - val_loss: 0.0963 - val_precision: 0.9939 - val_acc: 0.9650\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0931 - precision: 0.9927 - acc: 0.9695 - val_loss: 0.0827 - val_precision: 0.9969 - val_acc: 0.9722\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0645 - precision: 0.9951 - acc: 0.9787 - val_loss: 0.0673 - val_precision: 0.9950 - val_acc: 0.9785\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 19s - loss: 0.1992 - precision: 0.9469 - acc: 0.9273 - val_loss: 0.0970 - val_precision: 0.9943 - val_acc: 0.9695\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0866 - precision: 0.9919 - acc: 0.9715 - val_loss: 0.0823 - val_precision: 0.9931 - val_acc: 0.9704\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0606 - precision: 0.9946 - acc: 0.9814 - val_loss: 0.0660 - val_precision: 0.9933 - val_acc: 0.9776\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2040 - precision: 0.9423 - acc: 0.9206 - val_loss: 0.1004 - val_precision: 0.9920 - val_acc: 0.9632\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0927 - precision: 0.9902 - acc: 0.9704 - val_loss: 0.0780 - val_precision: 0.9950 - val_acc: 0.9740\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0675 - precision: 0.9935 - acc: 0.9782 - val_loss: 0.0660 - val_precision: 0.9932 - val_acc: 0.9794\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2293 - precision: 0.9248 - acc: 0.9065 - val_loss: 0.1219 - val_precision: 0.9816 - val_acc: 0.9561\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.1008 - precision: 0.9860 - acc: 0.9661 - val_loss: 0.0805 - val_precision: 0.9859 - val_acc: 0.9740\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0813 - precision: 0.9886 - acc: 0.9744 - val_loss: 0.0698 - val_precision: 0.9885 - val_acc: 0.9767\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2037 - precision: 0.9509 - acc: 0.9222 - val_loss: 0.1467 - val_precision: 0.9826 - val_acc: 0.9471\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0921 - precision: 0.9910 - acc: 0.9697 - val_loss: 0.0935 - val_precision: 0.9918 - val_acc: 0.9668\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0771 - precision: 0.9942 - acc: 0.9758 - val_loss: 0.0611 - val_precision: 0.9959 - val_acc: 0.9803\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2233 - precision: 0.9236 - acc: 0.9114 - val_loss: 0.1063 - val_precision: 0.9796 - val_acc: 0.9641\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.1017 - precision: 0.9836 - acc: 0.9648 - val_loss: 0.0777 - val_precision: 0.9834 - val_acc: 0.9722\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0765 - precision: 0.9870 - acc: 0.9733 - val_loss: 0.0643 - val_precision: 0.9906 - val_acc: 0.9785\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.2367 - precision: 0.9417 - acc: 0.8998 - val_loss: 0.1063 - val_precision: 0.9955 - val_acc: 0.9632\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.1165 - precision: 0.9897 - acc: 0.9628 - val_loss: 0.0787 - val_precision: 0.9969 - val_acc: 0.9722\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0857 - precision: 0.9927 - acc: 0.9717 - val_loss: 0.0883 - val_precision: 0.9950 - val_acc: 0.9677\n" ] } ], "source": [ "dropouts=[.1,.2,.3]\n", "recurrent_dropouts=[.1,.2,.3]\n", "\n", "for dropout in dropouts:\n", " for recurrent_dropout in recurrent_dropouts:\n", " rnn = Sequential()\n", " rnn.add(embedding_layer)\n", " rnn.add(func(100,dropout=dropout, recurrent_dropout=recurrent_dropout))\n", " rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "\n", " rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", " print(\"Hyper Paramater Set:\\n\\tdropout=%.1f\\n\\trecurrent_dropout=%.1f\" % (dropout,recurrent_dropout))\n", " rnn.fit(X_train,y_train_ohe,epochs=3, batch_size=64, validation_data=(X_test,y_test_ohe))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "###### We get some pretty ridiculously high accuracy with both of our hyperparameters set to .1" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "As we can see, with dropout and recurrent dropout at .1 we get some really great results; with accuracy getting as high as 98.6%. This is ridiculously high. The model gets .997 precision and .98 accuracy on the validation set with these hyperparameters.\n", "\n", "We actually get a similar precision score in a few sets of hyperparameters, but we get a higher accuracy with the .1 and .1 set so this is our most effective model." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "best_model = Sequential()\n", "best_model.add(embedding_layer)\n", "best_model.add(GRU(100,dropout=.1, recurrent_dropout=.1))\n", "best_model.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "best_model.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Running Our Best Model With More Epochs" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/10\n", "4459/4459 [==============================] - 20s - loss: 0.2039 - precision: 0.9513 - acc: 0.9197 - val_loss: 0.0924 - val_precision: 0.9918 - val_acc: 0.9677\n", "Epoch 2/10\n", "4459/4459 [==============================] - 19s - loss: 0.0836 - precision: 0.9930 - acc: 0.9724 - val_loss: 0.0715 - val_precision: 0.9932 - val_acc: 0.9740\n", "Epoch 3/10\n", "4459/4459 [==============================] - 19s - loss: 0.0611 - precision: 0.9945 - acc: 0.9800 - val_loss: 0.1282 - val_precision: 0.9846 - val_acc: 0.9552\n", "Epoch 4/10\n", "4459/4459 [==============================] - 19s - loss: 0.0507 - precision: 0.9942 - acc: 0.9854 - val_loss: 0.0607 - val_precision: 0.9932 - val_acc: 0.9803\n", "Epoch 5/10\n", "4459/4459 [==============================] - 19s - loss: 0.0440 - precision: 0.9959 - acc: 0.9865 - val_loss: 0.0525 - val_precision: 0.9933 - val_acc: 0.9857\n", "Epoch 6/10\n", "4459/4459 [==============================] - 19s - loss: 0.0294 - precision: 0.9970 - acc: 0.9924 - val_loss: 0.0623 - val_precision: 0.9917 - val_acc: 0.9839\n", "Epoch 7/10\n", "4459/4459 [==============================] - 19s - loss: 0.0280 - precision: 0.9974 - acc: 0.9922 - val_loss: 0.0431 - val_precision: 0.9942 - val_acc: 0.9865\n", "Epoch 8/10\n", "4459/4459 [==============================] - 19s - loss: 0.0214 - precision: 0.9986 - acc: 0.9933 - val_loss: 0.0513 - val_precision: 0.9962 - val_acc: 0.9848\n", "Epoch 9/10\n", "4459/4459 [==============================] - 19s - loss: 0.0189 - precision: 0.9979 - acc: 0.9957 - val_loss: 0.0443 - val_precision: 0.9935 - val_acc: 0.9883\n", "Epoch 10/10\n", "4459/4459 [==============================] - 19s - loss: 0.0132 - precision: 0.9986 - acc: 0.9964 - val_loss: 0.0467 - val_precision: 0.9963 - val_acc: 0.9883\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model.fit(X_train,y_train_ohe,epochs=10, batch_size=64, validation_data=(X_test,y_test_ohe))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "###### We end up getting above 99.5% accuracy and a precision score of .9986 on the validation set! We could absolutely use this to publish a spam filter. This is a very good score on this dataset." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Grid Search Using LSTM" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Now that we know we can get results as high as 99.5% accuracy and 99.8% precision with the GRU network we will try to see how high we can get our LSTM's score." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 20s - loss: 0.1763 - precision: 0.9702 - acc: 0.9354 - val_loss: 0.1679 - val_precision: 0.9856 - val_acc: 0.9417\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0878 - precision: 0.9947 - acc: 0.9726 - val_loss: 0.0861 - val_precision: 0.9949 - val_acc: 0.9713\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0679 - precision: 0.9976 - acc: 0.9778 - val_loss: 0.0822 - val_precision: 0.9921 - val_acc: 0.9749\n", "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1810 - precision: 0.9694 - acc: 0.9363 - val_loss: 0.1288 - val_precision: 0.9869 - val_acc: 0.9534\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0921 - precision: 0.9930 - acc: 0.9693 - val_loss: 0.0974 - val_precision: 0.9931 - val_acc: 0.9659\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0737 - precision: 0.9945 - acc: 0.9765 - val_loss: 0.0780 - val_precision: 0.9912 - val_acc: 0.9713\n", "Hyper Paramater Set:\n", "\tdropout=0.1\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1841 - precision: 0.9628 - acc: 0.9325 - val_loss: 0.1113 - val_precision: 0.9862 - val_acc: 0.9614\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0970 - precision: 0.9905 - acc: 0.9711 - val_loss: 0.0891 - val_precision: 0.9929 - val_acc: 0.9695\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0752 - precision: 0.9939 - acc: 0.9760 - val_loss: 0.0708 - val_precision: 0.9932 - val_acc: 0.9740\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1779 - precision: 0.9579 - acc: 0.9365 - val_loss: 0.1004 - val_precision: 0.9857 - val_acc: 0.9650\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0956 - precision: 0.9876 - acc: 0.9702 - val_loss: 0.0923 - val_precision: 0.9865 - val_acc: 0.9695\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0719 - precision: 0.9908 - acc: 0.9785 - val_loss: 0.1162 - val_precision: 0.9747 - val_acc: 0.9525\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1794 - precision: 0.9593 - acc: 0.9318 - val_loss: 0.1534 - val_precision: 0.9805 - val_acc: 0.9462\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0928 - precision: 0.9902 - acc: 0.9684 - val_loss: 0.1130 - val_precision: 0.9835 - val_acc: 0.9596\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0789 - precision: 0.9897 - acc: 0.9744 - val_loss: 0.0789 - val_precision: 0.9893 - val_acc: 0.9740\n", "Hyper Paramater Set:\n", "\tdropout=0.2\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1878 - precision: 0.9565 - acc: 0.9352 - val_loss: 0.1434 - val_precision: 0.9823 - val_acc: 0.9516\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0992 - precision: 0.9890 - acc: 0.9657 - val_loss: 0.1280 - val_precision: 0.9792 - val_acc: 0.9525\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0827 - precision: 0.9919 - acc: 0.9702 - val_loss: 0.0790 - val_precision: 0.9873 - val_acc: 0.9722\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.1\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1891 - precision: 0.9688 - acc: 0.9325 - val_loss: 0.1157 - val_precision: 0.9906 - val_acc: 0.9587\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.0940 - precision: 0.9949 - acc: 0.9713 - val_loss: 0.0992 - val_precision: 0.9921 - val_acc: 0.9650\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0796 - precision: 0.9932 - acc: 0.9724 - val_loss: 0.0767 - val_precision: 0.9921 - val_acc: 0.9740\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.2\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1928 - precision: 0.9556 - acc: 0.9289 - val_loss: 0.1007 - val_precision: 0.9891 - val_acc: 0.9623\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.1024 - precision: 0.9899 - acc: 0.9666 - val_loss: 0.0880 - val_precision: 0.9901 - val_acc: 0.9668\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0882 - precision: 0.9905 - acc: 0.9693 - val_loss: 0.0866 - val_precision: 0.9937 - val_acc: 0.9704\n", "Hyper Paramater Set:\n", "\tdropout=0.3\n", "\trecurrent_dropout=0.3\n", "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.1986 - precision: 0.9723 - acc: 0.9318 - val_loss: 0.1089 - val_precision: 0.9974 - val_acc: 0.9614\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.1016 - precision: 0.9962 - acc: 0.9673 - val_loss: 0.1256 - val_precision: 0.9879 - val_acc: 0.9578\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.0874 - precision: 0.9949 - acc: 0.9706 - val_loss: 0.0956 - val_precision: 0.9929 - val_acc: 0.9668\n" ] } ], "source": [ "dropouts=[.1,.2,.3]\n", "recurrent_dropouts=[.1,.2,.3]\n", "\n", "for dropout in dropouts:\n", " for recurrent_dropout in recurrent_dropouts:\n", " rnn = Sequential()\n", " rnn.add(embedding_layer)\n", " rnn.add(LSTM(100,dropout=dropout, recurrent_dropout=recurrent_dropout))\n", " rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "\n", " rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", " print(\"Hyper Paramater Set:\\n\\tdropout=%.1f\\n\\trecurrent_dropout=%.1f\" % (dropout,recurrent_dropout))\n", " rnn.fit(X_train,y_train_ohe,epochs=3, batch_size=64, validation_data=(X_test,y_test_ohe))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "###### As we can see, our best LSTM hyper parameter set is with a dropout of .1 and a recurrent dropout of .2. We will create this network and train it with more epochs." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "best_lstm = Sequential()\n", "best_lstm.add(embedding_layer)\n", "best_lstm.add(LSTM(100,dropout=.1, recurrent_dropout=.2))\n", "best_lstm.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "best_lstm.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Comparison of models" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Split #1\n", "0.996007984032\n", "[[ 55 2]\n", " [ 2 499]]\n", "0.996\n", "[[ 55 2]\n", " [ 3 498]]\n", "Split #2\n", "0.987421383648\n", "[[ 81 6]\n", " [ 0 471]]\n", "0.991011235955\n", "[[ 83 4]\n", " [ 30 441]]\n", "Split #3\n", "0.991786447639\n", "[[ 71 4]\n", " [ 0 483]]\n", "0.989733059548\n", "[[ 70 5]\n", " [ 1 482]]\n", "CPU times: user 42min 51s, sys: 10min 58s, total: 53min 50s\n", "Wall time: 20min 33s\n" ] } ], "source": [ "%%time\n", "\n", "from sklearn.model_selection import StratifiedShuffleSplit\n", "from sklearn.metrics import confusion_matrix, precision_score\n", "\n", "sss = StratifiedShuffleSplit(n_splits=3)\n", "\n", "gru_scores = []\n", "gru_cms = []\n", "lstm_scores = []\n", "lstm_cms = []\n", "\n", "split_num = 1\n", "for train_index, test_index in sss.split(sequences, y_ohe):\n", " print('Split #{}'.format(split_num))\n", " split_num += 1\n", " X_train, X_test = sequences[train_index], sequences[test_index]\n", " y_train_ohe, y_test_ohe = y_ohe[train_index], y_ohe[test_index]\n", " \n", " # one hot decode for scoring\n", " y_test = [list(x).index(1.0) for x in list(y_test_ohe)] \n", " \n", " best_model.fit(X_train,y_train_ohe,epochs=3,\n", " batch_size=64,validation_data=(X_train,y_train_ohe),verbose=0)\n", " y_hat = best_model.predict(X_test)\n", " \n", " # one hot decode for scoring\n", " y_hat = np.array([[0,1] if np.argmax(x) == 1 else [1,0] for x in y_hat]).astype(float)\n", " y_hat = [list(x).index(1.0) for x in list(y_hat)]\n", " \n", " gru_scores.append(precision_score(y_test, y_hat))\n", " gru_cms.append(confusion_matrix(y_test, y_hat))\n", " \n", " print(gru_scores[-1])\n", " print(gru_cms[-1])\n", " \n", " best_lstm.fit(X_train,y_train_ohe,epochs=3,\n", " batch_size=64,validation_data=(X_train,y_train_ohe),verbose=0)\n", " y_hat = best_lstm.predict(X_test)\n", " \n", " # one hot decode for scoring\n", " y_hat = np.array([[0,1] if np.argmax(x) == 1 else [1,0] for x in y_hat]).astype(float)\n", " y_hat = [list(x).index(1.0) for x in list(y_hat)]\n", " \n", " lstm_scores.append(precision_score(y_test, y_hat))\n", " lstm_cms.append(confusion_matrix(y_test, y_hat))\n", " \n", " print(lstm_scores[-1])\n", " print(lstm_cms[-1])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGkJJREFUeJzt3Xu4XHV97/H3R0gIShQkESShSVrRGlQuJ1KKQqO0lVia\neFpKia3ghUZq8YJatbZaiqc9B6Ha2uIF1CJWoGB9bIpRyjlKL2AooVJqQGyMIDtibgiGS8Do9/wx\na8fJZmfvSbKz9yLzfj3PPM+stX6z1nfWzJ7P/Nb67TWpKiRJapsnTXQBkiQNx4CSJLWSASVJaiUD\nSpLUSgaUJKmVDChJUisZUNJulOR/J3nLBGz3i0nO6KHdg0l+ejxqGk9JfjXJ3010Hdo1BpQeJ8lp\nSW5K8lCSdc39NyRJs/zSJI81H273Jbkuyc92Pf7cJH87zHorybN2Y93DbrdZ9uIkNyZ5oKn5hiQv\nTPLu5nk8mGRzkh91Ta/sqntdkr271jepmbfdfyRMMh04HfjYWD/X0VTVgqr6VA/t9quq1eNR03iq\nqn8EDk/ygomuRTvPgNI2krwN+EvgAuBg4CDgLOBFwOSupu+vqv2AGcAa4BPjXGrPkjwVuAb4K+Dp\ndGr+E+DRqvqz5kN6PzrP86uD01V1eNdqvg8s6Jpe0MwbyauBZVX1yE7WvdfOPG5Pl45ePruuAJbs\n7nq0+xhQ2irJ04DzgDdU1WeralN1fK2qfquqHh36mObD9yrgyF3Y7juTfHbIvL9M8qHm/quTrE6y\nKcm3k/zWDm7i2U2tV1TVj6rqkar6p6q6bQfW8Wk6vaFBpwOXjfKYBcA/D04kmZ9koOm1bUhyV/dz\naXqmH0myLMlDwEuS7JPkwiTfSbI2yUeT7Nv1mEVJbk3ygyTfSnJSM//6JGc295+V5J+b3uOG7kNf\n3b3aJE9LclmS9UnuTvJHg0HQvAb/1tTy/eZ16A7sbTSv6ZrmNbszyYnN/L2a5/+tZtktSQ5tlh2X\n5OamzpuTHNe1vuuT/GmSG4CHgZ9u6v1Eknubbf2vIaF+PfAro7xGajEDSt1+HtgH+IdeH5DkKcBi\nYNUubPdK4OVJpjbr3As4Fbi8Wf+HgAVVNRU4Drh1B9f/TeBHST6VZEGSA3aixs8DJyTZv3n88Yy+\nn54P3Dlk3sHANDq9uDOAi5M8p2v5K4E/BaYC/wb8HzoBeyTwrOZx7wVIcgydkPx9YH/gBOCuYep4\nH/BPwAHATDo9yeH8FfA04KeBX6ATwq/pWv5zzfOZBrwf+ETSOezbrXk+ZwMvbF6zl3XV9VY675eX\nA08FXgs8nOTpwBfovNYHAh8AvpDkwK5Vv4pOj2gqcDdwKbCl2S9HAb8MnNnV/g5gdtOD1hOQAaVu\n04ANVbVlcEZz3ub+JI8kOaGr7duT3A9sAl5M58Njp1TV3cB/AP+zmfVS4OGqWt5M/xh4XpJ9q+re\nqlq5g+v/QVNjAZcA65MsTXLQDqxmM/CPwG82t6XNvJHsT2f/DPWeqnq0qv6ZzofyqV3L/qGqbqiq\nHwOP0vlAPqeq7quqTcCfAac1bV8HfLKqrquqH1fVmqr6xjDb+yEwCzikqjZX1b8NbdB8KTgN+IOm\n53wX8Ods+7reXVWXVNWPgE8Bz6RzCHioH9H5ojM3yaSququqvtUsOxP4o6q6s+md/2dVbaTT0/nv\nqvp0VW2pqiuAbwC/2rXeS6tqZfP+fDqdkHtLVT1UVeuAD3btG/jJvt9/mBr1BGBAqdtGYFr3YICq\nOq6q9m+Wdb9fLmzmzwYeAbp7AVuASd0rTjI4/cPtbPtyOt+sodOLuLzZ/kN0AuEs4N4kX0jXgIxe\nVdUdVfXqqpoJPA84BPiLHVzNZXR6Fb0c3oPOOaqpQ+c1z2nQ3U0tg+7puj8deDJwS/Ml4X7gS818\ngEOBbzG6dwAB/j3JyiSvHabNNDqv2d1DapvRNf29wTtV9XBzd7+hK6qqVcBbgHOBdUmuTDL4HLdX\n8yFDtj3c9rv3zaym3nu79s3HgGd0tRnc9/cPsz09ARhQ6vZVOt/aF/X6gKr6DvBm4C+7zo18h05w\ndZtDJ7jWbGdVVwPzk8yk05O6vGsb11bVL9H5xv4NOr2gndb0Mi6lE1Q74l/5Sa/hcb2QYdxGc/6r\nywHNYctBPwV8t7u8rvsb6IT/4VW1f3N7WjOgAzof2D8zWhFV9b2q+p2qOgR4PfDhPH405QZ+0tPq\nrm17r9do27y8ql7crK+A80ep+btDtj3c9rv3zT103qvTuvbNU4cMbHkucFfTg9YTkAGlrarqfjqj\n2z6c5JQkU5M8KcmRwFNGeNx1dD5gBkdMfQn42SSvSmc49tPpHJr6++7Dh0PWsZ7OSe2/Ab5dVXcA\nJDmoGQjwFDofSA/SOeS3PU9KMqXrtk+Sn03ytib8aE7KLwaWj7Ce4WosOoecFlZvv1OzjM65nKH+\nJMnkJMcDJ9MJ5+G292M6YfzBJM9oap+R5GVNk08Ar0lyYvM6zRiud5nkNwafO51eXTFkHzaH7a4C\n/rR53WfROV807LD9kSR5TpKXJtmHzmHQR7q293HgfUkOS8cLmvNMy4BnJ3llkr2T/CYwl87oy+H2\nzb10zqv9eZKnNs//Z5J07+9fAL64o/WrPQwobaOq3k/ng+kdwNrm9jHgncCNIzz0AuAdSfZpzgcs\noPNtfR3wdTqHWX53lM1fDvwiXb0nOu/Rt9IJwPvofOiMtJ7FdD4QB2/fonMu4ueAm9IZHbe8qelt\no9TzOM05kF7PgV1GZ/DHvl3zvkcnJL4LfAY4azvnjQa9k84AlOVJfgD8X5rDqVX173QGMXwQeIDO\niMGhvRCAF9J57g/SOXf25u3879MbgYeA1XR6iJcDn+ztqW5jHzqDOzbQeb7PAP6gWfYBOkH4T8AP\n6ITsvs15qJPpvCYb6bz/Tq6qDSNs53Q6//pwO519+lk6PdxBi5mA/0HT2Ik/WCjtPkn+DFhXVX+R\nZD7wt815MO1GSX4VeFVVnTpqY7WWASWNEwNK2jEe4pMktZI9KElSK9mDkiS10t6jN9k9pk2bVrNn\nz56ozUuSJsgtt9yyoaqmj9ZuwgJq9uzZrFixYqI2L0maIEmGXjVkWB7ikyS1kgElSWolA0qS1EoT\ndg5KkvrJD3/4QwYGBti8ebRfadlzTJkyhZkzZzJp0qTRGw/DgJKkcTAwMMDUqVOZPXs2w/zO4x6n\nqti4cSMDAwPMmTNnp9Yx6iG+JJ9Msi7J17ezPEk+lGRVktuSHL1TlUjSHmzz5s0ceOCBfRFOAEk4\n8MADd6nH2Ms5qEuBk0ZYvgA4rLktAT6y09VI0h6sX8Jp0K4+31EDqqr+hc7PHGzPIuCy5ueblwP7\nJ3nmCO0lSRrVWJyDmsG2P8U80My7dwzWLUl7pIsvHtv1LVkyepu1a9dyzjnnsHz5cg444AAmT57M\nO97xDg444AAWLVrEnDlz2Lx5MyeffDIXXnghAOeeey777bcfb3/727euZ/BCC9OmTRvbJzHEuA4z\nT7IkyYokK9avXz+em5akvlZVvOIVr+CEE05g9erV3HLLLVx55ZUMDAwAcPzxx3Prrbfyta99jWuu\nuYYbbrhhgisemx7UGuDQrumZzbzHqaqLgYsB5s2bNyaXUR/rbyG9WMJEbLSHr0eSJkQvn0NHHgnd\n38s3bdr17U6d2nvbL3/5y0yePJmzzjpr67xZs2bxxje+keuvv37rvH333ZcjjzySNWuG/RgfV2MR\nUEuBs5NcSedntR+oKg/v6QnDLzl6wupOufUjj5ZbedNNHP3c526bkoPuvx8eewzWr+f799/Pf99x\nByfMndtp+9BDzfqbx00f9RqvY6aXYeZXAF8FnpNkIMnrkpyVZDCGlwGrgVXAJcAbdlu1kqQx8Xvv\nfCdHzJ/PC3/5lwH41+XLOWL+fGYccQQve8lLOPigg4Dtj8QbjxGJo/agqmrxKMsL+L0xq0iSNOYO\nf85z+Ptrrtk6fdH557Nh40bmNQF1/LHHcs1nPsO3776bYxcs4NSFCzny+c/nwKc/nXvXrt1mXZs2\nbWL//fff7TV7LT5J6gMvPf54Nj/6KB/5m7/ZOu/hRx55XLs5s2bxrje9ifP/+q8BOOHYY1l67bVs\nevBBAD73uc9xxBFHsNdee+32mr3UkSRNgNNP3/V1TKf3qzQk4fOf+hTnvOc9vP+ii5h+4IE85clP\n5vz3vOdxbc864wwu/PCHues73+EFhx/O2a99LS8++WSS8IxDDuHjH//4rhffAwNKkvrEMw86iCu3\nMypo/otetPX+vvvuy5rbbts6/fozzuD1Z5zRmWjTIAlJkiaCASVJaiUDSpLUSgaUJKmVDChJUisZ\nUJKkVnKYuSRNgCmXjcX1GLuuxdfDP1btN3s2D9511zbz7ly1ite//e3c/8ADPPrYYxx/7LH8+q/8\nCu983/sAWPXtbzPjmc9k3ylTeMHcubz2DW/gJS95CZdccglnnnkmALfeeitHHXUUF1xwwTY/y7Gr\nDChJ6mNveve7Oef1r2fRggUA/Nftt/P8uXN52UtfCsD8V7yCC889l3lHHgnA9StX8rznPY+rrrpq\na0BdccUVHHHEEWNem4f4JKmP3bt2LTMPOWTr9PPnzh31MbNmzWLz5s2sXbuWquJLX/oSC5qAG0sG\nlCT1sXPOOouX/tqvseC00/jgRz/K/Q880NPjTjnlFK6++mpuvPFGjj76aPbZZ58xr82AkqQ+9prF\ni7njhhv4jYULuf7GGzl2wQIeffTRUR936qmncvXVV3PFFVewePGIP3qx0wwoSepzhxx8MK995Sv5\nh8suY++99+br3/jGqI85+OCDmTRpEtdddx0nnnjibqnLQRKS1Me+9OUvc+LxxzNp0iS+t3YtG++7\njxkHH9zTY8877zzWrVu32356w4CSpAmw+fQlu7yOqQzz8+0jePiRR5jZNdrurWedxcB3v8ub//AP\nmdKcQ7rgj/9466/pjua4447boe3vKANKkvrEj4f8Mu6gDzT/8zSc6z//+W2m58+fz/z58x/X7txz\nz92V0oblOShJUisZUJKkVjKgJGmcVNVElzCudvX5GlCSNA4efngKmzZt7JuQqio2btzIlClTdnod\nDpKQpHGwevVMYIAnP3nHRt6NZEP3xWLHy4YNPTedMmUKM2fO3OlNGVCSNA62bJnEN785Z0zXuYSx\nuCL6jm5014fH98pDfJKkVjKgJEmtZEBJklrJgJIktZIBJUlqJQNKktRKBpQkqZUMKElSKxlQkqRW\nMqAkSa1kQEmSWsmAkiS1Uk8BleSkJHcmWZXkXcMs/6kkX0nytSS3JXn52JcqSeonowZUkr2Ai4AF\nwFxgcZK5Q5r9EXBVVR0FnAZ8eKwLlST1l156UMcAq6pqdVU9BlwJLBrSpoCnNvefBnx37EqUJPWj\nXgJqBnBP1/RAM6/bucBvJxkAlgFvHG5FSZYkWZFkxfr1Y/ejXZKkPc9YDZJYDFxaVTOBlwOfTvK4\ndVfVxVU1r6rmTZ8+fYw2LUnaE/USUGuAQ7umZzbzur0OuAqgqr4KTAGmjUWBkqT+1EtA3QwclmRO\nksl0BkEsHdLmO8CJAEmeSyegPIYnSdppowZUVW0BzgauBe6gM1pvZZLzkixsmr0N+J0k/wlcAby6\nqmp3FS1J2vPt3UujqlpGZ/BD97z3dt2/HXjR2JYmSepnXklCktRKBpQkqZUMKElSKxlQkqRWMqAk\nSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIrGVCSpFYyoCRJrWRASZJayYCSJLWSASVJ\naiUDSpLUSgaUJKmVDChJUisZUJKkVjKgJEmtZEBJklrJgJIktZIBJUlqJQNKktRKBpQkqZUMKElS\nKxlQkqRWMqAkSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIrGVCSpFbqKaCSnJTkziSr\nkrxrO21OTXJ7kpVJLh/bMiVJ/Wbv0Rok2Qu4CPglYAC4OcnSqrq9q81hwB8AL6qq7yd5xu4qWJLU\nH3rpQR0DrKqq1VX1GHAlsGhIm98BLqqq7wNU1bqxLVOS1G96CagZwD1d0wPNvG7PBp6d5IYky5Oc\nNNyKkixJsiLJivXr1+9cxZKkvjBWgyT2Bg4D5gOLgUuS7D+0UVVdXFXzqmre9OnTx2jTkqQ9US8B\ntQY4tGt6ZjOv2wCwtKp+WFXfBr5JJ7AkSdopvQTUzcBhSeYkmQycBiwd0ubzdHpPJJlG55Df6jGs\nU5LUZ0YNqKraApwNXAvcAVxVVSuTnJdkYdPsWmBjktuBrwC/X1Ubd1fRkqQ936jDzAGqahmwbMi8\n93bdL+CtzU2SpF3mlSQkSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIrGVCSpFYyoCRJ\nrWRASZJayYCSJLWSASVJaiUDSpLUSgaUJKmVDChJUisZUJKkVjKgJEmtZEBJklrJgJIktZIBJUlq\nJQNKktRKBpQkqZUMKElSKxlQkqRWMqAkSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIr\nGVCSpFYyoCRJrWRASZJayYCSJLVSTwGV5KQkdyZZleRdI7T79SSVZN7YlShJ6kejBlSSvYCLgAXA\nXGBxkrnDtJsKvBm4aayLlCT1n156UMcAq6pqdVU9BlwJLBqm3fuA84HNY1ifJKlP9RJQM4B7uqYH\nmnlbJTkaOLSqvjCGtUmS+tguD5JI8iTgA8Dbemi7JMmKJCvWr1+/q5uWJO3BegmoNcChXdMzm3mD\npgLPA65PchdwLLB0uIESVXVxVc2rqnnTp0/f+aolSXu8XgLqZuCwJHOSTAZOA5YOLqyqB6pqWlXN\nrqrZwHJgYVWt2C0VS5L6wqgBVVVbgLOBa4E7gKuqamWS85Is3N0FSpL60969NKqqZcCyIfPeu522\n83e9LElSv/NKEpKkVjKgJEmtZEBJklrJgJIktZIBJUlqJQNKktRKBpQkqZUMKElSKxlQkqRWMqAk\nSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIrGVCSpFYyoCRJrWRASZJayYCSJLWSASVJ\naiUDSpLUSgaUJKmVDChJUisZUJKkVjKgJEmtZEBJklrJgJIktZIBJUlqJQNKktRKBpQkqZUMKElS\nKxlQkqRWMqAkSa1kQEmSWsmAkiS1Uk8BleSkJHcmWZXkXcMsf2uS25PcluT/JZk19qVKkvrJqAGV\nZC/gImABMBdYnGTukGZfA+ZV1QuAzwLvH+tCJUn9pZce1DHAqqpaXVWPAVcCi7obVNVXqurhZnI5\nMHNsy5Qk9ZteAmoGcE/X9EAzb3teB3xxuAVJliRZkWTF+vXre69SktR3xnSQRJLfBuYBFwy3vKou\nrqp5VTVv+vTpY7lpSdIeZu8e2qwBDu2antnM20aSXwT+EPiFqnp0bMqTJPWrXnpQNwOHJZmTZDJw\nGrC0u0GSo4CPAQurat3YlylJ6jejBlRVbQHOBq4F7gCuqqqVSc5LsrBpdgGwH3B1kluTLN3O6iRJ\n6kkvh/ioqmXAsiHz3tt1/xfHuC5JUp/zShKSpFYyoCRJrWRASZJayYCSJLWSASVJaiUDSpLUSgaU\nJKmVDChJUisZUJKkVjKgJEmtZEBJklrJgJIktZIBJUlqJQNKktRKBpQkqZUMKElSKxlQkqRWMqAk\nSa1kQEmSWsmAkiS1kgElSWolA0qS1EoGlCSplQwoSVIrGVCSpFYyoCRJrWRASZJayYCSJLWSASVJ\naiUDSpLUSgaUJKmVDChJUisZUJKkVjKgJEmtZEBJklrJgJIktVJPAZXkpCR3JlmV5F3DLN8nyd81\ny29KMnusC5Uk9ZdRAyrJXsBFwAJgLrA4ydwhzV4HfL+qngV8EDh/rAuVJPWXXnpQxwCrqmp1VT0G\nXAksGtJmEfCp5v5ngROTZOzKlCT1m1TVyA2SU4CTqurMZvpVwM9V1dldbb7etBlopr/VtNkwZF1L\ngCXN5HOAO8fqieyEacCGUVv1N/fR6NxHI3P/jK4f99Gsqpo+WqO9x6OSQVV1MXDxeG5ze5KsqKp5\nE11Hm7mPRuc+Gpn7Z3Tuo+3r5RDfGuDQrumZzbxh2yTZG3gasHEsCpQk9adeAupm4LAkc5JMBk4D\nlg5psxQ4o7l/CvDlGu3YoSRJIxj1EF9VbUlyNnAtsBfwyapameQ8YEVVLQU+AXw6ySrgPjoh1nat\nONTYcu6j0bmPRub+GZ37aDtGHSQhSdJE8EoSkqRWMqAkSa20xweUl2kaWQ/759VJ1ie5tbmdORF1\nTqQkn0yyrvl/v+GWJ8mHmn14W5Kjx7vGidTD/pmf5IGu99B7x7vGiZbk0CRfSXJ7kpVJ3jxMm75+\nHw1njw4oL9M0sh73D8DfVdWRze3j41pkO1wKnDTC8gXAYc1tCfCRcaipTS5l5P0D8K9d76HzxqGm\nttkCvK2q5gLHAr83zN9av7+PHmePDii8TNNoetk/fa+q/oXO6NTtWQRcVh3Lgf2TPHN8qpt4Peyf\nvldV91bVfzT3NwF3ADOGNOvr99Fw9vSAmgHc0zU9wOPfFFvbVNUW4AHgwHGpbuL1sn8Afr055PDZ\nJIcOs7zf9bof+9nPJ/nPJF9McvhEFzORmtMIRwE3DVnk+2iIPT2gtOv+EZhdVS8AruMnvU2pV/9B\n59prRwB/BXx+guuZMEn2A/4eeEtV/WCi62m7PT2gvEzTyEbdP1W1saoebSY/DvyPcartiaSX91nf\nqqofVNWDzf1lwKQk0ya4rHGXZBKdcPpMVX1umCa+j4bY0wPKyzSNbNT9M+QY+EI6x861raXA6c0o\nrGOBB6rq3okuqi2SHDx4XjfJMXQ+d/rlSyDQGaFH54o7d1TVB7bTzPfREON6NfPxtgdfpmlM9Lh/\n3pRkIZ1RSPcBr56wgidIkiuA+cC0JAPAHwOTAKrqo8Ay4OXAKuBh4DUTU+nE6GH/nAL8bpItwCPA\naX30JXDQi4BXAf+V5NZm3ruBnwLfR9vjpY4kSa20px/ikyQ9QRlQkqRWMqAkSa1kQEmSWsmAkiS1\nkgElSWolA0qS1Er/HySFzny311/vAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot bar graphs\n", "bar_width = 0.20\n", "index = np.arange(3)\n", "opacity=0.4\n", "\n", "plt.bar(index, gru_scores, bar_width, align='center',\n", " color='b', label='GRU', alpha=opacity)\n", "plt.bar(index + bar_width, lstm_scores, bar_width,\n", " align='center', color='r', label='LSTM', alpha=opacity)\n", "plt.title('GRU vs LSTM (precision score)')\n", "\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Both models perform extremely well, however the GRU model performed just a bit better.\n", " \n", "By looking at heatmaps of the confusion matrices we can get a more granular look into how our models classify each class." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAb0AAAFXCAYAAAA28ZCgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtcVXW+//H33uBOc4tGGt6CRMPbaY6Sp7HSzBRTxEIp\npYzRmqwZT5ZJxxybzLBBMslbhWh5ySmzvGVqNxTT0cRGo9TxcrTESyqZpoLoBvb6/eFv9uiouI+L\nDSzX69ljPx6svdb+7u8yfLz9fL/ftZbDMAxDAADYgLOyOwAAQEUh9AAAtkHoAQBsg9ADANgGoQcA\nsA1CDwBgG4QeKlzz5s119OjR895buHChnnzySVPtPvbYYxe0W5kOHjyouLg43Xffffr2228v2L9j\nxw49/fTT6tatm2JjYxUbG6uZM2fqn1cR5eTk6De/+Y3uv/9+36tr1676wx/+oGPHjkmSpkyZopSU\nlAvajouLU05OTmBPELCg4MruAFBe1q5dW9ldOE9OTo7q1q2rWbNmXbBv+/btevTRRzVmzBhNnjxZ\nknT06FENHjxYkvToo49KksLDw/Xxxx/7PldaWqohQ4ZoxowZSk5ODvxJAFcZQg9Vjsfj0fjx4/XN\nN9+otLRUrVq10p///Ge53W5lZ2crMzNTHo9HR48eVXx8vIYOHao//elPkqQBAwZo2rRp6t+/v+Li\n4rRq1Sr9+uuvGjJkiDZt2qStW7cqODhYGRkZCgsLu2R7OTk5GjdunMLCwrRv3z5Vr15daWlpatq0\n6QX9nTdvnubMmSOn06m6devqxRdf1OHDhzVx4kSdPHlSSUlJmjNnznmfmThxoh5//HF17drV915o\naKhSUlK0Y8eOS/7ZFBQU6OjRo4qOji6nP23AZgyggkVFRRlxcXHGfffd53t16tTJeOKJJwzDMIwp\nU6YYaWlphtfrNQzDMNLT042XXnrJ8Hq9xiOPPGL8+OOPhmEYxqFDh4yWLVsav/zyi6/df/7cuXNn\nIzU11TAMw1i2bJnRokULY9u2bYZhGMbgwYONjIyMMttbv3690aJFC+Obb74xDMMw3n//faN3794X\nnMu6deuMrl27+r53wYIFRo8ePQyv12ssWLDAd07/7tZbbzW2b99e5p/T+vXrjVtuucW47777jJ49\nexrt27c34uPjjczMTMPj8RiGYRiTJ082Xn755Qs+27NnT2P9+vVltg/YEZUeKsXs2bMVGhrq2164\ncKE+//xzSdKqVat08uRJrVu3TpJUXFys66+/Xg6HQ1OnTtWqVau0dOlS7d69W4ZhqKio6KLf0a1b\nN0nSjTfeqLp166pFixaSzg4ZHj9+/LLttWjRQu3atZMkJSQkKCUlRceOHdN1113n+441a9YoNjbW\ndy59+vTRX/7yF+3fv7/M8zcMQw6Hw7edmpqqnJwceb1eFRUVKSsry9fXfw5vLliwQBMmTFCXLl1U\nrVo1STqvjXN5vV4FBQWV2QfAjgg9VDler1cjR45Up06dJEmFhYU6c+aMTp06pd69e6tr165q166d\nEhISlJWV5Vv48e9cLpfv53+GxLku196/h4ZhGBd9798ZhqGSkpIyz7Ft27basGGDoqKiJEkjR46U\nJO3fv1+9evW66GcSEhL03XffadiwYVqwYIGCg4N13XXX6Ycffrjg2F9++UV16tQpsw+AHbF6E1VO\nhw4d9N5778nj8cjr9erFF1/U66+/rry8PBUUFGjo0KG65557tGHDBt8x0tmQulzYnOty7W3fvl3b\nt2+XdHbeLjo6WiEhIRf0dfny5b5VowsWLFCdOnUUERFR5ncnJycrMzNTq1at8gXnmTNn9OWXX8rp\nvPRfy+TkZOXn5+uvf/2rJOmOO+7Q119/rd27d/uOWbx4sdxutyIjI/3+swDsgkoPVc7gwYP16quv\nqnfv3iotLVXLli01YsQIXXvttbr77rvVo0cPhYSEKDw8XM2aNVNeXp7Cw8MVExOjhx9+WG+99ZZf\n39O8efNLtudyuVS3bl1NnDhRBw4cUGhoqMaNG3dBG3feeacGDhyoAQMGyOv1KjQ0VJmZmWUGlyS1\nbNlSs2fP1ptvvqn09HQ5nU55PB5FR0frww8/vOTnateureeee05jx45VXFycmjRpojFjxmj48OEq\nKSmRx+NR48aNNX369Mv2AbAjh3GpsSHAxnJycjRmzBgtXbq0srsCoBzxT0EAgG1Q6QEAbINKDwBg\nG4QeAMA2CD0AgG0E9JKFDs3jAtk8UCE+Xzm5srsAlIuajQJ37eZvIjpd8We/z/uqHHtSNq7TAwCY\ndqlb4lU1DG8CAGyDSg8AYJrDYY0ayhq9BACgHFDpAQBMc8oac3qEHgDANKssZCH0AACmOS0yp0fo\nAQBMs0qlZ41oBgCgHBB6AADbYHgTAGCag9WbAAC7YCELAMA2rLKQhdADAJjmtEjoWaMeBQCgHBB6\nAADbYHgTAGCawyI1FKEHADCNhSwAANuwykIWQg8AYJpVLk63xiAsAADlgNADANgGw5sAANO4DRkA\nwDZYvQkAsA1WbwIAbIPVmwAAVDFUegAA06yykMUavQQAoBxQ6QEATGP1JgDANli9CQCwDVZvAgBQ\nxVDpAQBMY04PAGAbVpnTY3gTAGAbVHoAANOsspCF0AMAmMYdWQAAqGKo9AAAprF6EwBgG1ZZvUno\nAQBMs8pCFub0AAC2QaUHADDNKsObVHoAANug0gMAmMbqTQCAbVhleJPQAwCYZpXVm4QeAMA0q1R6\nLGQBANgGoQcAsA2GNwEAprF6EwBgG1aZ0yP0AACmsXoTAGAbVqn0WMgCAKiyvF6vRo0apX79+ikp\nKUl5eXnn7Z8xY4b69OmjhIQEffnll5dtj0oPAFBlZWVlyePxaN68ecrNzVVaWpoyMjIkSSdOnNC7\n776rL774QkVFRYqPj1dMTEyZ7RF6AADTArV6c+PGjerYsaMkqU2bNtqyZYtvX40aNdSwYUMVFRWp\nqKjIrz4QegAA0wI1p1dQUCC32+3bDgoKUklJiYKDz8ZXgwYN1LNnT5WWlurJJ5+8fD8D0ksAgK04\nHI4rfpXF7XarsLDQt+31en2Bt3r1auXn52vFihVatWqVsrKy9P3335fZHqEHADDNYeK/skRHR2v1\n6tWSpNzcXEVFRfn21a5dW9WrV5fL5dI111yjWrVq6cSJE2W2x/AmAKDKiomJ0dq1a5WYmCjDMJSa\nmqqZM2cqPDxcXbp00bp169S3b185nU5FR0frzjvvLLM9h2EYRqA626F5XKCaBirM5ysnV3YXgHJR\ns1FkwNp+qtPQK/7sG19NLMeelI3hTQCAbTC8CQAwjRtOAwBswyq3ISP0AACmWaXSY04PAGAbVHoA\nANOcPFoIZtzeqZ2eTB4gl6uadu/Yo7EjJ+lUYdF5xyQ8EqeER+J05rRHebv3KT0lQyePF0iSPvn6\nPR05/Ivv2PffWagvP1lVkacAm1qzfoOmvD1TxZ5i3RzZRKP+Z6jcNWv6fcw9vfvphrp1fcf+rl+C\nYrveo70HftLYiW/o2K/HVVxSrPge9yqpb0KFnhsuzSrDm4ReFVTnuhCNHDtUf3xouPbn/aQ/PjdQ\nf3xuoNJfzvAd0/a3t6j/oAf0ZN9k/Xz4F917f2cNTxmiF58ZqxubNFLB8QI9Gv90JZ4F7OjYr79q\n9LjXNXNyusIbN9Kkae9oyvSZ+tPQp/w6Zs/e/QqpVUsfTH/zgrZHv5quXvfGqHfP7jpZUKikwc+o\nebOmui26TUWeIiyOOb0q6L86RGvb5v/V/ryfJEmL5i5XTK+7zzumRetm+vu6XP38/6u5r75Ypzvv\nuU3B1YJ1S9uWKvV6NfndVM1aMkUD/ztRTif/qxF4X/99k1o3j1J440aSpAfvi9OnK7J17j0wyjrm\nu63/kNPp1BPDnlffx/+oae++p9LSUknS/T3uVfcud0uSarlr6saGDXTwcH7FniAuyelwXPGrQvtZ\nod8Gv4TVr6v8Q0d82z8fOiJ3rZq6tmYN33v/+H6nbm3/G4U1rCdJiu0TI5ermmrXqaWgoCB9s/Zb\nJf9+lJ7qP0K3dYhWQhJ3x0HgHc4/orAb6vm2b6hXVwWFp1R46pRfx5SWlqr9rW31RtoYvTPxNX39\nzSZ9sGiJJOn+Ht1Uo3p1SdLaDX/Xd1u36Y7b2lXQmeFyHI4rf1Ukv4Y39+3bp+zsbJ05c8b33qBB\ngwLWKbtzOC/+W+D1en0/f/f3rZrx5lylvvGCDMPQsgVf6vixEyouLtEnH33uO664uETzZi7WA0m9\n9NHsJQHvO+zNa3gv+n6QM8ivY/rE9fBtu1wuPfJgb81d+LH6P9Db9/4nn3+pCRlv67XRL6je9aHl\n1HPYhV+V3uDBg3X8+HG5XC7fC4Fz+ODPur7edb7tumHX68SvJ3W66F//6KhRs4ZyN2zR7/sM1eMJ\nz2rV5+skSSd+Pal77++sps1v8h3rcDhUWlJaYf2HfdW/4QYd+eWobzv/5yMKqeVWjRrV/Tpm6Rcr\ntHP3j759hmH4HiNjGIZez5iujJl/Vcb4VP321rYVcEbw11U1vNmgQQMNGTJEAwYM8L0QOBv+9q1a\n/2dzNY5oKEmKT4zVmhXrzzum7g2hmjJnrG/Ic+DgRGUtO/v4jcibI/T7p/vL6XTKdY1LCf3jtGL5\nmoo9CdjS7e2itXnbdu3df0CStOCT5ep0x+1+H7N7zx5NnTVHpaWlOn3mjOYt/kTd7r5LkvTaG1O1\n6fvNem/qJDVv1rQCzwr+CNSjhcq9n/48ZWHu3Lk6cOCAmjVr5nsvPj7+so3zlIUr1/6udvpD8gAF\nVwvWgb0H9crzr6vhjfU14pWnfasy+/SPU5/+PeV0OvT9xn/o9ZSp8pzx6Jrq12jYqD+o1X82V3Bw\nsLI/+5umTXi3ks/IunjKwv/N39Zv0JS3Z6m4pESNGzbQmBHP6cDBg0oZP8m3KvNix9QOqaWi06f1\n6uS3tHnbdpWUlKprp4566vcDdPjnI+r50AA1CLvhvMsfHupzv+7v0a2yTtVyAvmUhZH3/umKP5v6\n+dhy7EnZ/Aq9pKQkRUZGKiQk5OyHHA4NGzbsso0TergaEHq4WhB6fi5kcblcevnllwPdFwCARV1V\nN5xu2LChMjMz1apVK99V9x06dAhoxwAA1mGRzPMv9EpKSrRnzx7t2bPH9x6hBwCwGr9Cb+zY88db\n8/O5CwIA4F+uquHNSZMmae7cuSouLtbp06d10003admyZYHuGwDAIir60oMr5dd1eitXrtTq1avV\nq1cvLV++XGFhYYHuFwDAQq6qi9Pr1asnl8ulwsJCRUREqLi4OND9AgCg3Pk1vFm/fn3Nnz9fNWrU\nUHp6uk6cOBHofgEALMQiU3r+hV5KSooOHTqk7t27a9GiRUpPTw90vwAAKHd+hd6xY8c0Y8YM7dmz\nRzfffLPq1at3+Q8BAGzDKk9O92tOb+jQoYqMjNRzzz2nxo0ba/jw4YHuFwDAQqyykMWvSk+SHn74\nYUlSixYt9NlnnwWsQwAA67FIoedf6EVGRurjjz9W+/bttXXrVtWpU0c//nj2mVdNmjQJaAcBAFXf\nVXVx+g8//KAffvhBM2bMUFBQkGrWrKlRo0bJ4XDo3Xd5ZA0AwBrKnNPbunWr4uPj9c477ygpKUn5\n+fkqLCzUwIEDNWfOHAIPAGApZYbeuHHjlJaWJpfLpYkTJ+rtt9/WggULNH369IrqHwDAAqzy5PQy\nhze9Xq9atGihw4cPq6ioSK1bt5ZknaWpAICKYZVcKDP0goPP7l6zZo1uv/12SVJxcbFOnToV+J4B\nACzDaY3MKzv0br/9diUmJurQoUPKyMjQ3r17lZKSotjY2IrqHwDAAq6KSu+JJ55Qly5d5Ha7FRYW\npr1796pfv36KiYmpqP4BAFBuLnvJQtOmTX0/h4eHKzw8PKAdAgAgUPy+IwsAAJdyVQxvAgDgj6ti\nIQsAAP6g0gMA2IZFMs+/RwsBAHA1oNIDAJhmlacsUOkBAGyDSg8AYFpF3zj6ShF6AADTLDK6SegB\nAMxjTg8AgCqGSg8AYBoXpwMAbMMimcfwJgDAPqj0AACmMbwJALANnrIAAIBJXq9Xo0eP1o4dO+Ry\nufTKK68oIiLCt/+rr77Sm2++KcMw1Lp1a7300ktlVp3M6QEATHM4HFf8KktWVpY8Ho/mzZun5ORk\npaWl+fYVFBTotdde09SpU/XRRx+pUaNGOnbsWJntUekBAEwL1JTexo0b1bFjR0lSmzZttGXLFt++\nb7/9VlFRUXr11Ve1b98+PfjggwoNDS2zPUIPAGBaoO7IUlBQILfb7dsOCgpSSUmJgoODdezYMeXk\n5Gjx4sW69tpr1b9/f7Vp00ZNmjS5dD8D0ksAAMqB2+1WYWGhb9vr9So4+Gy9VqdOHd1yyy2qV6+e\natasqXbt2mnbtm1ltkfoAQBMC9ScXnR0tFavXi1Jys3NVVRUlG9f69attXPnTh09elQlJSX67rvv\n1KxZszLbY3gTAFBlxcTEaO3atUpMTJRhGEpNTdXMmTMVHh6uLl26KDk5WY8//rgkqXv37ueF4sUQ\negAA0wK1kMXpdColJeW895o2ber7uWfPnurZs6ff7RF6AADTuCMLAMA2LJJ5hB4AwDweIgsAQBVD\n6AEAbIPhTQCAaRYZ3ST0AADmsXoTAGAbFsk8Qg8AYJ5VKj0WsgAAbIPQAwDYBsObAADTLDK6SegB\nAMyzyh1ZCD0AgGkWyTxCDwBgHqs3AQCoYqj0AACmWaTQo9IDANgHlR4AwDSrzOkRegAA0yySeYQe\nAMA8q1R6zOkBAGyDSg8AYJpFCj1CDwBgHsObAABUMVR6AADTLFLoBTb0Vn4zO5DNAxWi3S19KrsL\nQLn4Pu+rgLXNUxYAALZhkcxjTg8AYB9UegAA06yyepPQAwCYZpHMY3gTAGAfVHoAANMcTmuUeoQe\nAMA0hjcBAKhiqPQAAKaxehMAYBsWyTxCDwBgnlUqPeb0AAC2QaUHADDNIoUelR4AwD6o9AAA5lmk\n1CP0AACmWWUhC6EHADDNIplH6AEAzLPKvTdZyAIAsA1CDwBgGwxvAgBMY04PAGAbrN4EANiGRTKP\n0AMAmGeVSo+FLAAA2yD0AABVltfr1ahRo9SvXz8lJSUpLy/vosc8/vjjmjt37mXbI/QAAKY5HFf+\nKktWVpY8Ho/mzZun5ORkpaWlXXDMxIkTdeLECb/6yZweAMC0QM3pbdy4UR07dpQktWnTRlu2bDlv\n/2effSaHw+E75nKo9AAA5jlNvMpQUFAgt9vt2w4KClJJSYkkaefOnVq6dKmeeeYZv7tJpQcAMC1Q\nlZ7b7VZhYaFv2+v1Kjj4bHQtXrxYhw8f1oABA3TgwAFVq1ZNjRo10l133XXJ9gg9AECVFR0drezs\nbMXGxio3N1dRUVG+fcOHD/f9PGXKFNWtW7fMwJMIPQBAFRYTE6O1a9cqMTFRhmEoNTVVM2fOVHh4\nuLp06fJ/bo/QAwCYFqhr051Op1JSUs57r2nTphccN2TIEL/aI/QAAKZZ5Y4shB4AwDSLZB6hBwAo\nBxZJPa7TAwDYBpUeAMA0h5NKDwCAKoVKDwBgmkWm9Ag9AIB5XLIAALANi2Qec3oAAPug0gMAmGeR\nUo/QAwCYxiULAABUMVR6AADTLDK6SegBAMqBRVKP4U0AgG1Q6QEATLNIoUfoAQDMs8rqTUIPAGCa\nVW5DxpweAMA2qPQAAOZZo9Cj0gMA2AeVHgDANKvM6RF6AADTCD0AgH1YZLKM0AMAmGaVSs8i2QwA\ngHmEHgDANhjeBACYZpXhTUIPAGCeNTKP0AMAmMcNpwEA9mGR4U0WsgAAbIPQAwDYBsObAADTLDK6\nSehVttV/W6uJb05VsadYN9/cVCl/Him3u6Zfx5SWluq1CZO1dn2OSktLNfCRh9U3ofd5n120ZKlW\nZH+lNya8JkkyDENTpk7TiuyvJEn/0aql/jzif1SjevWKOWHgEsaMH6FdO3/U7GnzKrsruAJWuWSB\n4c1KdPTYMb2Y8hdNeDVVnyz4QI0bNdTEN97y+5iPFi5W3r79WvTBXzV39juaM3eeNm/9hyTp+PET\nShk7TmNfe12GDF97K7K/0tfrN2j+e7O1eN57Kjp9Wu/N/bDiThr4N02aRejtuRPULa5zZXcFZjgd\nV/6qyG5W6LfhPOvWb1DrVi0VEX6jJKlfQh8t++wLGYbh1zErVq1WfK+eCg4OVu2QEPXo1lVLP/1M\nkvR51grVq3u9kp956rzv7HrP3Xr3nUxVq1ZNhYWndPToMdWuHVIxJwxcROLv4rX4w0/1xdLsyu4K\nTHA4HFf8qkiEXiU6dPiw6oeF+bbDbqingsJCFRae8uuYs/tuOGffDTp8+GdJUt+E3vrjoN/rmmuu\nueB7qwUH6/0P56tbr9769dfj6tK5UyBOD/DL2FGTtHTRF5XdDdiEX3N6+/btU3Z2ts6cOeN7b9Cg\nQQHrlF2cW9Gdyxnk9OuYi+0797NlebjvA3rowQRNmTpNw55/QbOmvXX5DwHApVhjSs+/Sm/w4ME6\nfvy4XC6X7wXz6oeF6ecjR3zb+T//rJCQWrq2Rg2/jqkfFqYjR345b1/YDf+q/C5mx87/1bYdOySd\nHY5IuP8+bduxs7xOCQCqNL9Cr0GDBhoyZIgGDBjge8G8O9rfpu+3bFXe3n2SpA8XLFbnuzr6fUzn\nTh21aMlSlZSU6MTJk/r0iyzd0+muMr9z565devHlVBWdPi1JWrLsU93WLrq8Tw2AzVhlTs+v4c3O\nnTtr/Pjxatasme+9+Pj4gHXKLq4PDdWYUS9o2IgXVFxcrBsbN1Lq6FHa+o9teumVNM1/f/Ylj5Gk\nfgm9tX//AT3w8AAVlxTrwd7x+q9b25b5nb1ie2jvvgNK/N1jCgoKUrPIJkp5cWRFnC6Aq5hV7r3p\nMC41aXSOpKQkRUZGKiTk7Co/h8OhYcOGXbZxz4lfLnsMUNW1u6VPZXcBKBff530VsLb3Lfv0ij97\nY88e5diTsvlV6blcLr388suB7gsAwKKscnG6X6HXsGFDZWZmqlWrVr4T69ChQ0A7BgBAefMr9EpK\nSrRnzx7t2bPH9x6hBwDwsUah51/ojR079rzt/Pz8gHQGAIBA8iv0Jk2apLlz56q4uFinT5/WTTfd\npGXLlgW6bwAAi7DK6k2/rtNbuXKlVq9erV69emn58uUKO+e2WAAAyOG48lcF8qvSq1evnlwulwoL\nCxUREaHi4uJA9wsAYCFWWb3pV6VXv359zZ8/XzVq1FB6erpOnjwZ6H4BAFDuyqz0Fi9eLElq27at\ngoKCFBUVJcMw1Ldv3wrpHADAIgI0p+f1ejV69Gjt2LFDLpdLr7zyiiIiInz7Z82a5Vtj0qlTJz31\n1FOXakrSZUJv9+7dvp+XLVumuLg4GYZhmTIWAFAxApULWVlZ8ng8mjdvnnJzc5WWlqaMjAxJZ58A\ntGTJEn300UdyOp166KGH1LVrV7Vo0eKS7ZUZesnJyb6fc3Nz/br1GAAA5WXjxo3q2PHsTfbbtGmj\nLVu2+PbVr19fb7/9toKCgiSdvab8Ys8QPZdfC1kk60xSAgAqQYAioqCgQG6327cdFBSkkpISBQcH\nq1q1agoNDZVhGBo3bpxatWqlJk2alNme36EHAMClBKowcrvdKiws9G17vV4FB/8rus6cOaORI0eq\nZs2aeumlly7bXpmhN2zYMDkcDhmGoV27dp033Jmenn4l/QcAwG/R0dHKzs5WbGyscnNzFRUV5dtn\nGIYGDx6s3/72t3riiSf8aq/M0EtMTLzozwAAnCdAqzdjYmK0du1aJSYmyjAMpaamaubMmQoPD5fX\n69WGDRvk8Xi0Zs0aSWeLtbZtL/1cUb+ep3eleJ4ergY8Tw9Xi0A+T+/wmlVX/NmwjneXWz8uhzk9\nAIB5Flns6NcdWQAAuBpQ6QEATLPKZW1UegAA26DSAwCYZ5Hn6RF6AADTrDK8SegBAMwj9AAAduGw\nyPAmC1kAALZB6AEAbIPhTQCAeczpAQDsgtWbAAD7IPQAAHbB6k0AAKoYQg8AYBsMbwIAzGNODwBg\nG4QeAMAuuGQBAGAfrN4EAKBqodIDAJjmcFijhrJGLwEAKAdUegAA81jIAgCwC1ZvAgDsg9WbAABU\nLVR6AADTGN4EANiHRUKP4U0AgG1Q6QEAzLPIxemEHgDANJ6cDgBAFUOlBwAwzyILWQg9AIBpXLIA\nALAPiyxksUYvAQAoB1R6AADTWL0JAEAVQ6UHADCPhSwAALtg9SYAwD4ssnqT0AMAmMdCFgAAqhZC\nDwBgGwxvAgBMYyELAMA+WMgCALALKj0AgH1YpNKzRi8BACgHhB4AwDYY3gQAmGaVpywQegAA81jI\nAgCwC4dFFrIQegAA8yxS6TkMwzAquxMAAFQEa9SjAACUA0IPAGAbhB4AwDYIPQCAbRB6AADbIPQA\nALZB6FVB06ZN08CBA/XII48oKSlJW7ZsqewuAVcsJydHzz777HnvjR8/XgsXLqykHsHOuDi9itm1\na5dWrlypuXPnyuFwaNu2bXr++ee1ZMmSyu4aAFgeoVfF1KpVSz/99JPmz5+vu+66Sy1bttT8+fOV\nlJSkJk2a6Mcff5RhGJowYYJCQ0M1atQoHTp0SPn5+brnnnv07LPPasSIEQoODtZPP/0kj8ej2NhY\nZWdn6+DBg3rrrbcUHh5e2acJqLS0VC+88AK/v6hQDG9WMWFhYcrIyNCmTZvUr18/de/eXdnZ2ZKk\n6OhozZkzRz169FBmZqYOHjyoNm3a6J133tH8+fP1wQcf+Npp1KiRZsyYocjISO3fv1/Tp09Xt27d\ntHLlyso6NdjY+vXrlZSU5HstXbpUQUFB/P6iwlHpVTF5eXlyu90aO3asJGnz5s0aNGiQ6tWrp/bt\n20s6G34rV65UnTp1tHnzZq1fv15ut1sej8fXTqtWrSRJISEhioyM9P187jFARWnfvr0mTJjg2x4/\nfrwKCgq0a9cufn9Roaj0qpgdO3YoJSXF95e7SZMmCgkJUVBQkG9By6ZNm9SsWTMtXLhQtWrVUnp6\nuh577DFHTzZXAAAAqElEQVSdPn1a/7yVqsMiN3+FvfH7i4pGpVfFdOvWTbt379YDDzyga6+9VoZh\naPjw4Zo9e7YWLVqkWbNmqUaNGho3bpyOHDmi5ORk5ebmyuVyKSIiQvn5+ZV9CoBfgoKCtGbNGn5/\nUaF4yoJFJCUlafTo0WratGlldwUALIvhTQCAbVDpAQBsg0oPAGAbhB4AwDYIPQCAbRB6AADbIPQA\nALZB6AEAbOP/AXjcuujdgSkYAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot heatmap\n", "import seaborn as sns\n", "labels = ['Spam', 'Ham']\n", "gru_cm_avg = np.zeros((2,2))\n", "for cm in gru_cms:\n", " # turn cm into percentages\n", " totals = np.repeat(np.sum(cm, axis=1), 2, axis=0).reshape(2,2)\n", " cm_ = cm / totals / 3\n", " gru_cm_avg = np.sum([gru_cm_avg, cm_], axis=0)\n", " \n", "sns.heatmap(gru_cm_avg, annot=True, xticklabels=labels, yticklabels=labels)\n", "plt.title('Heatmap of GRU')" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAb0AAAFXCAYAAAA28ZCgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4FFW+xvG3u5OGQIiIxLBIIgRDAL2DER1UEAXCZRdc\nAMUYZAQcFEVQFFHAyBKWCLgBooAyyoBswwAiIsEwYVgGiCyDKAphJyhLSFiS0HX/4Lk9ZjShpdKd\nFPX9+PTzdFVXn5yjkZffqVNVDsMwDAEAYAPO0u4AAACBQugBAGyD0AMA2AahBwCwDUIPAGAbhB4A\nwDYIPZSaevXq6cSJE4X2LVy4UH379jXVbq9evX7Vbmk6cuSIOnTooE6dOmnr1q2FPmvRooW2b99e\n7PfL2ngAKwsq7Q4AJS09Pb20u1DIhg0bVLVqVc2aNeuKvl/WxgNYGaGHMisvL08TJkzQpk2bdPHi\nRTVo0ECvvvqqQkNDlZqaqmnTpikvL08nTpxQ586dNWDAAA0ZMkSSlJiYqPfff189evRQhw4dtGbN\nGp06dUr9+/fXli1btHPnTgUFBWnKlCmKiIgosr0NGzZo3LhxioiI0IEDB1S+fHklJycrOjr6V/2d\nO3euZs+eLafTqapVq+q1117TsWPHNGnSJJ05c0YJCQmaPXv2b441NzdXQ4YMUWZmppxOpxo2bKik\npCQNHTr0isYDoAgGUEpiYmKMDh06GJ06dfK+mjdvbvTp08cwDMN4++23jeTkZMPj8RiGYRgpKSnG\n8OHDDY/HYzz22GPG3r17DcMwjKNHjxr169c3fv75Z2+7///+vvvuM0aPHm0YhmEsW7bMiI2NNXbt\n2mUYhmH069fPmDJlSrHtrV+/3oiNjTU2bdpkGIZhfPrpp0aXLl1+NZZ169YZrVq18v7cBQsWGG3b\ntjU8Ho+xYMEC75j+23333Wds27bNWLRokdGrVy/DMAyjoKDAGDp0qLFv374rGg+AolHpoVR99NFH\nqlKlind74cKF+uKLLyRJa9as0ZkzZ7Ru3TpJUn5+vq677jo5HA5NnTpVa9as0dKlS/XDDz/IMAyd\nO3fuN39G69atJUm1atVS1apVFRsbK0mKjIzU6dOnL9tebGysGjduLEl68MEHlZSUpJMnT+raa6/1\n/oy1a9eqXbt23rE88MADGjVqlA4ePOjTv4fbbrtNEydOVEJCgu666y4lJiYqKirqisYDoGiEHsos\nj8ejV155Rc2bN5d0aQrwwoULOnv2rLp06aJWrVqpcePGevDBB7Vq1SoZRdxG1u12e98HBwf/6vPL\ntedyuQodbxjGb+77b4ZhqKCgwKex1qpVS19++aU2bNig9evX64knntCrr76qNm3a/O7xACgaqzdR\nZjVt2lSffPKJ8vLy5PF49Nprr+nNN99UZmamcnJyNGDAALVo0UIbN270HiNdCilfw0bSZdv79ttv\n9e2330q6dN4uLi5OYWFhv+rr8uXLvassFyxYoMqVKxdZrf23Tz/9VEOGDFHTpk314osvqmnTpvr+\n+++vaDwAikalhzKrX79+Gjt2rLp06aKLFy+qfv36evnll1WhQgXde++9atu2rcLCwhQZGam6desq\nMzNTkZGRio+P16OPPqr33nvPp59Tr169Ittzu92qWrWqJk2apEOHDqlKlSoaN27cr9q4++671bNn\nTyUmJsrj8ahKlSqaNm2anE7f/l7ZuXNnbdy4Ue3atVNISIhq1Kihxx9/XJJ+93gAFM1hFDUnBEAb\nNmzQG2+8oaVLl5Z2VwCUAKY3AQC2QaUHALANKj0AgG0QegAA2yD0AAC24ddLFm6Pbu3P5oGAWLth\nVml3ASgR5avW8Fvb/xPV/Iq/uy3z6xLsSfG4Tg8AYJrD4SjtLviE6U0AgG1Q6QEATHM4rFFDWaOX\nAACUACo9AIBpTlnjnB6hBwAwzSoLWQg9AIBpTouc0yP0AACmWaXSs0Y0AwBQAgg9AIBtML0JADDN\nwepNAIBdsJAFAGAbVlnIQugBAExzWiT0rFGPAgBQAgg9AIBtML0JADDNYZEaitADAJjGQhYAgG1Y\nZSELoQcAMM0qF6dbYxIWAIASQOgBAGyD6U0AgGnchgwAYBus3gQA2AarNwEAtsHqTQAAyhgqPQCA\naVZZyGKNXgIAUAKo9AAAprF6EwBgG6zeBADYBqs3AQAoY6j0AACmcU4PAGAbVjmnx/QmAMA2qPQA\nAKZZZSELoQcAMI07sgAAUMZQ6QEATGP1JgDANqyyepPQAwCYZpWFLJzTAwDYBpUeAMA0q0xvUukB\nAGyDSg8AYBqrNwEAtmGV6U1CDwBgmlVWbxJ6AADTrFLpsZAFAGAbhB4AwDaY3gQAmOav1Zsej0cj\nRozQ7t275Xa7NXLkSEVFRXk/nzFjhpYuXSqHw6GnnnpK8fHxxbZH6AEATPPXOb1Vq1YpLy9Pc+fO\nVUZGhpKTkzVlyhRJUnZ2tj7++GOtXLlS586dU+fOnQk9AID/+Wv15ubNm9WsWTNJUqNGjbRjxw7v\nZyEhIapRo4bOnTunc+fO+VRtEnoAANP8Venl5OQoNDTUu+1yuVRQUKCgoEvxVb16dbVv314XL15U\n3759L99Pv/QSAIASEBoaqtzcXO+2x+PxBl5aWpqysrL01Vdfac2aNVq1apW2bdtWbHuEHgCgzIqL\ni1NaWpokKSMjQzExMd7PrrnmGpUvX15ut1vlypVTpUqVlJ2dXWx7TG8CAEzz1+rN+Ph4paenq3v3\n7jIMQ6NHj9bMmTMVGRmpli1bat26derataucTqfi4uJ09913F99PwzAMv/RU0u3Rrf3VNBAwazfM\nKu0uACWifNUafmv7qWbPXvF3p659qwR7UjwqPQCAaTxlAQBgG1a54TQLWQAAtkGlBwAwzWmNQo9K\nDwBgH1R6AADTWMgCALANqzxEltADAJhmlUqPc3oAANug0gMAmOa0yHV6hF4Zdfe9d+jpF3vJ7Q7W\n99/u1cghbyo352yhY7o+fr+6JnTShfN52vvDfo0b/o6yT5+RJK3cOE/Hj/3sPXb29M+0YsnqgI4B\n9pW27p96a+oHysvLV0zdOhox5EWFVqz4u495fsgwhVe9Tq8Mek6StPv7PRqdMklncnIVWrGinu7T\nS3+8LS5g40LRmN7EFatc5RoNG/eCXno6SQ/F/0mHDhzRMy/+qdAxtzX5gx7v01X9El5Sj45/Vvqa\njXpl1ABJUlTtG5SdnaMeHf/sfRF4CJQTJ09p2KhxShn1upb89WPVrFFdk6e8/7uPmfnJHG39r8fE\nDHj5VXXp2F4L/zJTb45O0qjxk/TTzyf8PiZcPQi9MqhJ09v07227dWDfYUnSgk+Wqs39LQodU//m\nm7Rp3VZlHf1JkpT6RbqatfijgoKD9D9xDeS56NGUT8bp02VT9eQzPeR08p8agfHPjZt0c/16iqp1\ngySpa5f7tXzlV/rlve0vd8zGzVuVvn6THrq/k/c7J0+d1tGs4+rY5tKN7KteV0Uxdesoff3GQA0N\nxXA6HFf8Cmg/A/rT4JOI6uE6duS4dzvr6HGFVqqoiqEVvPt2frNbje9spGo1rpckdXyotdzl3Lqm\ncphcQS5tSN+iZ58Yqj7dB6nJPY3V7fH7Az4O2NPRrOOKuP5673ZEeLhycnOVe/asT8dkHf9J4ya/\nozHDh8rl+s8fUddWvkY1q1fTks+/kCQdPHRYW77Zpp9+/s80PkqPw3Hlr0Dy6ZzegQMHlJqaqgsX\nLnj39e7d22+dsjtHEffzuXjR432/ddN2TX/rLxo/Zbg8hqG/f/aFTp3MVkF+vhbP/dx7XH5evj75\ncIG6J3bWnFmL/N53wPB4fnP/L2cbijrGMAy9NPwNvfjc0wqvet2vPp88dpTefGeq/jJ3vmLqRqvZ\nXU0UFBxcMh2HLfgUev369VPr1q0VFhbm7/5A0rHDx3XzH2K92+ERVXX6VLbOnzvv3VehYoi2bNym\nJZ+tkCRVua6y+j6fqNOnzqht55b6fteP2rN7r6RLJ5gLCgoCOwjYVrVqEdr+713e7ayfjiusUiVV\nCAm57DE/7svUoSNHlPLWe5Kkn06ckMfjUV5enkYMeVEew9DksaMUFOSSJPUb9JLubXpXgEaG4ljl\n4nSfpjerV6+u/v37KzEx0fuC/6z/x2bdfGt91brx0gMfH3y0g9JW/bPQMeHXX6epn4z3Tnn+6Zke\nWvn3VElSdMyN6vt8opxOp8qVc6trQid9uezrwA4CtnXnHY21becuZR44KEn6bNHfdW+zu3065g83\nN9TKRfM076MPNO+jD/Rw505q3eI+jRjyoiTpjbEpSl37D0lSxvYd2vPjXv2x8W0BHB2K4jDxT0D7\n6cuT0+fMmaNDhw6pbt263n2dO3e+bOM8Of3K3XXv7Xr6hV4KDg7Wwf2HNeKF8aoZWU2vjh6oHh3/\nLEl6OKGTHn6sk5xOhzL+tVPjR7yjCxfyVK58OQ0e8bRublRfQcFB+mp5mt5LmVnKI7Iunpz++61d\nt15vTZuu/PwC3VCzhka9NkQHDx3R68njNe+jD4o85pr/mk2a8uEsnTx12nvJwvc/7tXrY8br3Pnz\nqhASoiEDn1WD2HoBH59V+fPJ6a/875Ar/u7oL8aUYE+K51PoJSQkqE6dOt7pTYfDoYEDB162cUIP\nVwNCD1cLQs/Hc3put1uvv/66v/sCALAoq5zT8yn0atSooWnTpqlBgwbeq+6bNm3q144BAKzDIpnn\nW+gVFBRo37592rdvn3cfoQcAsBqfQm/MmMLzrVlZWX7pDADAmq6q6c3Jkydrzpw5ys/P1/nz53Xj\njTdq2bJl/u4bAMAiAn3pwZXy6Tq91atXKy0tTR07dtTy5csVERHh734BACzkqrr3Znh4uNxut3Jz\ncxUVFaX8/Hx/9wsAgBLn0/RmtWrVNH/+fIWEhCglJUXZ2dn+7hcAwEIsckrPt9BLSkrS0aNH1aZN\nGy1atEgpKSn+7hcAACXOp9A7efKkZsyYoX379ummm25SeHi4v/sFALCQq+rJ6QMGDFCdOnX0wgsv\n6IYbbtDgwYP93S8AgIVYZSGLT5WeJD366KOSpNjYWK1YscJvHQIAWI9FCj3fQq9OnTr629/+piZN\nmmjnzp2qXLmy9u699Ky22rVr+7WDAICy76q6OP3HH3/Ujz/+qBkzZsjlcqlixYoaNmyYHA6HPv74\nY3/3EQCAElHsOb2dO3eqc+fO+vDDD5WQkKCsrCzl5uaqZ8+emj17NoEHALCUYkNv3LhxSk5Oltvt\n1qRJk/TBBx9owYIFmj59eqD6BwCwAKs8Ob3Y6U2Px6PY2FgdO3ZM586dU8OGDSVZZ2kqACAwrJIL\nxYZeUNClj9euXas777xTkpSfn6+zZ8/6v2cAAMtwWiPzig+9O++8U927d9fRo0c1ZcoU7d+/X0lJ\nSWrXrl2g+gcAsICrotLr06ePWrZsqdDQUEVERGj//v3q1q2b4uPjA9U/AABKzGUvWYiOjva+j4yM\nVGRkpF87BACAv/h8RxYAAIpyVUxvAgDgi6tiIQsAAL6g0gMA2IZFMs+3RwsBAHA1oNIDAJhmlacs\nUOkBAGyDSg8AYFqgbxx9pQg9AIBpFpndJPQAAOZxTg8AgDKGSg8AYBoXpwMAbMMimcf0JgDAPqj0\nAACmMb0JALANfz1lwePxaMSIEdq9e7fcbrdGjhypqKgo7+dff/213n33XRmGoYYNG2r48OHFBjDT\nmwCAMmvVqlXKy8vT3LlzNWjQICUnJ3s/y8nJ0fjx4zV16lR99tlnqlmzpk6ePFlse1R6AADT/DW9\nuXnzZjVr1kyS1KhRI+3YscP72datWxUTE6OxY8fqwIEDevjhh1WlSpVi2yP0AACm+euUXk5OjkJD\nQ73bLpdLBQUFCgoK0smTJ7VhwwYtXrxYFSpUUI8ePdSoUSPVrl27yPaY3gQAmOZ0OK74VZzQ0FDl\n5uZ6tz0ej4KCLtVrlStX1i233KLw8HBVrFhRjRs31q5du4rvp/mhAgDgH3FxcUpLS5MkZWRkKCYm\nxvtZw4YN9d133+nEiRMqKCjQN998o7p16xbbHtObAADT/HVOLz4+Xunp6erevbsMw9Do0aM1c+ZM\nRUZGqmXLlho0aJCefPJJSVKbNm0KheJvIfQAAGWW0+lUUlJSoX3R0dHe9+3bt1f79u19bo/QAwCY\nZpFr0wk9AIB53JEFAGAbFsk8Qg8AYB4PkQUAoIwh9AAAtsH0JgDANIvMbhJ6AADzWL0JALANi2Qe\noQcAMM8qlR4LWQAAtkHoAQBsg+lNAIBpFpndJPQAAOZZ5Y4shB4AwDSLZB6hBwAwj9WbAACUMVR6\nAADTLFLoUekBAOyDSg8AYJpVzukRegAA0yySeYQeAMA8q1R6nNMDANgGlR4AwDSLFHqEHgDAPKY3\nAQAoY6j0AACmWaTQ82/opW+d48/mgYBofMsDpd0FoERsy/zab23zlAUAgG1YJPM4pwcAsA8qPQCA\naVZZvUnoAQBMs0jmMb0JALAPKj0AgGkOpzVKPUIPAGAa05sAAJQxVHoAANNYvQkAsA2LZB6hBwAw\nzyqVHuf0AAC2QaUHADDNIoUelR4AwD6o9AAA5lmk1CP0AACmWWUhC6EHADDNIplH6AEAzLPKvTdZ\nyAIAsA1CDwBgG0xvAgBM45weAMA2WL0JALANi2QeoQcAMM8qlR4LWQAAtkHoAQBsg+lNAIBpFpnd\npNIDAJjncDiu+FUcj8ejYcOGqVu3bkpISFBmZuZvHvPkk09qzpw5l+0noQcAMM9p4lWMVatWKS8v\nT3PnztWgQYOUnJz8q2MmTZqk7Oxsn7rJ9CYAwDR/rd7cvHmzmjVrJklq1KiRduzYUejzFStWyOFw\neI+5HCo9AECZlZOTo9DQUO+2y+VSQUGBJOm7777T0qVL9dxzz/ncHpUeAKDMCg0NVW5urnfb4/Eo\nKOhSdC1evFjHjh1TYmKiDh06pODgYNWsWVP33HNPke0RegAA0/y1ejMuLk6pqalq166dMjIyFBMT\n4/1s8ODB3vdvv/22qlatWmzgSYQeAKAE+OucXnx8vNLT09W9e3cZhqHRo0dr5syZioyMVMuWLX93\ne4QeAMA0f1V6TqdTSUlJhfZFR0f/6rj+/fv71B6hBwAwzyJXp7N6EwBgG1R6AADTHE4qPQAAyhQq\nPQCAaRY5pUfoAQDMs8pDZAk9AIBpFsk8zukBAOyDSg8AYJ5FSj1CDwBgGpcsAABQxlDpAQBMs8js\nJqEHACgBFkk9pjcBALZBpQcAMM0ihR6hBwAwzyqrNwk9AIBpVrkNGef0AAC2QaUHADDPGoUelR4A\nwD6o9AAAplnlnB6hBwAwjdADANiHRU6WEXoAANOsUulZJJsBADCP0AMA2AbTmwAA06wyvUnoAQDM\ns0bmEXoAAPO44TQAwD4sMr3JQhYAgG0QegAA22B6EwBgmkVmNwm9siTtH+ma9O5U5efl66abopX0\n6isKDa3o0zHnz1/QqHETtOPfu2R4DN1ycwMNHfyCypcv5/3u6exsdUvopYHP9lPrli0CPTzYVLMW\nTfTc4D5yu4P13bc/avjgscrNOVvomEd6PqBHHu+i8+cvaO+e/Rr12kRlnz4jp9OpV5IG6LYmf5Ak\n/SN1vVJGTSmNYeAyrHLJAtObZcSJkyf1WtIoTRw7Wn9f8FfdULOGJr3zns/HvD9zli5evKgFn36s\nBXM+1oULF/TBrI+93zUMQ0OHv6Gc3JyAjgv2dm2Va/TG+Jc18KnX1KlFgg7uP6wBL/ctdMztd96q\nXk89ot6PDlTXdk9qbep6DU9+QZLU4YHWujG6lh5s/YQebtNLt/2xkeLb3VsKI8FlOR1X/gpkNwP6\n01Ckdes3qmGD+oqKrCVJ6vbgA1q2YqUMw/DpmMa3NlKfXj3ldDrlcrkUWy9GR44e9X532oezFHNT\ntG6Kjg7swGBrd95zu3Zs+1b79x2SJM37y9/U7v5WhY5pcEuM1v9js44dPS5J+mpFmpq3vEtBwUFy\nuZwKCSkvtztYwW63goODlHchL+DjwOU5HI4rfgUSoVdGHD12TNUiIrzbEdeHKyc3V7m5Z3065q4m\nf9SNUZGSpMNHjugvc+Z5pzDXrd+gf23Zqqf79g7QaIBLqlW/XkcPZ3m3jx05rkphoaoYWsG7b0fG\nLt1xV5yq17z0u31/17Zyl3Or8rVh+ttnK5R9+oxWbVyg1ZsW6kDmIX391bqAjwNXD5/O6R04cECp\nqam6cOGCd1/v3vwBWpJ+WdH9ktPl/F3H7Nz1rQa8OESPdH1QzZvdrSNHj2r8pLc1/Z3JcrlcJdtp\n4DKczt/+e7Xnosf7fvPGbZo6eZYmvT9SHo9Hi+d9rlMnTys/r0BPDeipEydO697bOqt8+XKaNH2U\nHu/dVR9PnxeoIcBX1jil51ul169fP50+fVput9v7QsmqFhGh4z/95N3OOn5cYWGVVCEkxOdjPl/5\npfo885wGPPNn9X4iUZK0ctVqnT9/Xk89O1APPZqonbu+1Ztvvat5CxYFaGSwsyOHjyn8+uu829dX\nq6rTp7J17tx5774KFUP0r/XfqFv73nqkY199+fnXkqTTp7LVqk0zLZ63XAX5Bco5k6sl81fo9jtv\nDfg4cPXwqdKrXr26+vfv7+++2NpdTe7QhMlvK3P/AUVF1tK8BYt13z3NfD5m5VerlTxhot5/e5Ia\nNqjv/U7iY48q8bFHvdtP9H1aj3R9kNWbCIh/pm3SC0P7KfLGmtq/75Ae7tFJqSvTCx1zfURVTf/0\nTXVulajcnLPq++zj+nzJV5KkXTu+1/+2v0+b/rlVQUEu3Rt/t7Zt/XdpDAWXYZXVmw6jqDmzX5gz\nZ44OHTqkunXrevd17tz5so3nZf9srnc2k5a+TpPfnar8/HzVuqGmRo8YpoOHDmn4yGTN//SjIo+5\n5powtX+gq86cydH14eHe9hr94Ra9+tILhX4Goff7Nb7lgdLugqU1ve+Pem5wHwW7g3Ug85CGPj9a\nN0TW0IixL6pruyclSd0Tu6j7413kdDi05V/bNea1SbpwIU/XVA7TkKTnVL/hTbro8WhD+haljHxX\nBQUXS3lU1rQt82u/tX1g6fIr/m6tDu1KsCfF8yn0EhISVKdOHYWFhV36ksOhgQMHXrZxQg9XA0IP\nVwu/ht6yz6/4u7Xaty3BnhTPp+lNt9ut119/3d99AQBYlFWmN30KvRo1amjatGlq0KCBd2BNmzb1\na8cAAChpPoVeQUGB9u3bp3379nn3EXoAAC9rFHq+hd6YMWMKbWdlZRVxJAAAZZdPoTd58mTNmTNH\n+fn5On/+vG688UYtW7bM330DAFiEVZ6c7tPF6atXr1ZaWpo6duyo5cuXK+IXt8ICAEAOx5W/Asin\nSi88PFxut1u5ubmKiopSfn6+v/sFALAQq6ze9KnSq1atmubPn6+QkBClpKTozJkz/u4XAAAlrthK\nb/HixZKkW2+9VS6XSzExMTIMQ127dg1I5wAAFmGRc3rFht4PP/zgfb9s2TJ16NBBhmFYpowFAASG\nVXKh2NAbNGiQ931GRoZPtx4DAKCs8mkhi2SdFAcAlAKLRITPoQcAQFH8VRh5PB6NGDFCu3fvltvt\n1siRIxUVFeX9fNasWd7rxps3b65nnnmm2PaKDb2BAwfK4XDIMAzt2bOn0HRnSkqKmXEAAHBZq1at\nUl5enubOnauMjAwlJydrypQpkqQDBw5oyZIl+uyzz+R0OvXII4+oVatWio2NLbK9YkOve/fuv/ke\nAIBC/LR6c/PmzWrW7NLDshs1aqQdO3Z4P6tWrZo++OADuVwuSZfuE12uXLli2ys29O644w6z/QUA\n2IC/pjdzcnIUGhrq3Xa5XCooKFBQUJCCg4NVpUoVGYahcePGqUGDBqpdu3ax7XFODwBgnp9CLzQ0\nVLm5ud5tj8ejoKD/RNeFCxf0yiuvqGLFiho+fPhl2/PpjiwAAJSGuLg4paWlSbp06VxMTIz3M8Mw\n1K9fP9WrV09JSUneac7iUOkBAEzz1/RmfHy80tPT1b17dxmGodGjR2vmzJmKjIyUx+PRxo0blZeX\np7Vr10q6tADz1ltvLbI9Qg8AUGY5nU4lJSUV2hcdHe19v3379t/VHqEHADDvarj3JgAAvrDKXbsI\nPQCAeYQeAMAuHBaZ3uSSBQCAbRB6AADbYHoTAGAe5/QAAHbB6k0AgH0QegAAu2D1JgAAZQyhBwCw\nDaY3AQDmcU4PAGAbhB4AwC64ZAEAYB+s3gQAoGyh0gMAmOZwWKOGskYvAQAoAVR6AADzWMgCALAL\nVm8CAOyD1ZsAAJQtVHoAANOY3gQA2IdFQo/pTQCAbVDpAQDMs8jF6YQeAMA0npwOAEAZQ6UHADDP\nIgtZCD0AgGlcsgAAsA+LLGSxRi8BACgBVHoAANNYvQkAQBlDpQcAMI+FLAAAu2D1JgDAPiyyepPQ\nAwCYx0IWAADKFkIPAGAbTG8CAExjIQsAwD5YyAIAsAsqPQCAfVik0rNGLwEAKAGEHgDANpjeBACY\nZpWnLBB6AADzWMgCALALh0UWshB6AADzLFLpOQzDMEq7EwAABII16lEAAEoAoQcAsA1CDwBgG4Qe\nAMA2CD0AgG0QegAA2yD0yqD3339fPXv21GOPPaaEhATt2LGjtLsEXLENGzbo+eefL7RvwoQJWrhw\nYSn1CHbGxellzJ49e7R69WrNmTNHDodDu3bt0ksvvaQlS5aUdtcAwPIIvTKmUqVKOnz4sObPn697\n7rlH9evX1/z585WQkKDatWtr7969MgxDEydOVJUqVTRs2DAdPXpUWVlZatGihZ5//nm9/PLLCgoK\n0uHDh5WXl6d27dopNTVVR44c0XvvvafIyMjSHiagixcvaujQofz+IqCY3ixjIiIiNGXKFG3ZskXd\nunVTmzZtlJqaKkmKi4vT7Nmz1bZtW02bNk1HjhxRo0aN9OGHH2r+/Pn661//6m2nZs2amjFjhurU\nqaODBw9q+vTpat26tVavXl1aQ4ONrV+/XgkJCd7X0qVL5XK5+P1FwFHplTGZmZkKDQ3VmDFjJEnb\nt29X794dyjCVAAABJklEQVS9FR4eriZNmki6FH6rV69W5cqVtX37dq1fv16hoaHKy8vzttOgQQNJ\nUlhYmOrUqeN9/8tjgEBp0qSJJk6c6N2eMGGCcnJytGfPHn5/EVBUemXM7t27lZSU5P2fu3bt2goL\nC5PL5fIuaNmyZYvq1q2rhQsXqlKlSkpJSVGvXr10/vx5/f+tVB0Wufkr7I3fXwQalV4Z07p1a/3w\nww966KGHVKFCBRmGocGDB+ujjz7SokWLNGvWLIWEhGjcuHH66aefNGjQIGVkZMjtdisqKkpZWVml\nPQTAJy6XS2vXruX3FwHFUxYsIiEhQSNGjFB0dHRpdwUALIvpTQCAbVDpAQBsg0oPAGAbhB4AwDYI\nPQCAbRB6AADbIPQAALZB6AEAbOP/AEVQACDza3ttAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot heatmap\n", "lstm_cm_avg = np.zeros((2,2))\n", "for cm in lstm_cms:\n", " # turn cm into percentages\n", " totals = np.repeat(np.sum(cm, axis=1), 2, axis=0).reshape(2,2)\n", " cm_ = cm / totals / 3\n", " lstm_cm_avg = np.sum([lstm_cm_avg, cm_], axis=0)\n", " \n", "sns.heatmap(lstm_cm_avg, annot=True, xticklabels=labels, yticklabels=labels)\n", "plt.title('Heatmap of lstm')" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "From the heatmaps we can see that ham gets classified perfectly using both models, however our GRU model scores much better than the LSTM when classifying spam instances." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## NLTK tokenize vs keras tokenizer" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "We thought it could be interesting to compare the generalized NLTK tokenizer to the keras tokenizer. We decided to compare them using basic LSTM networks." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from nltk.tokenize import word_tokenize\n", "X_nltk = [word_tokenize(x) for x in X]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "encoder = {}\n", "counter = 0\n", "def encode_sentence(seq):\n", " global encoder, counter\n", " fseq = []\n", " for x in seq:\n", " if x not in encoder:\n", " encoder[x] = counter\n", " counter+=1\n", " fseq.append(encoder[x])\n", " return fseq\n", "\n", "X_nltk = [encode_sentence(x) for x in X]\n", "X_nltk = pad_sequences(X_nltk, maxlen=None)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "embedding_layer = Embedding(len(word_index) + 1,\n", " EMBED_SIZE,\n", " weights=[embedding_matrix],\n", " input_length=len(X_nltk[0]),\n", " trainable=False)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "embedding_2 (Embedding) (None, 910, 100) 900800 \n", "_________________________________________________________________\n", "lstm_13 (LSTM) (None, 100) 80400 \n", "_________________________________________________________________\n", "dense_25 (Dense) (None, 2) 202 \n", "=================================================================\n", "Total params: 981,402\n", "Trainable params: 80,602\n", "Non-trainable params: 900,800\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "rnn = Sequential()\n", "rnn.add(embedding_layer)\n", "rnn.add(LSTM(100,dropout=0.2, recurrent_dropout=0.2))\n", "rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", "print(rnn.summary())" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "X_train, X_test, y_train_ohe, y_test_ohe = train_test_split(X_nltk, y_ohe, test_size=0.2,\n", " stratify=y_ohe, \n", " random_state=42)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 95s - loss: 0.3282 - precision: 0.8947 - acc: 0.8767 - val_loss: 0.1758 - val_precision: 0.9685 - val_acc: 0.9408\n", "Epoch 2/3\n", "4459/4459 [==============================] - 93s - loss: 0.2239 - precision: 0.9454 - acc: 0.9206 - val_loss: 0.2723 - val_precision: 0.9344 - val_acc: 0.9076\n", "Epoch 3/3\n", "4459/4459 [==============================] - 93s - loss: 0.1768 - precision: 0.9608 - acc: 0.9477 - val_loss: 0.2153 - val_precision: 0.9479 - val_acc: 0.9471\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnn.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe), epochs=3, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "deletable": true, "editable": true }, "source": [ "# KerasGlove Published to PyPi\n", "I really liked being able to easily use glove embeddings with keras so I published a package to PyPi for it. It's available under kerasglove and removes the need for a lot of the code in the notebook. Here is a sample usage of it:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from kerasglove import GloveEmbedding\n", "EMBED_SIZE=100\n", "metrics = ['accuracy',precision]\n", "\n", "embed_layer = GloveEmbedding(\n", " EMBED_SIZE,\n", " MAX_TEXT_LEN,\n", " word_index)\n", "embed_layer" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "embedding_4 (Embedding) (None, 189, 100) 900800 \n", "_________________________________________________________________\n", "gru_12 (GRU) (None, 100) 60300 \n", "_________________________________________________________________\n", "dense_26 (Dense) (None, 2) 202 \n", "=================================================================\n", "Total params: 961,302\n", "Trainable params: 60,502\n", "Non-trainable params: 900,800\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.layers import LSTM\n", "from kerasglove import GloveEmbedding\n", "\n", "rnn = Sequential()\n", "rnn.add(GloveEmbedding(EMBED_SIZE,\n", " MAX_TEXT_LEN,\n", " word_index))\n", "rnn.add(GRU(100,dropout=0.2, recurrent_dropout=0.2))\n", "rnn.add(Dense(NUM_CLASSES, activation='sigmoid'))\n", "rnn.compile(loss='categorical_crossentropy', \n", " optimizer='rmsprop', \n", " metrics=metrics)\n", "print(rnn.summary())" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "X_train, X_test, y_train_ohe, y_test_ohe = train_test_split(sequences, y_ohe, test_size=0.2,\n", " stratify=y_ohe, \n", " random_state=42)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 4459 samples, validate on 1115 samples\n", "Epoch 1/3\n", "4459/4459 [==============================] - 21s - loss: 0.3050 - acc: 0.8872 - precision: 0.8751 - val_loss: 0.2898 - val_acc: 0.8897 - val_precision: 0.9084\n", "Epoch 2/3\n", "4459/4459 [==============================] - 19s - loss: 0.2419 - acc: 0.8962 - precision: 0.8936 - val_loss: 0.2526 - val_acc: 0.8933 - val_precision: 0.8888\n", "Epoch 3/3\n", "4459/4459 [==============================] - 19s - loss: 0.2360 - acc: 0.9002 - precision: 0.8948 - val_loss: 0.2538 - val_acc: 0.9013 - val_precision: 0.9122\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnn.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe), epochs=3, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "As we can see, this is far easier to construct a network with a pre trained GloVe emebedding than doing it manually.\n", "\n", "The full source is here:\n", "https://github.com/LukeWoodSMU/KerasGlove" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }