{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sentiment Analysis - Text Classification with Universal Embeddings\n", "\n", "Textual data in spite of being highly unstructured, can be classified into two major types of documents. \n", "- __Factual documents__ which typically depict some form of statements or facts with no specific feelings or emotion attached to them. These are also known as objective documents. \n", "- __Subjective documents__ on the other hand have text which expresses feelings, mood, emotions and opinion. \n", "\n", "Sentiment Analysis is also popularly known as opinion analysis or opinion mining. The key idea is to use techniques from text analytics, NLP, machine learning and linguistics to extract important information or data points from unstructured text. This in turn can help us derive the sentiment from text data\n", "\n", "![](sentiment_cover.png)\n", "\n", "Here we will be looking at building supervised sentiment analysis classification models thanks to the advantage of labeled data! The dataset we will be working with is the IMDB Large Movie Review Dataset having 50000 reviews classified into positive and negative sentiment. I have provided a compressed version of the dataset in this repository itself for your benefit!\n", "\n", "Do remember that the focus here is not sentiment analysis but text classification by leveraging universal sentence embeddings.\n", "\n", "![](sample_classification.png)\n", "\n", "We will leverage the following sentence encoders here for demonstration from [TensorFlow Hub](https://tfhub.dev/):\n", "\n", "- [__Neural-Net Language Model (nnlm-en-dim128)__](https://tfhub.dev/google/nnlm-en-dim128/1)\n", "- [__Universal Sentence Encoder (universal-sentence-encoder)__](https://tfhub.dev/google/universal-sentence-encoder/2)\n", "\n", "\n", "_Developed by [Dipanjan (DJ) Sarkar](https://www.linkedin.com/in/dipanzan/)_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Install Tensorflow Hub" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-hub\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5f/22/64f246ef80e64b1a13b2f463cefa44f397a51c49a303294f5f3d04ac39ac/tensorflow_hub-0.1.1-py2.py3-none-any.whl (52kB)\n", "\u001b[K 100% |################################| 61kB 8.5MB/s ta 0:00:011\n", "\u001b[?25hRequirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub) (1.14.3)\n", "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub) (1.11.0)\n", "Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub) (3.5.2.post1)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.4.0->tensorflow-hub) (39.1.0)\n", "Installing collected packages: tensorflow-hub\n", "Successfully installed tensorflow-hub-0.1.1\n", "\u001b[33mYou are using pip version 10.0.1, however version 18.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" ] } ], "source": [ "!pip install tensorflow-hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load up Dependencies" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Check if GPU is available for use!" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.test.is_gpu_available()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/device:GPU:0'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.test.gpu_device_name()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load and View Dataset" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 50000 entries, 0 to 49999\n", "Data columns (total 2 columns):\n", "review 50000 non-null object\n", "sentiment 50000 non-null object\n", "dtypes: object(2)\n", "memory usage: 781.3+ KB\n" ] } ], "source": [ "dataset = pd.read_csv('movie_reviews.csv.bz2', compression='bz2')\n", "dataset.info()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
reviewsentiment
0One of the other reviewers has mentioned that ...1
1A wonderful little production. <br /><br />The...1
2I thought this was a wonderful way to spend ti...1
3Basically there's a family where a little boy ...0
4Petter Mattei's \"Love in the Time of Money\" is...1
\n", "
" ], "text/plain": [ " review sentiment\n", "0 One of the other reviewers has mentioned that ... 1\n", "1 A wonderful little production.

The... 1\n", "2 I thought this was a wonderful way to spend ti... 1\n", "3 Basically there's a family where a little boy ... 0\n", "4 Petter Mattei's \"Love in the Time of Money\" is... 1" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset['sentiment'] = [1 if sentiment == 'positive' else 0 for sentiment in dataset['sentiment'].values]\n", "dataset.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Build train, validation and test datasets" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((30000,), (5000,), (15000,))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reviews = dataset['review'].values\n", "sentiments = dataset['sentiment'].values\n", "\n", "train_reviews = reviews[:30000]\n", "train_sentiments = sentiments[:30000]\n", "\n", "val_reviews = reviews[30000:35000]\n", "val_sentiments = sentiments[30000:35000]\n", "\n", "test_reviews = reviews[35000:]\n", "test_sentiments = sentiments[35000:]\n", "train_reviews.shape, val_reviews.shape, test_reviews.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic Text Wrangling" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: contractions in /usr/local/lib/python3.6/dist-packages (0.0.17)\n", "\u001b[33mYou are using pip version 10.0.1, however version 18.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.6/dist-packages (4.6.3)\n", "\u001b[33mYou are using pip version 10.0.1, however version 18.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" ] } ], "source": [ "!pip install contractions\n", "!pip install beautifulsoup4" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import contractions\n", "from bs4 import BeautifulSoup\n", "import unicodedata\n", "import re\n", "\n", "\n", "def strip_html_tags(text):\n", " soup = BeautifulSoup(text, \"html.parser\")\n", " [s.extract() for s in soup(['iframe', 'script'])]\n", " stripped_text = soup.get_text()\n", " stripped_text = re.sub(r'[\\r|\\n|\\r\\n]+', '\\n', stripped_text)\n", " return stripped_text\n", "\n", "\n", "def remove_accented_chars(text):\n", " text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')\n", " return text\n", "\n", "\n", "def expand_contractions(text):\n", " return contractions.fix(text)\n", "\n", "\n", "\n", "def remove_special_characters(text, remove_digits=False):\n", " pattern = r'[^a-zA-Z0-9\\s]' if not remove_digits else r'[^a-zA-Z\\s]'\n", " text = re.sub(pattern, '', text)\n", " return text\n", "\n", "\n", "def pre_process_document(document):\n", " \n", " # strip HTML\n", " document = strip_html_tags(document)\n", " \n", " # lower case\n", " document = document.lower()\n", " \n", " # remove extra newlines (often might be present in really noisy text)\n", " document = document.translate(document.maketrans(\"\\n\\t\\r\", \" \"))\n", " \n", " # remove accented characters\n", " document = remove_accented_chars(document)\n", " \n", " # expand contractions \n", " document = expand_contractions(document)\n", " \n", " # remove special characters and\\or digits \n", " # insert spaces between special characters to isolate them \n", " special_char_pattern = re.compile(r'([{.(-)!}])')\n", " document = special_char_pattern.sub(\" \\\\1 \", document)\n", " document = remove_special_characters(document, remove_digits=True) \n", " \n", " # remove extra whitespace\n", " document = re.sub(' +', ' ', document)\n", " document = document.strip()\n", " \n", " return document\n", "\n", "\n", "pre_process_corpus = np.vectorize(pre_process_document)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "train_reviews = pre_process_corpus(train_reviews)\n", "val_reviews = pre_process_corpus(val_reviews)\n", "test_reviews = pre_process_corpus(test_reviews)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Build Data Ingestion Functions" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Training input on the whole training set with no limit on training epochs.\n", "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", " {'sentence': train_reviews}, train_sentiments, \n", " batch_size=256, num_epochs=None, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# Prediction on the whole training set.\n", "predict_train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", " {'sentence': train_reviews}, train_sentiments, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Prediction on the whole validation set.\n", "predict_val_input_fn = tf.estimator.inputs.numpy_input_fn(\n", " {'sentence': val_reviews}, val_sentiments, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# Prediction on the test set.\n", "predict_test_input_fn = tf.estimator.inputs.numpy_input_fn(\n", " {'sentence': test_reviews}, test_sentiments, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Build Deep Learning Model with Universal Sentence Encoder" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using /tmp/tfhub_modules to cache modules.\n" ] } ], "source": [ "embedding_feature = hub.text_embedding_column(\n", " key='sentence', \n", " module_spec=\"https://tfhub.dev/google/universal-sentence-encoder/2\",\n", " trainable=False)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n", "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpn9bphscn\n", "INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn9bphscn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] } ], "source": [ "dnn = tf.estimator.DNNClassifier(\n", " hidden_units=[512, 128],\n", " feature_columns=[embedding_feature],\n", " n_classes=2,\n", " activation_fn=tf.nn.relu,\n", " dropout=0.1,\n", " optimizer=tf.train.AdagradOptimizer(learning_rate=0.005))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train for approx 12 epochs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12.8" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "256*1500 / 30000" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Training" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 0\n", "Train Time (s): 78.62789511680603\n", "Eval Metrics (Train): {'accuracy': 0.84863335, 'accuracy_baseline': 0.5005, 'auc': 0.9279859, 'auc_precision_recall': 0.92819566, 'average_loss': 0.34581015, 'label/mean': 0.5005, 'loss': 44.145977, 'precision': 0.86890674, 'prediction/mean': 0.47957155, 'recall': 0.8215118, 'global_step': 100}\n", "Eval Metrics (Validation): {'accuracy': 0.8454, 'accuracy_baseline': 0.505, 'auc': 0.92413086, 'auc_precision_recall': 0.9200026, 'average_loss': 0.35258815, 'label/mean': 0.495, 'loss': 44.073517, 'precision': 0.8522351, 'prediction/mean': 0.48447067, 'recall': 0.8319192, 'global_step': 100}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 100\n", "Train Time (s): 76.1651611328125\n", "Eval Metrics (Train): {'accuracy': 0.85436666, 'accuracy_baseline': 0.5005, 'auc': 0.9321357, 'auc_precision_recall': 0.93224275, 'average_loss': 0.3330773, 'label/mean': 0.5005, 'loss': 42.520508, 'precision': 0.8501513, 'prediction/mean': 0.5098621, 'recall': 0.86073923, 'global_step': 200}\n", "Eval Metrics (Validation): {'accuracy': 0.8494, 'accuracy_baseline': 0.505, 'auc': 0.92772096, 'auc_precision_recall': 0.92323804, 'average_loss': 0.34418356, 'label/mean': 0.495, 'loss': 43.022945, 'precision': 0.83501947, 'prediction/mean': 0.5149463, 'recall': 0.86707073, 'global_step': 200}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 200\n", "Train Time (s): 76.43350577354431\n", "Eval Metrics (Train): {'accuracy': 0.8561, 'accuracy_baseline': 0.5005, 'auc': 0.93400496, 'auc_precision_recall': 0.93414706, 'average_loss': 0.32816008, 'label/mean': 0.5005, 'loss': 41.892776, 'precision': 0.85856014, 'prediction/mean': 0.49918926, 'recall': 0.85301363, 'global_step': 300}\n", "Eval Metrics (Validation): {'accuracy': 0.8508, 'accuracy_baseline': 0.505, 'auc': 0.9290328, 'auc_precision_recall': 0.9250712, 'average_loss': 0.33986613, 'label/mean': 0.495, 'loss': 42.48327, 'precision': 0.84319174, 'prediction/mean': 0.50402075, 'recall': 0.85818183, 'global_step': 300}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 300\n", "Train Time (s): 78.08082580566406\n", "Eval Metrics (Train): {'accuracy': 0.8584, 'accuracy_baseline': 0.5005, 'auc': 0.93560475, 'auc_precision_recall': 0.9356024, 'average_loss': 0.3264787, 'label/mean': 0.5005, 'loss': 41.678127, 'precision': 0.84423554, 'prediction/mean': 0.522159, 'recall': 0.8793207, 'global_step': 400}\n", "Eval Metrics (Validation): {'accuracy': 0.8494, 'accuracy_baseline': 0.505, 'auc': 0.929911, 'auc_precision_recall': 0.9256854, 'average_loss': 0.34194976, 'label/mean': 0.495, 'loss': 42.74372, 'precision': 0.82564294, 'prediction/mean': 0.5267772, 'recall': 0.8820202, 'global_step': 400}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 400\n", "Train Time (s): 78.66700315475464\n", "Eval Metrics (Train): {'accuracy': 0.8597, 'accuracy_baseline': 0.5005, 'auc': 0.9370333, 'auc_precision_recall': 0.93680507, 'average_loss': 0.32163766, 'label/mean': 0.5005, 'loss': 41.060127, 'precision': 0.8540629, 'prediction/mean': 0.51153994, 'recall': 0.86799866, 'global_step': 500}\n", "Eval Metrics (Validation): {'accuracy': 0.851, 'accuracy_baseline': 0.505, 'auc': 0.93050015, 'auc_precision_recall': 0.92593473, 'average_loss': 0.33805788, 'label/mean': 0.495, 'loss': 42.257233, 'precision': 0.83579195, 'prediction/mean': 0.5157132, 'recall': 0.869899, 'global_step': 500}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 500\n", "Train Time (s): 81.03562021255493\n", "Eval Metrics (Train): {'accuracy': 0.8628, 'accuracy_baseline': 0.5005, 'auc': 0.9386675, 'auc_precision_recall': 0.93854874, 'average_loss': 0.31690273, 'label/mean': 0.5005, 'loss': 40.45567, 'precision': 0.86361516, 'prediction/mean': 0.5004919, 'recall': 0.86200464, 'global_step': 600}\n", "Eval Metrics (Validation): {'accuracy': 0.8562, 'accuracy_baseline': 0.505, 'auc': 0.9315411, 'auc_precision_recall': 0.92733943, 'average_loss': 0.3339314, 'label/mean': 0.495, 'loss': 41.741425, 'precision': 0.8481364, 'prediction/mean': 0.5051916, 'recall': 0.86424243, 'global_step': 600}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 600\n", "Train Time (s): 82.44215893745422\n", "Eval Metrics (Train): {'accuracy': 0.86523336, 'accuracy_baseline': 0.5005, 'auc': 0.94012344, 'auc_precision_recall': 0.93995917, 'average_loss': 0.31317216, 'label/mean': 0.5005, 'loss': 39.979427, 'precision': 0.86268675, 'prediction/mean': 0.50560397, 'recall': 0.8690643, 'global_step': 700}\n", "Eval Metrics (Validation): {'accuracy': 0.8572, 'accuracy_baseline': 0.505, 'auc': 0.9326249, 'auc_precision_recall': 0.9284968, 'average_loss': 0.33208466, 'label/mean': 0.495, 'loss': 41.510582, 'precision': 0.84651715, 'prediction/mean': 0.51005256, 'recall': 0.8690909, 'global_step': 700}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 700\n", "Train Time (s): 81.71217703819275\n", "Eval Metrics (Train): {'accuracy': 0.8632333, 'accuracy_baseline': 0.5005, 'auc': 0.94086003, 'auc_precision_recall': 0.9406356, 'average_loss': 0.31374687, 'label/mean': 0.5005, 'loss': 40.05279, 'precision': 0.88191235, 'prediction/mean': 0.4772521, 'recall': 0.8390942, 'global_step': 800}\n", "Eval Metrics (Validation): {'accuracy': 0.854, 'accuracy_baseline': 0.505, 'auc': 0.9327511, 'auc_precision_recall': 0.9289065, 'average_loss': 0.33162636, 'label/mean': 0.495, 'loss': 41.453293, 'precision': 0.86128366, 'prediction/mean': 0.48188075, 'recall': 0.84040403, 'global_step': 800}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 800\n", "Train Time (s): 83.28980422019958\n", "Eval Metrics (Train): {'accuracy': 0.8681667, 'accuracy_baseline': 0.5005, 'auc': 0.94270587, 'auc_precision_recall': 0.94251335, 'average_loss': 0.30656302, 'label/mean': 0.5005, 'loss': 39.135704, 'precision': 0.86419916, 'prediction/mean': 0.5073946, 'recall': 0.8739261, 'global_step': 900}\n", "Eval Metrics (Validation): {'accuracy': 0.8586, 'accuracy_baseline': 0.505, 'auc': 0.933738, 'auc_precision_recall': 0.9295292, 'average_loss': 0.32951483, 'label/mean': 0.495, 'loss': 41.189354, 'precision': 0.84612375, 'prediction/mean': 0.5115125, 'recall': 0.87313133, 'global_step': 900}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 900\n", "Train Time (s): 82.47694706916809\n", "Eval Metrics (Train): {'accuracy': 0.86903334, 'accuracy_baseline': 0.5005, 'auc': 0.9443012, 'auc_precision_recall': 0.94409984, 'average_loss': 0.3057844, 'label/mean': 0.5005, 'loss': 39.03631, 'precision': 0.8514456, 'prediction/mean': 0.5249473, 'recall': 0.8943723, 'global_step': 1000}\n", "Eval Metrics (Validation): {'accuracy': 0.8566, 'accuracy_baseline': 0.505, 'auc': 0.9347619, 'auc_precision_recall': 0.93128407, 'average_loss': 0.33120963, 'label/mean': 0.495, 'loss': 41.401203, 'precision': 0.829955, 'prediction/mean': 0.5288649, 'recall': 0.8933333, 'global_step': 1000}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1000\n", "Train Time (s): 84.39996337890625\n", "Eval Metrics (Train): {'accuracy': 0.8716, 'accuracy_baseline': 0.5005, 'auc': 0.9453286, 'auc_precision_recall': 0.94506085, 'average_loss': 0.30098185, 'label/mean': 0.5005, 'loss': 38.423214, 'precision': 0.8595169, 'prediction/mean': 0.51834613, 'recall': 0.8887113, 'global_step': 1100}\n", "Eval Metrics (Validation): {'accuracy': 0.859, 'accuracy_baseline': 0.505, 'auc': 0.9349775, 'auc_precision_recall': 0.93106484, 'average_loss': 0.3288155, 'label/mean': 0.495, 'loss': 41.101936, 'precision': 0.83727133, 'prediction/mean': 0.5222167, 'recall': 0.8876768, 'global_step': 1100}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1100\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Time (s): 84.88116145133972\n", "Eval Metrics (Train): {'accuracy': 0.87226665, 'accuracy_baseline': 0.5005, 'auc': 0.94644105, 'auc_precision_recall': 0.94621253, 'average_loss': 0.2976978, 'label/mean': 0.5005, 'loss': 38.00398, 'precision': 0.88602, 'prediction/mean': 0.4845446, 'recall': 0.85474527, 'global_step': 1200}\n", "Eval Metrics (Validation): {'accuracy': 0.8612, 'accuracy_baseline': 0.505, 'auc': 0.9357711, 'auc_precision_recall': 0.93195754, 'average_loss': 0.32375482, 'label/mean': 0.495, 'loss': 40.469353, 'precision': 0.86272913, 'prediction/mean': 0.48821172, 'recall': 0.8557576, 'global_step': 1200}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1200\n", "Train Time (s): 85.49790549278259\n", "Eval Metrics (Train): {'accuracy': 0.8763, 'accuracy_baseline': 0.5005, 'auc': 0.94796634, 'auc_precision_recall': 0.9477355, 'average_loss': 0.29294312, 'label/mean': 0.5005, 'loss': 37.396996, 'precision': 0.87395793, 'prediction/mean': 0.5043318, 'recall': 0.8797203, 'global_step': 1300}\n", "Eval Metrics (Validation): {'accuracy': 0.861, 'accuracy_baseline': 0.505, 'auc': 0.9365514, 'auc_precision_recall': 0.9326272, 'average_loss': 0.32206511, 'label/mean': 0.495, 'loss': 40.25814, 'precision': 0.8526149, 'prediction/mean': 0.50739485, 'recall': 0.869495, 'global_step': 1300}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1300\n", "Train Time (s): 85.8064513206482\n", "Eval Metrics (Train): {'accuracy': 0.87743336, 'accuracy_baseline': 0.5005, 'auc': 0.9490614, 'auc_precision_recall': 0.94886625, 'average_loss': 0.29016578, 'label/mean': 0.5005, 'loss': 37.042442, 'precision': 0.87052286, 'prediction/mean': 0.5104429, 'recall': 0.8870463, 'global_step': 1400}\n", "Eval Metrics (Validation): {'accuracy': 0.8622, 'accuracy_baseline': 0.505, 'auc': 0.9367896, 'auc_precision_recall': 0.9333618, 'average_loss': 0.3225043, 'label/mean': 0.495, 'loss': 40.31304, 'precision': 0.8474708, 'prediction/mean': 0.514252, 'recall': 0.88, 'global_step': 1400}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1400\n", "Train Time (s): 85.99037742614746\n", "Eval Metrics (Train): {'accuracy': 0.8783, 'accuracy_baseline': 0.5005, 'auc': 0.9500882, 'auc_precision_recall': 0.94986326, 'average_loss': 0.28882334, 'label/mean': 0.5005, 'loss': 36.871063, 'precision': 0.865308, 'prediction/mean': 0.5196238, 'recall': 0.8963703, 'global_step': 1500}\n", "Eval Metrics (Validation): {'accuracy': 0.8626, 'accuracy_baseline': 0.505, 'auc': 0.93708724, 'auc_precision_recall': 0.9336051, 'average_loss': 0.32389137, 'label/mean': 0.495, 'loss': 40.486423, 'precision': 0.84044176, 'prediction/mean': 0.5226699, 'recall': 0.8917172, 'global_step': 1500}\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1500\n", "Train Time (s): 86.91469407081604\n", "Eval Metrics (Train): {'accuracy': 0.8802, 'accuracy_baseline': 0.5005, 'auc': 0.95115364, 'auc_precision_recall': 0.950775, 'average_loss': 0.2844779, 'label/mean': 0.5005, 'loss': 36.316326, 'precision': 0.8735527, 'prediction/mean': 0.51057553, 'recall': 0.8893773, 'global_step': 1600}\n", "Eval Metrics (Validation): {'accuracy': 0.8626, 'accuracy_baseline': 0.505, 'auc': 0.9373224, 'auc_precision_recall': 0.9336302, 'average_loss': 0.32108024, 'label/mean': 0.495, 'loss': 40.135033, 'precision': 0.8478599, 'prediction/mean': 0.5134171, 'recall': 0.88040406, 'global_step': 1600}\n" ] } ], "source": [ "tf.logging.set_verbosity(tf.logging.ERROR)\n", "import time\n", "\n", "TOTAL_STEPS = 1500\n", "STEP_SIZE = 100\n", "for step in range(0, TOTAL_STEPS+1, STEP_SIZE):\n", " print()\n", " print('-'*100)\n", " print('Training for step =', step)\n", " start_time = time.time()\n", " dnn.train(input_fn=train_input_fn, steps=STEP_SIZE)\n", " elapsed_time = time.time() - start_time\n", " print('Train Time (s):', elapsed_time)\n", " print('Eval Metrics (Train):', dnn.evaluate(input_fn=predict_train_input_fn))\n", " print('Eval Metrics (Validation):', dnn.evaluate(input_fn=predict_val_input_fn))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Evaluation" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 0.8802,\n", " 'accuracy_baseline': 0.5005,\n", " 'auc': 0.95115364,\n", " 'auc_precision_recall': 0.950775,\n", " 'average_loss': 0.2844779,\n", " 'label/mean': 0.5005,\n", " 'loss': 36.316326,\n", " 'precision': 0.8735527,\n", " 'prediction/mean': 0.51057553,\n", " 'recall': 0.8893773,\n", " 'global_step': 1600}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dnn.evaluate(input_fn=predict_train_input_fn)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 0.8663333,\n", " 'accuracy_baseline': 0.5006667,\n", " 'auc': 0.9406502,\n", " 'auc_precision_recall': 0.93988097,\n", " 'average_loss': 0.31214723,\n", " 'label/mean': 0.5006667,\n", " 'loss': 39.679733,\n", " 'precision': 0.8597569,\n", " 'prediction/mean': 0.5120608,\n", " 'recall': 0.8758988,\n", " 'global_step': 1600}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dnn.evaluate(input_fn=predict_test_input_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Build a Generic Model Trainer on any Input Sentence Encoder" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "TOTAL_STEPS = 1500\n", "STEP_SIZE = 500\n", "\n", "my_checkpointing_config = tf.estimator.RunConfig(\n", " keep_checkpoint_max = 2, # Retain the 2 most recent checkpoints.\n", ")\n", "\n", "def train_and_evaluate_with_sentence_encoder(hub_module, train_module=False, path=''):\n", " embedding_feature = hub.text_embedding_column(\n", " key='sentence', module_spec=hub_module, trainable=train_module)\n", " \n", " print()\n", " print('='*100)\n", " print('Training with', hub_module)\n", " print('Trainable is:', train_module)\n", " print('='*100)\n", " \n", " dnn = tf.estimator.DNNClassifier(\n", " hidden_units=[512, 128],\n", " feature_columns=[embedding_feature],\n", " n_classes=2,\n", " activation_fn=tf.nn.relu,\n", " dropout=0.1,\n", " optimizer=tf.train.AdagradOptimizer(learning_rate=0.005),\n", " model_dir=path,\n", " config=my_checkpointing_config)\n", "\n", " for step in range(0, TOTAL_STEPS+1, STEP_SIZE):\n", " print('-'*100)\n", " print('Training for step =', step)\n", " start_time = time.time()\n", " dnn.train(input_fn=train_input_fn, steps=STEP_SIZE)\n", " elapsed_time = time.time() - start_time\n", " print('Train Time (s):', elapsed_time)\n", " print('Eval Metrics (Train):', dnn.evaluate(input_fn=predict_train_input_fn))\n", " print('Eval Metrics (Validation):', dnn.evaluate(input_fn=predict_val_input_fn))\n", "\n", " train_eval_result = dnn.evaluate(input_fn=predict_train_input_fn)\n", " test_eval_result = dnn.evaluate(input_fn=predict_test_input_fn)\n", "\n", " return {\n", " \"Model Dir\": dnn.model_dir,\n", " \"Training Accuracy\": train_eval_result[\"accuracy\"],\n", " \"Test Accuracy\": test_eval_result[\"accuracy\"],\n", " \"Training AUC\": train_eval_result[\"auc\"],\n", " \"Test AUC\": test_eval_result[\"auc\"],\n", " \"Training Precision\": train_eval_result[\"precision\"],\n", " \"Test Precision\": test_eval_result[\"precision\"],\n", " \"Training Recall\": train_eval_result[\"recall\"],\n", " \"Test Recall\": test_eval_result[\"recall\"]\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train Deep Learning Models on difference Sentence Encoders\n", "- NNLM - pre-trained and fine-tuning\n", "- USE - pre-trained and fine-tuning" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "====================================================================================================\n", "Training with https://tfhub.dev/google/nnlm-en-dim128/1\n", "Trainable is: False\n", "====================================================================================================\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 0\n", "Train Time (s): 30.525171756744385\n", "Eval Metrics (Train): {'accuracy': 0.8480667, 'accuracy_baseline': 0.5005, 'auc': 0.9287864, 'auc_precision_recall': 0.9287345, 'average_loss': 0.34465897, 'label/mean': 0.5005, 'loss': 43.99902, 'precision': 0.8288572, 'prediction/mean': 0.5302467, 'recall': 0.8776557, 'global_step': 2500}\n", "Eval Metrics (Validation): {'accuracy': 0.8288, 'accuracy_baseline': 0.505, 'auc': 0.91452694, 'auc_precision_recall': 0.9113482, 'average_loss': 0.37722248, 'label/mean': 0.495, 'loss': 47.15281, 'precision': 0.7999259, 'prediction/mean': 0.53336626, 'recall': 0.8723232, 'global_step': 2500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 500\n", "Train Time (s): 27.883334159851074\n", "Eval Metrics (Train): {'accuracy': 0.8558, 'accuracy_baseline': 0.5005, 'auc': 0.93249583, 'auc_precision_recall': 0.93237436, 'average_loss': 0.33365154, 'label/mean': 0.5005, 'loss': 42.59381, 'precision': 0.87215376, 'prediction/mean': 0.48501322, 'recall': 0.8341658, 'global_step': 3000}\n", "Eval Metrics (Validation): {'accuracy': 0.8356, 'accuracy_baseline': 0.505, 'auc': 0.91525424, 'auc_precision_recall': 0.9126499, 'average_loss': 0.36975703, 'label/mean': 0.495, 'loss': 46.219627, 'precision': 0.8397041, 'prediction/mean': 0.48800987, 'recall': 0.82545453, 'global_step': 3000}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1000\n", "Train Time (s): 28.585613012313843\n", "Eval Metrics (Train): {'accuracy': 0.8552667, 'accuracy_baseline': 0.5005, 'auc': 0.9354528, 'auc_precision_recall': 0.9354019, 'average_loss': 0.33256087, 'label/mean': 0.5005, 'loss': 42.45458, 'precision': 0.89311236, 'prediction/mean': 0.4612543, 'recall': 0.80745924, 'global_step': 3500}\n", "Eval Metrics (Validation): {'accuracy': 0.8348, 'accuracy_baseline': 0.505, 'auc': 0.91587234, 'auc_precision_recall': 0.9131782, 'average_loss': 0.3731109, 'label/mean': 0.495, 'loss': 46.638863, 'precision': 0.8573905, 'prediction/mean': 0.4640141, 'recall': 0.7991919, 'global_step': 3500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1500\n", "Train Time (s): 28.242169618606567\n", "Eval Metrics (Train): {'accuracy': 0.8616, 'accuracy_baseline': 0.5005, 'auc': 0.9385461, 'auc_precision_recall': 0.9383698, 'average_loss': 0.32105166, 'label/mean': 0.5005, 'loss': 40.985317, 'precision': 0.8443543, 'prediction/mean': 0.52729076, 'recall': 0.8869797, 'global_step': 4000}\n", "Eval Metrics (Validation): {'accuracy': 0.828, 'accuracy_baseline': 0.505, 'auc': 0.91572505, 'auc_precision_recall': 0.91319984, 'average_loss': 0.37505153, 'label/mean': 0.495, 'loss': 46.88144, 'precision': 0.80322945, 'prediction/mean': 0.53075755, 'recall': 0.86424243, 'global_step': 4000}\n", "\n", "====================================================================================================\n", "Training with https://tfhub.dev/google/nnlm-en-dim128/1\n", "Trainable is: True\n", "====================================================================================================\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 0\n", "Train Time (s): 45.97756814956665\n", "Eval Metrics (Train): {'accuracy': 0.9997, 'accuracy_baseline': 0.5005, 'auc': 0.9998141, 'auc_precision_recall': 0.99985945, 'average_loss': 0.0038648716, 'label/mean': 0.5005, 'loss': 0.49338785, 'precision': 0.99980015, 'prediction/mean': 0.5008926, 'recall': 0.9996004, 'global_step': 2500}\n", "Eval Metrics (Validation): {'accuracy': 0.877, 'accuracy_baseline': 0.505, 'auc': 0.9225529, 'auc_precision_recall': 0.9297111, 'average_loss': 0.67985016, 'label/mean': 0.495, 'loss': 84.98127, 'precision': 0.86671925, 'prediction/mean': 0.50768346, 'recall': 0.88808084, 'global_step': 2500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 500\n", "Train Time (s): 44.74027681350708\n", "Eval Metrics (Train): {'accuracy': 0.99986666, 'accuracy_baseline': 0.5005, 'auc': 0.9999234, 'auc_precision_recall': 0.9999373, 'average_loss': 0.002121097, 'label/mean': 0.5005, 'loss': 0.27077833, 'precision': 0.9998668, 'prediction/mean': 0.5002479, 'recall': 0.9998668, 'global_step': 3000}\n", "Eval Metrics (Validation): {'accuracy': 0.8764, 'accuracy_baseline': 0.505, 'auc': 0.9195744, 'auc_precision_recall': 0.9288729, 'average_loss': 0.74133915, 'label/mean': 0.495, 'loss': 92.6674, 'precision': 0.8742443, 'prediction/mean': 0.4980622, 'recall': 0.87636364, 'global_step': 3000}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1000\n", "Train Time (s): 45.068076372146606\n", "Eval Metrics (Train): {'accuracy': 0.9999667, 'accuracy_baseline': 0.5005, 'auc': 1.0, 'auc_precision_recall': 1.0, 'average_loss': 0.0009565035, 'label/mean': 0.5005, 'loss': 0.12210683, 'precision': 1.0, 'prediction/mean': 0.5007308, 'recall': 0.9999334, 'global_step': 3500}\n", "Eval Metrics (Validation): {'accuracy': 0.8748, 'accuracy_baseline': 0.505, 'auc': 0.9156478, 'auc_precision_recall': 0.9250907, 'average_loss': 0.8001606, 'label/mean': 0.495, 'loss': 100.02007, 'precision': 0.86584884, 'prediction/mean': 0.50656, 'recall': 0.8840404, 'global_step': 3500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1500\n", "Train Time (s): 44.654765605926514\n", "Eval Metrics (Train): {'accuracy': 1.0, 'accuracy_baseline': 0.5005, 'auc': 1.0, 'auc_precision_recall': 1.0, 'average_loss': 0.00060436723, 'label/mean': 0.5005, 'loss': 0.077153265, 'precision': 1.0, 'prediction/mean': 0.50068456, 'recall': 1.0, 'global_step': 4000}\n", "Eval Metrics (Validation): {'accuracy': 0.875, 'accuracy_baseline': 0.505, 'auc': 0.91479605, 'auc_precision_recall': 0.9244194, 'average_loss': 0.8459253, 'label/mean': 0.495, 'loss': 105.74066, 'precision': 0.8661916, 'prediction/mean': 0.5066238, 'recall': 0.8840404, 'global_step': 4000}\n", "\n", "====================================================================================================\n", "Training with https://tfhub.dev/google/universal-sentence-encoder/2\n", "Trainable is: False\n", "====================================================================================================\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 0\n", "Train Time (s): 261.7671597003937\n", "Eval Metrics (Train): {'accuracy': 0.8591, 'accuracy_baseline': 0.5005, 'auc': 0.9373971, 'auc_precision_recall': 0.93691623, 'average_loss': 0.3231426, 'label/mean': 0.5005, 'loss': 41.252243, 'precision': 0.8820655, 'prediction/mean': 0.47581005, 'recall': 0.8293706, 'global_step': 501}\n", "Eval Metrics (Validation): {'accuracy': 0.8522, 'accuracy_baseline': 0.505, 'auc': 0.93081224, 'auc_precision_recall': 0.9264202, 'average_loss': 0.33680823, 'label/mean': 0.495, 'loss': 42.10103, 'precision': 0.8631799, 'prediction/mean': 0.47982788, 'recall': 0.8335354, 'global_step': 501}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 500\n", "Train Time (s): 259.56947469711304\n", "Eval Metrics (Train): {'accuracy': 0.8696333, 'accuracy_baseline': 0.5005, 'auc': 0.9444929, 'auc_precision_recall': 0.94410896, 'average_loss': 0.30284137, 'label/mean': 0.5005, 'loss': 38.660603, 'precision': 0.8605663, 'prediction/mean': 0.5134329, 'recall': 0.88251746, 'global_step': 1001}\n", "Eval Metrics (Validation): {'accuracy': 0.8608, 'accuracy_baseline': 0.505, 'auc': 0.93478435, 'auc_precision_recall': 0.93100405, 'average_loss': 0.32802072, 'label/mean': 0.495, 'loss': 41.00259, 'precision': 0.8446339, 'prediction/mean': 0.5172887, 'recall': 0.88080806, 'global_step': 1001}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Time (s): 258.70958161354065\n", "Eval Metrics (Train): {'accuracy': 0.8782333, 'accuracy_baseline': 0.5005, 'auc': 0.9505533, 'auc_precision_recall': 0.94987434, 'average_loss': 0.28575605, 'label/mean': 0.5005, 'loss': 36.479496, 'precision': 0.8756281, 'prediction/mean': 0.5043332, 'recall': 0.8819847, 'global_step': 1501}\n", "Eval Metrics (Validation): {'accuracy': 0.8616, 'accuracy_baseline': 0.505, 'auc': 0.9369738, 'auc_precision_recall': 0.9333807, 'average_loss': 0.32096896, 'label/mean': 0.495, 'loss': 40.121117, 'precision': 0.8505702, 'prediction/mean': 0.5080113, 'recall': 0.8739394, 'global_step': 1501}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1500\n", "Train Time (s): 258.4421606063843\n", "Eval Metrics (Train): {'accuracy': 0.88733333, 'accuracy_baseline': 0.5005, 'auc': 0.9558296, 'auc_precision_recall': 0.95508415, 'average_loss': 0.2716801, 'label/mean': 0.5005, 'loss': 34.682564, 'precision': 0.8979955, 'prediction/mean': 0.4882649, 'recall': 0.8741925, 'global_step': 2001}\n", "Eval Metrics (Validation): {'accuracy': 0.864, 'accuracy_baseline': 0.505, 'auc': 0.938815, 'auc_precision_recall': 0.9357392, 'average_loss': 0.31562653, 'label/mean': 0.495, 'loss': 39.453316, 'precision': 0.864393, 'prediction/mean': 0.49126464, 'recall': 0.860202, 'global_step': 2001}\n", "\n", "====================================================================================================\n", "Training with https://tfhub.dev/google/universal-sentence-encoder/2\n", "Trainable is: True\n", "====================================================================================================\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 0\n", "Train Time (s): 313.1993100643158\n", "Eval Metrics (Train): {'accuracy': 0.99916667, 'accuracy_baseline': 0.5005, 'auc': 0.9996535, 'auc_precision_recall': 0.9996587, 'average_loss': 0.0054427227, 'label/mean': 0.5005, 'loss': 0.69481564, 'precision': 0.9989349, 'prediction/mean': 0.50010633, 'recall': 0.9994006, 'global_step': 500}\n", "Eval Metrics (Validation): {'accuracy': 0.9056, 'accuracy_baseline': 0.505, 'auc': 0.95068294, 'auc_precision_recall': 0.95441175, 'average_loss': 0.40755096, 'label/mean': 0.495, 'loss': 50.94387, 'precision': 0.9020474, 'prediction/mean': 0.4965181, 'recall': 0.9078788, 'global_step': 500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 500\n", "Train Time (s): 306.61662435531616\n", "Eval Metrics (Train): {'accuracy': 0.9999667, 'accuracy_baseline': 0.5005, 'auc': 0.9999665, 'auc_precision_recall': 0.9999665, 'average_loss': 0.00048276514, 'label/mean': 0.5005, 'loss': 0.061629593, 'precision': 0.9999334, 'prediction/mean': 0.50048435, 'recall': 1.0, 'global_step': 1000}\n", "Eval Metrics (Validation): {'accuracy': 0.9024, 'accuracy_baseline': 0.505, 'auc': 0.93500155, 'auc_precision_recall': 0.94411886, 'average_loss': 0.5513662, 'label/mean': 0.495, 'loss': 68.92078, 'precision': 0.892843, 'prediction/mean': 0.50568086, 'recall': 0.91232324, 'global_step': 1000}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1000\n", "Train Time (s): 306.62883472442627\n", "Eval Metrics (Train): {'accuracy': 0.9999667, 'accuracy_baseline': 0.5005, 'auc': 0.9999666, 'auc_precision_recall': 0.9999666, 'average_loss': 0.00030994884, 'label/mean': 0.5005, 'loss': 0.039567936, 'precision': 0.9999334, 'prediction/mean': 0.50050503, 'recall': 1.0, 'global_step': 1500}\n", "Eval Metrics (Validation): {'accuracy': 0.9018, 'accuracy_baseline': 0.505, 'auc': 0.9302231, 'auc_precision_recall': 0.9409413, 'average_loss': 0.61150163, 'label/mean': 0.495, 'loss': 76.4377, 'precision': 0.8905512, 'prediction/mean': 0.50745887, 'recall': 0.9139394, 'global_step': 1500}\n", "----------------------------------------------------------------------------------------------------\n", "Training for step = 1500\n", "Train Time (s): 305.9913341999054\n", "Eval Metrics (Train): {'accuracy': 1.0, 'accuracy_baseline': 0.5005, 'auc': 1.0, 'auc_precision_recall': 1.0, 'average_loss': 5.714076e-05, 'label/mean': 0.5005, 'loss': 0.0072945654, 'precision': 1.0, 'prediction/mean': 0.5004708, 'recall': 1.0, 'global_step': 2000}\n", "Eval Metrics (Validation): {'accuracy': 0.9032, 'accuracy_baseline': 0.505, 'auc': 0.929281, 'auc_precision_recall': 0.9409471, 'average_loss': 0.6407001, 'label/mean': 0.495, 'loss': 80.08751, 'precision': 0.8986784, 'prediction/mean': 0.4996146, 'recall': 0.9066667, 'global_step': 2000}\n" ] } ], "source": [ "tf.logging.set_verbosity(tf.logging.ERROR)\n", "\n", "results = {}\n", "\n", "results[\"nnlm-en-dim128\"] = train_and_evaluate_with_sentence_encoder(\n", " \"https://tfhub.dev/google/nnlm-en-dim128/1\", path='/storage/models/nnlm-en-dim128_f/')\n", "\n", "results[\"nnlm-en-dim128-with-training\"] = train_and_evaluate_with_sentence_encoder(\n", " \"https://tfhub.dev/google/nnlm-en-dim128/1\", train_module=True, path='/storage/models/nnlm-en-dim128_t/')\n", "\n", "results[\"use-512\"] = train_and_evaluate_with_sentence_encoder(\n", " \"https://tfhub.dev/google/universal-sentence-encoder/2\", path='/storage/models/use-512_f/')\n", "\n", "results[\"use-512-with-training\"] = train_and_evaluate_with_sentence_encoder(\n", " \"https://tfhub.dev/google/universal-sentence-encoder/2\", train_module=True, path='/storage/models/use-512_t/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Evaluations" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Model DirTraining AccuracyTest AccuracyTraining AUCTest AUCTraining PrecisionTest PrecisionTraining RecallTest Recall
nnlm-en-dim128/storage/models/nnlm-en-dim128_f/0.8616000.8361330.9385460.9182210.8443540.8227700.8869800.857390
nnlm-en-dim128-with-training/storage/models/nnlm-en-dim128_t/1.0000000.8784671.0000000.9196551.0000000.8759751.0000000.882157
use-512/storage/models/use-512_f/0.8873330.8670670.9558300.9423190.8979950.8767760.8741920.854594
use-512-with-training/storage/models/use-512_t/1.0000000.9045331.0000000.9304011.0000000.9046601.0000000.904660
\n", "
" ], "text/plain": [ " Model Dir \\\n", "nnlm-en-dim128 /storage/models/nnlm-en-dim128_f/ \n", "nnlm-en-dim128-with-training /storage/models/nnlm-en-dim128_t/ \n", "use-512 /storage/models/use-512_f/ \n", "use-512-with-training /storage/models/use-512_t/ \n", "\n", " Training Accuracy Test Accuracy Training AUC \\\n", "nnlm-en-dim128 0.861600 0.836133 0.938546 \n", "nnlm-en-dim128-with-training 1.000000 0.878467 1.000000 \n", "use-512 0.887333 0.867067 0.955830 \n", "use-512-with-training 1.000000 0.904533 1.000000 \n", "\n", " Test AUC Training Precision Test Precision \\\n", "nnlm-en-dim128 0.918221 0.844354 0.822770 \n", "nnlm-en-dim128-with-training 0.919655 1.000000 0.875975 \n", "use-512 0.942319 0.897995 0.876776 \n", "use-512-with-training 0.930401 1.000000 0.904660 \n", "\n", " Training Recall Test Recall \n", "nnlm-en-dim128 0.886980 0.857390 \n", "nnlm-en-dim128-with-training 1.000000 0.882157 \n", "use-512 0.874192 0.854594 \n", "use-512-with-training 1.000000 0.904660 " ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_df = pd.DataFrame.from_dict(results, orient=\"index\")\n", "results_df" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/storage/models/use-512_t/'" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model_dir = results_df[results_df['Test Accuracy'] == results_df['Test Accuracy'].max()]['Model Dir'].values[0]\n", "best_model_dir" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedding_feature = hub.text_embedding_column(\n", " key='sentence', module_spec=\"https://tfhub.dev/google/universal-sentence-encoder/2\", trainable=True)\n", "\n", "dnn = tf.estimator.DNNClassifier(\n", " hidden_units=[512, 128],\n", " feature_columns=[embedding_feature],\n", " n_classes=2,\n", " activation_fn=tf.nn.relu,\n", " dropout=0.1,\n", " optimizer=tf.train.AdagradOptimizer(learning_rate=0.005),\n", " model_dir=best_model_dir)\n", "dnn" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "def get_predictions(estimator, input_fn):\n", " return [x[\"class_ids\"][0] for x in estimator.predict(input_fn=input_fn)]" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0, 1, 0, 1, 1, 0, 1, 1, 1, 1]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = get_predictions(estimator=dnn, input_fn=predict_test_input_fn)\n", "predictions[:10]" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: seaborn in /usr/local/lib/python3.6/dist-packages (0.9.0)\n", "Requirement already satisfied: scipy>=0.14.0 in /usr/local/lib/python3.6/dist-packages (from seaborn) (1.0.1)\n", "Requirement already satisfied: pandas>=0.15.2 in /usr/local/lib/python3.6/dist-packages (from seaborn) (0.22.0)\n", "Requirement already satisfied: numpy>=1.9.3 in /usr/local/lib/python3.6/dist-packages (from seaborn) (1.14.3)\n", "Requirement already satisfied: matplotlib>=1.4.3 in /usr/local/lib/python3.6/dist-packages (from seaborn) (2.2.2)\n", "Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas>=0.15.2->seaborn) (2018.4)\n", "Requirement already satisfied: python-dateutil>=2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.15.2->seaborn) (2.7.2)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=1.4.3->seaborn) (1.11.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=1.4.3->seaborn) (1.0.1)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=1.4.3->seaborn) (2.2.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=1.4.3->seaborn) (0.10.0)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib>=1.4.3->seaborn) (39.1.0)\n", "\u001b[33mYou are using pip version 10.0.1, however version 18.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" ] } ], "source": [ "!pip install seaborn" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEKCAYAAADticXcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XucV1W9//HXmwFjEC8geMMLQnjwkmmi4iVFUcRbaJpimdDxF2V6MtPyksdL5kkt0zQjMU1QPGJ4VA4/RAnFTEUFRe4ioCioWQoqCjTMfM4fe4NfbjPfgdnf75fN+8ljPdh77fXda21m+MyatddeWxGBmZnlQ7NyN8DMzJqOg7qZWY44qJuZ5YiDuplZjjiom5nliIO6mVmOOKibmeWIg7qZWY44qJuZ5UjzcjdgXWr+OdePutoaqnf8armbYBVo+b8WaEPP0ZiY06Jdpw2uLyvuqZuZ5UjF9tTNzEqqrrbcLWgSDupmZgC1y8vdgibhoG5mBkTUlbsJTcJB3cwMoM5B3cwsP9xTNzPLEd8oNTPLEffUzczyIzz7xcwsR3yj1MwsRzz8YmaWI75RamaWI+6pm5nliG+UmpnliG+UmpnlR4TH1M3M8sNj6mZmOeLhFzOzHHFP3cwsR2pryt2CJuGgbmYGHn4xM8sVD7+YmeWIe+pmZjmSk6DerNwNMDOrBFFbU3RqiKStJQ2XNFPSDEkHS2oraYyk19O/26RlJelWSbMlTZb0lYLz9EvLvy6pXzHX4aBuZgbJmHqxqWG/BUZHRFfgy8AM4FJgbER0Acam+wDHAV3SNAAYCCCpLXAVcBBwIHDVih8E9XFQNzODZPil2FQPSVsBhwN3AUTEvyJiEdAHGJwWGwycnG73AYZEYjywtaQdgGOBMRHxYUQsBMYAvRu6DAd1MzNoyp76bsA/gD9JekXSHyVtDmwXEe+mZd4Dtku3OwBvF3x+fpq3rvx6OaibmUGjeuqSBkiaUJAGFJypOfAVYGBE7Ad8yudDLQBERACRxWV49ouZGTRqnnpEDAIGrePwfGB+RLyQ7g8nCep/l7RDRLybDq+8nx5fAOxc8Pmd0rwFQI/V8sc11Db31M3MAJYvLz7VIyLeA96W9G9pVk9gOjACWDGDpR/waLo9Ajg7nQXTHfgoHaZ5HOglqU16g7RXmlcv99TNzKCpnyj9D2CopM2AucB3SDrRD0o6B5gHnJ6WHQUcD8wGPkvLEhEfSroWeCkt9/OI+LChih3UzcygSR8+iohJQLe1HOq5lrIBnLeO89wN3N2Yuh3UzczAa7+YmeVKTpYJcFA3MwP31M3McqWBWS0bCwd1MzOAyORZoJJzUDczA4+pm5nlioO6mVmO+EapmVmO1NaWuwVNwkHdzAw8/GJmlisO6mZmOeIxdTOz/Ig6z1M3M8sPD7+YmeWIZ7+YmeWIe+q2IT7+ZDFXXX8Ls+fOA4lrL7+Qe4c9wptvzQfgk8WL2aJ1ax4afDsjH3+SP93/0MrPzprzBn+++za67t55Zd75P72a+e+8xyP3/aHk12JNb/fdO3P/0IEr9zvttgtXX/NrFrzzHlf+54/Zo2sXDj7kBCa+PHllmS99aQ8G3n4DW2zZmrq6OroffALLli0rR/M3Tg7qtiGuv+UPHHpQN26+7gpqampYsnQZN1172crjv7rtTlpv3gqAE489ihOPPQpIAvoPL/35KgF9zLhnadWqurQXYJmaNWsO3Q7oBUCzZs14682JPPLoY7RqVc03Tv8uA2+/fpXyVVVVDL7nVvp/5wImT55O27ZtqKmpKUfTN145WdDLL54ug08Wf8rEV6dy6knHAtCiRQu23KL1yuMRwegn/8rxx/RY47OjxjzNcUcfsXL/s8+WMGTY//C9fn0zb7eVR8+jDmPu3Hm89dYCZs6czaxZc9Yo0+uYI5gyZQaTJ08H4MMPF1KXk55nydTVFZ8qWOZBXVJ1wVu1DVjwznu02XorrrjuN5zW/zyu/OUtfLZk6crjE1+dyjZt2rDrzh3W+OzosU+vEuxvu3MI/fp+nZYtW5ai6VYGp5/ehweGPVJvmS5dOhEBo0YO5cUXRnPxReeWqHU5UhfFpwqWaVCXdBIwCRid7u8raUSWdW4MltfWMmPWbM445QSG33M71dUtueveB1ceHzVmHMcfc8Qan5s8bSbVLVvSpVNHAGbOmsPbC97l6CMOLVXTrcRatGjBSSf2YvhDI+st17x5FYcecgDf7nc+R/Q4mZP7HMdRRx5WolbmRG1t8amCZd1Tvxo4EFgEK9+wvdu6CksaIGmCpAl/HPLfGTetfLbfth3btW/HPnt1BaBXj8OYPms2AMuX1/KXp5+jd8/D1/jcY39Zdehl0rQZTJv5Or1O7cfZ517Em28voP/5Py3NRVhJ9O59JK+8MoX33/9nveXmL3iXZ/72Ah98sJAlS5by2Ogn2W+/vUvUynyIurqiUyXLOqjXRMRHq+Wt83eXiBgUEd0iotv/O/vMjJtWPu22acv227bnjXnJTJfxEyfRueMuyfaEV+i0605sv237VT5TV1fH408+s0pQ73vKiTw1YihPPDSYIQNvouPOHbjndzeW7kIsc33POLnBoReAJ554mr337kp1dUuqqqo4/KvdmTHj9RK0MEc8/FKUaZK+CVRJ6iLpNuC5jOvcKFx+4blccs2NnHL2ubz2+ly+e/YZwIreeI81yk+YNJXtt23Hzh12KHFLrVxatarm6J6H8/Ajj63M69OnN2/OnUD37vsz4tEhjBo5FIBFiz7ilt8OYvzzo5g44QlemTSFUY+NLVfTN05RV3yqYIoMp/FIagX8DOiVZj0O/CIilq77U4maf86t7B+HVhbVO3613E2wCrT8Xwu0oef49OffKjrmbH7l0A2uLytZz1PvGhE/IwnsZmaVa3ll3wAtVtZB/SZJ2wPDgWERMTXj+szM1k+FD6sUK9Mx9Yg4EjgS+Adwh6Qpkq7Isk4zs/XiG6XFiYj3IuJW4Pskc9avzLpOM7PGysuUxkyHXyTtAZwBnAp8AAwDLsqyTjOz9VLhPfBiZT2mfjdJID82It7JuC4zs/XnoN6wiDg4y/ObmTWZCn/8v1iZBHVJD0bE6ZKmsOoTpAIiIvbJol4zs/Xld5TW74L07xMzOr+ZWdPKSVDPZPZLRLybbv4gIuYVJuAHWdRpZrZBvJ56UY5ZS95xGddpZtZ4OZmnntWY+rkkPfJOkiYXHNoCeDaLOs3MNkiFB+tiZTWmfj/wGPBL4NKC/E8i4sOM6jQzW29RW9nDKsXKJKina6h/BJwJIGlboCXQWlLriHgri3rNzNabe+oNS19n9xtgR+B9YFdgBrBXlvWamTVWXqY0Zn2j9BdAd2BWROwG9ATGZ1ynmVnj5eRGaSleZ/cB0ExSs4h4CuiWcZ1mZo1X14hUwbJe+2WRpNbAX4Ghkt4HPs24TjOzRovlFR6ti5R1T70PsAS4EBgNzAFOyrhOM7PGc0+9YRFR2CsfnGVdZmYbwjdKiyDpE0kfr5belvSwpE5Z1m1m1ihN3FOXVCXpFUkj0/17JL0haVKa9k3zJelWSbMlTZb0lYJz9JP0epr6FVNv1mPqtwDzSR5GEtAX6Ay8TLLWeo+M6zczK0oGPfULSKZwb1mQ95OIGL5aueOALmk6CBgIHCSpLXAVyeSSACZKGhERC+urNOsx9a9FxB0R8UlEfBwRg0hemDEMaJNx3WZmxWvCnrqknYATgD8WUXMfYEgkxgNbS9oBOBYYExEfpoF8DNC7oZNlHdQ/k3S6pGZpOh1Ymh7LxwCWmeVCLC8+SRogaUJBGrDa6W4BfsqaPwKuS4dYbpb0hTSvA/B2QZn5ad668uuVdVD/FvBtkqdJ/55unyWpGjg/47rNzIoWdY1IEYMioltBGrTiPJJOBN6PiImrVXEZ0BU4AGgLXJLFdWQ9+2Uu657C+Lcs6zYza5Smm6p4KPA1SceTrHm1paT7IuKs9PgySX8CLk73FwA7F3x+pzRvAaved9wJGNdQ5VnPftld0lhJU9P9fSRdkWWdZmbrozE99XrPE3FZROwUER1JJoc8GRFnpePkSBJwMjA1/cgI4Ox0Fkx34KP0RUOPA70ktZHUBuiV5tUr6+GXO0l+5agBiIjJJBdpZlZRmiqo12No+t7mKUA7krWxAEYBc4HZJDHzBwDpMuXXAi+l6efFLF2e9ZTGVhHxYvKDaaXlGddpZtZoUauGCzX2nBHjSIdMIuKodZQJ4Lx1HLubZPp30bIO6v+U1Jl0pouk04B36/+ImVnpbUAPvKJkHdTPAwYBXSUtAN4gmRFjZlZRoq7pe+rlkHVQXwD8CXiKZArPx0A/4OcZ12tm1ijuqRfnUWARybIA72Rcl5nZeotwT70YO0VEg4+1mpmVW1566g1OaZS0uaRm6fbukr4mqUWR539O0pc2qIVmZiVQV6uiUyUrpqf+V+Cr6eT3J0jmS55BcTc8DwP6S3oDWEayUmNExD7r2V4zs0xsSjdKFRGfSToH+H1E3ChpUpHnP24D2mZmVjKbVFCXdDBJz/ycNK+qmJNHxLz1bZiZWSlFTtaNLSao/4jkUf+HI2Ja+saip7JtlplZaW0yPfWIeBp4umB/LvDDLBtlZlZquZ/SKOl/qedFFhHxtUxaZGZWBrUVPqulWPX11H9dslaYmZVZ7nvq6bCLmdkmYZMZU5fUBfglsCfJWzwAiIhOGbbLzKyk8jL7pZiXZPwJGEiyDvqRwBDgviwbZWZWalGnolMlKyaoV0fEWJKHkOZFxNXACdk2y8ystGrrmhWdKlkx89SXpWu/vC7pfJLldFtn2ywzs9LalIZfLgBakcxN3x/4Nsma6GZmuVEXKjpVsmIePnop3VwMfCfb5piZlUfupzSuIOkp1vIQ0rpeompmtjHKy/BLMWPqFxdstwROJZkJk6nqHb+adRW2EVryzjPlboLlVKUPqxSrmOGXiatlPSvpxYzaY2ZWFpU+q6VYxQy/tC3YbUZys3SrzFpkZlYGORl9KWr4ZSLJ9Ypk2OUNPl9X3cwsFzaZ4Rdgj4hYWpgh6QsZtcfMrCzyMvulmEGk59aS93xTN8TMrJzqGpEqWX3rqW8PdACqJe1HMvwCsCXJw0hmZrkR5KOnXt/wy7FAf2An4CY+D+ofA5dn2ywzs9JanpPhl/rWUx8MDJZ0akQ8VMI2mZmVXF566sWMqe8vaesVO5LaSPpFhm0yMyu5vIypFxPUj4uIRSt2ImIhcHx2TTIzK71ARadKVsyUxipJX4iIZQCSqgFPaTSzXKn0HnixignqQ4Gxkv5EcrO0PzA4y0aZmZVabYX3wItVzNovN0h6FTia5MnSx4Fds26YmVkpVfhb6opWTE8d4O8kAf0bJMsEeDaMmeVKXd576pJ2B85M0z+BYSTvKT2yRG0zMyuZTWFBr5nAM8CJETEbQNKFJWmVmVmJ5eVGaX1TGr8OvAs8JelOST0hJ7+fmJmtpk4qOlWydQb1iHgkIvoCXYGngB8B20oaKKlXqRpoZlYKtY1IlazBh48i4tOIuD8iTiJZB+YV4JLMW2ZmVkJ1Kj5VsmJnvwArnyYdlCYzs9zI/ewXM7NNyaYw+8XMbJNR6cMqxcrH67PNzDZQU63SKKmlpBclvSppmqRr0vzdJL0gabakYZI2S/O/kO7PTo93LDjXZWn+a5KOLeY6HNTNzIBaFZ8asAw4KiK+DOwL9JbUHbgBuDkivggsBM5Jy58DLEzzb07LIWlPoC+wF9Ab+L2kqoYqd1A3M6PpeuqRWJzutkhTAEcBw9P8wcDJ6XYfPl8kcTjQU5LS/AciYllEvAHMBg5s6Doc1M3MaNqXZEiqkjQJeB8YA8wBFkXE8rTIfJJ3QJP+/TZAevwjYJvC/LV8Zp0c1M3MgFDxSdIASRMK0oBVzhVRGxH7kjzbcyDJQ5wl4dkvZmY0bu2XiCjqeZ2IWCTpKeBgYGtJzdPe+E7AgrTYAmBnYL6k5sBWwAcF+SsUfmad3FM3M6PplgmQ1H7Fe53TN8UdA8wgWW7ltLRYP+DRdHtEuk96/MmIiDS/bzo7ZjegC/BiQ9fhnrqZGU06T30HYHA6U6UZ8GBEjJQ0HXhA0i9Illu5Ky1/F3CvpNnAhyQzXoiIaZIeBKYDy4HzIqLBpWcc1M3MaLqldyNiMrDfWvLnspbZKxGxlOQFRGs713XAdY2p30HdzIz8rKfuoG5mhtd+MTPLlbys/eKgbmZG5b/8olgO6mZmQF1OBmAc1M3M8I1SM7NcyUc/3UHdzAxwT93MLFeWKx99dQd1MzM8/GJmlisefjEzyxFPaTQzy5F8hHQHdTMzwMMvZma5UpuTvrqDupkZ7qmbmeVKuKduZpYf7qlbk9h9987cP3Tgyv1Ou+3C1df8mgXvvMeV//lj9ujahYMPOYGJL09eWeZLX9qDgbffwBZbtqauro7uB5/AsmXLytF8a2Iff7KYq66/hdlz54HEtZdfyL3DHuHNt+YD8MnixWzRujUPDb6dmpoarrnxNqbNfB01E5de8H0O/Mo+q5zv/J9ezfx33uOR+/5QjsvZqHhKozWJWbPm0O2AXgA0a9aMt96cyCOPPkarVtV84/TvMvD261cpX1VVxeB7bqX/dy5g8uTptG3bhpqamnI03TJw/S1/4NCDunHzdVdQU1PDkqXLuOnay1Ye/9Vtd9J681YADB8xGoCH7x3IBwsXce5F/8kDf/wtzZo1A2DMuGdp1aq69BexkcpHSE/edG0VoudRhzF37jzeemsBM2fOZtasOWuU6XXMEUyZMoPJk6cD8OGHC6mry8svjpu2TxZ/ysRXp3LqSccC0KJFC7bcovXK4xHB6Cf/yvHH9ABgzptvceD+XwZgmzZbs0XrzZk283UAPvtsCUOG/Q/f69e3tBexEVtOFJ0qWaZBXYmzJF2Z7u8iaY23aVvi9NP78MCwR+ot06VLJyJg1MihvPjCaC6+6NwStc6ytuCd92iz9VZccd1vOK3/eVz5y1v4bMnSlccnvjqVbdq0YdedOwDwb1/cjXF/G8/y5bXMf+c9pr82m/f+/g8AbrtzCP36fp2WLVuW5Vo2RtGIP5Us657674GDgTPT/U+A29dVWNIASRMkTair+zTjplWWFi1acNKJvRj+0Mh6yzVvXsWhhxzAt/udzxE9TubkPsdx1JGHlaiVlqXltbXMmDWbM045geH33E51dUvuuvfBlcdHjRnH8cccsXL/lBOOZbv27TjjnB9yw2/vYN+996BZVTNmzprD2wve5egjDi3HZWy06hqRKlnWQf2giDgPWAoQEQuBzdZVOCIGRUS3iOjWrNnmGTetsvTufSSvvDKF99//Z73l5i94l2f+9gIffLCQJUuW8tjoJ9lvv71L1ErL0vbbtmO79u3YZ6+uAPTqcRjTZ80GYPnyWv7y9HP07nn4yvLNm1dxyQXf46HBt3PbDVfx8eJP6bhzByZNm8G0ma/T69R+nH3uRbz59gL6n//TslzTxsQ99eLUSKoivQchqT2V/4OuLPqecXKDQy8ATzzxNHvv3ZXq6pZUVVVx+Fe7M2PG6yVooWWt3TZt2X7b9rwxL5npMn7iJDp33CXZnvAKnXbdie23bb+y/JKlS1cOzzz34ss0r6qi82670veUE3lqxFCeeGgwQwbeRMedO3DP724s/QVtZPLSU8969sutwMPAtpKuA04Drsi4zo1Oq1bVHN3zcM79wSUr8/r06c1vb/4F7du3ZcSjQ3j11Wkcf+K3WLToI2757SDGPz8quXE2+klGPTa2jK23pnT5hedyyTU3UrO8hp133IFrL78QgMf+8jTHHd1jlbIfLvyI7134M9SsGdu134ZfXnlxGVqcH7VR2T3wYikyvhBJXYGegICxETGjmM8136xDPv6FrUkteeeZcjfBKlCLdp20oef45q6nFB1z7p/38AbXl5VMe+qSbgUeiIh13hw1M6sElT5WXqysx9QnAldImiPp15K6ZVyfmdl6ycuYeqZBPSIGR8TxwAHAa8ANknxXz8wqTh1RdKpkpVom4ItAV2BXoKgxdTOzUsrL8EvWY+o3AqcAc4BhwLURsSjLOs3M1kdeZr9k3VOfAxwcEfU/UWNmVmaVPqxSrEyCuqSuETETeAnYRdIuhccj4uUs6jUzW1+VfgO0WFn11H8MDABuWsuxAI7KqF4zs/XiMfV6RMSAdPO4iFhaeEySl40zs4qTl+GXrOepP1dknplZWUVE0amSZTWmvj3QAaiWtB/JEgEAWwKtsqjTzGxD1Oakp57VmPqxQH9gJ+A3BfmfAJdnVKeZ2XrLy/BLVmPqg4HBkk6NiIeyqMPMrClV+rBKsbIafjkrIu4DOkr68erHI+I3a/mYmVnZuKdevxWvLWpdbykzswrhKY31iIg70r+vyeL8ZmZNLS/LBGQ6pVHSjZK2lNRC0lhJ/5B0VpZ1mpmtj7ys0pj1PPVeEfExcCLwJslqjT/JuE4zs0ZzUC/OiuGdE4A/R8RHGddnZrZemvLhI0l3S3pf0tSCvKslLZA0KU3HFxy7TNJsSa9JOrYgv3eaN1vSpcVcR9ZBfaSkmcD+wFhJ7YGlDXzGzKzkmrinfg/Qey35N0fEvmkaBSBpT6AvsFf6md9LqpJUBdwOHAfsCZyZlq1X1m8+uhQ4BOgWETXAp0CfLOs0M1sf0Yg/DZ4r4q/Ah0VW3YfkXc7LIuINYDZwYJpmR8TciPgX8ABFxM+sb5S2AM4ChkkaDpwDfJBlnWZm66M26opOG+B8SZPT4Zk2aV4H4O2CMvPTvHXl1yvr4ZeBJEMvv0/TV9I8M7OK0pgxdUkDJE0oSAMaroGBQGdgX+Bd1r40+QbL+s1HB0TElwv2n5T0asZ1mpk1WmNmtUTEIGBQY84fEX9fsS3pTmBkursA2Lmg6E5pHvXkr1PWPfVaSZ1X7EjqBNRmXKeZWaM15Zj62kjaoWD3FGDFzJgRQF9JX5C0G9AFeJHkzXFdJO0maTOSm6kjGqon6576T4CnJM1N9zsC38m4TjOzRqtrwidKJf030ANoJ2k+cBXQQ9K+JG9/exP4HkBETJP0IDAdWA6cFxG16XnOBx4HqoC7I2Jag3VnuTJZ+paji4CewCKSnzw3r/42pLVpvlmHyp7hb2Wx5J1nyt0Eq0At2nVSw6Xqt9d2BxUdc6b9/YUNri8rWffUhwAfA9em+98E7gW+kXG9ZmaNsoGzWipG1kF974gonCz/lKTpGddpZtZoTTn8Uk5Z3yh9WVL3FTuSDgImZFynmVmjZX2jtFSy7qnvDzwn6a10fxfgNUlTgIiIfTKu38ysKHnpqWcd1Ne29oGZWcWp9B54sTIN6hExL8vzm5k1ldrIxyM0WffUzcw2Cn7xtJlZjlT6yy+K5aBuZoZ76mZmueLZL2ZmOeLZL2ZmOeJlAszMcsRj6mZmOeIxdTOzHHFP3cwsRzxP3cwsR9xTNzPLEc9+MTPLEd8oNTPLEQ+/mJnliJ8oNTPLEffUzcxyJC9j6srLT6c8kzQgIgaVux1WWfx9YWvTrNwNsKIMKHcDrCL5+8LW4KBuZpYjDupmZjnioL5x8LiprY2/L2wNvlFqZpYj7qmbmeWIg/pGRtLWkn5QsL+jpOHlbJOVlqTvSzo73e4vaceCY3+UtGf5Wmfl5uGXjYykjsDIiNi7zE2xCiBpHHBxREwod1usMrin3sQkdZQ0Q9KdkqZJekJStaTOkkZLmijpGUld0/KdJY2XNEXSLyQtTvNbSxor6eX0WJ+0iuuBzpImSfpVWt/U9DPjJe1V0JZxkrpJ2lzS3ZJelPRKwbmsxNKv10xJQ9Pvk+GSWknqmX5tpqRfqy+k5a+XNF3SZEm/TvOulnSxpNOAbsDQ9PuhuuBr/n1Jvyqot7+k36XbZ6XfC5Mk3SGpqhz/FpaRiHBqwgR0BJYD+6b7DwJnAWOBLmneQcCT6fZI4Mx0+/vA4nS7ObBlut0OmA0oPf/U1eqbmm5fCFyTbu8AvJZu/xdwVrq9NTAL2Lzc/1abYkq/XgEcmu7fDVwBvA3snuYNAX4EbAO8xue/UW+d/n01Se8cYBzQreD840gCfXtgdkH+Y8BhwB7A/wIt0vzfA2eX+9/FqemSe+rZeCMiJqXbE0n+Ix8C/FnSJOAOkqALcDDw53T7/oJzCPgvSZOBvwAdgO0aqPdB4LR0+3RgxVh7L+DStO5xQEtgl0ZflTWVtyPi2XT7PqAnyffMrDRvMHA48BGwFLhL0teBz4qtICL+AcyV1F3SNkBX4Nm0rv2Bl9Lvh55Apya4JqsQXtArG8sKtmtJgvGiiNi3Eef4Fklva/+IqJH0JkkwXqeIWCDpA0n7AGeQ9Pwh+QFxakS81oj6LTur38haRNIrX7VQxHJJB5IE3tOA84GjGlHPAyQ/3GcCD0dESBIwOCIuW6+WW8VzT700PgbekPQNACW+nB4bD5yabvct+MxWwPtpQD8S2DXN/wTYop66hgE/BbaKiMlp3uPAf6T/oZG034ZekG2QXSQdnG5/E5gAdJT0xTTv28DTklqTfB1HkQytfXnNU9X7/fAw0Ac4kyTAQzIMeJqkbQEktZW06zo+bxshB/XS+RZwjqRXgWkk/9kgGTv9cTrM8kWSX7kBhgLdJE0BzibpbRERHwDPSppaeCOswHCSHw4PFuRdC7QAJkualu5b+bwGnCdpBtAGuBn4Dsnw3BSgDvgDSbAemX5v/A348VrOdQ/whxU3SgsPRMRCYAawa0S8mOZNJxnDfyI97xg+Hwq0HPCUxjKT1ApYkv5q3Jfkpqlnp+SUp6Ra1jymXn77A79Lh0YWAf9e5vaY2UbMPXUzsxzxmLqZWY44qJuZ5YiDuplZjjioW5OTVJtOsZsq6c/pDJ/1PVcPSSPT7a9JurSesqusYNmIOq6WdPH6ttGskjioWxaWRMS+6bS9f/H5k63AyoevGv29FxEjIuL6eopsDTQ6qJvliYO6Ze0Z4Ivp6oSvSRoCTAV2ltRL0vPpSpR/Tp+gRFLvdCXDl4GvrzjRaisNbifpYUmvpukQVlvBMi33E0kvpascXlNwrp9JmiXpb8C/lexfwyxjnqdumZHUHDgOGJ1mdQH6RcR4Se3h1Y0ZAAABiElEQVRInmw8OiI+lXQJyZO1NwJ3kqxxMptk2YO1uRV4OiJOSZeObQ1cCuy9Yo0dSb3SOg8kWf9mhKTDgU9Jnrrdl+T/wMskC6+ZbfQc1C0L1ekKgJD01O8CdgTmRcT4NL87sCfJkgcAmwHPk6wm+EZEvA4g6T5gwFrqOIpk+QQiohb4SFKb1cr0StMr6X5rkiC/BckCV5+ldYzYoKs1qyAO6paFJauvSJkG7k8Ls4AxEXHmauUas5JlQwT8MiLuWK2OHzVhHWYVxWPqVi7jgUNXrEyo5O1Mu5MsXNZRUue03Jnr+PxY4Nz0s1WStmLNFQsfB/69YKy+Q7o64V+Bk9M3BW0BnNTE12ZWNg7qVhbpSxz6A/+drhb4PNA1IpaSDLf8//RG6fvrOMUFwJHpqoYTgT1XX8EyIp4gefHI82m54cAWEfEyyVj9qyRvBHopsws1KzGv/WJmliPuqZuZ5YiDuplZjjiom5nliIO6mVmOOKibmeWIg7qZWY44qJuZ5YiDuplZjvwfMuY4Mhvw6AMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "with tf.Session() as session:\n", " cm = tf.confusion_matrix(test_sentiments, predictions).eval()\n", "\n", "LABELS = ['negative', 'positive']\n", "sns.heatmap(cm, annot=True, xticklabels=LABELS, yticklabels=LABELS, fmt='g')\n", "xl = plt.xlabel(\"Predicted\")\n", "yl = plt.ylabel(\"Actuals\")" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " negative 0.90 0.90 0.90 7490\n", " positive 0.90 0.90 0.90 7510\n", "\n", "avg / total 0.90 0.90 0.90 15000\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "\n", "print(classification_report(y_true=test_sentiments, y_pred=predictions, target_names=LABELS))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }