{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Question Answering with DeepMatcher\n", "\n", "Note: you can run **[this notebook live in Google Colab](https://colab.research.google.com/github/anhaidgroup/deepmatcher/blob/master/examples/question_answering.ipynb)**.\n", "\n", "DeepMatcher can be easily be used for text matching tasks such Question Answering, Text Entailment, etc. In this tutorial we will see how to use DeepMatcher for Answer Selection, a major sub-task of Question Answering. Specifically, we will look at [WikiQA](https://aclweb.org/anthology/D15-1237), a benchmark dataset for Answer Selection. There are three main steps in this tutorial:\n", "\n", "1. Get data and transform it into DeepMatcher input format\n", "2. Setup and train DeepMatcher model\n", "3. Evaluate model using QA eval metrics\n", "\n", "Before we begin, if you are running this notebook in Colab, you will first need to install necessary packages by running the code below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " import deepmatcher\n", "except:\n", " !pip install -qqq deepmatcher" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Get data and transform it into DeepMatcher input format\n", "\n", "First let's import relevant packages and download the dataset:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import deepmatcher as dm\n", "import pandas as pd\n", "import os\n", "\n", "!wget -qnc https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip\n", "!unzip -qn WikiQACorpus.zip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how this dataset looks like:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012
0how are glacier caves formed ?A partly submerged glacier cave on Perito More...0
1how are glacier caves formed ?The ice facade is approximately 60 m high0
2how are glacier caves formed ?Ice formations in the Titlis glacier cave0
3how are glacier caves formed ?A glacier cave is a cave formed within the ice...1
4how are glacier caves formed ?Glacier caves are often called ice caves , but...0
\n", "
" ], "text/plain": [ " 0 \\\n", "0 how are glacier caves formed ? \n", "1 how are glacier caves formed ? \n", "2 how are glacier caves formed ? \n", "3 how are glacier caves formed ? \n", "4 how are glacier caves formed ? \n", "\n", " 1 2 \n", "0 A partly submerged glacier cave on Perito More... 0 \n", "1 The ice facade is approximately 60 m high 0 \n", "2 Ice formations in the Titlis glacier cave 0 \n", "3 A glacier cave is a cave formed within the ice... 1 \n", "4 Glacier caves are often called ice caves , but... 0 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_train = pd.read_csv(os.path.join('WikiQACorpus', 'WikiQA-train.txt'), sep='\\t', header=None)\n", "raw_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clearly, it is not in the format `deepmatcher` wants its input data to be in - this file has no column names, no ID column, and its not a CSV file. Let's fix that:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
left_valueright_valuelabel
id
0how are glacier caves formed ?A partly submerged glacier cave on Perito More...0
1how are glacier caves formed ?The ice facade is approximately 60 m high0
2how are glacier caves formed ?Ice formations in the Titlis glacier cave0
3how are glacier caves formed ?A glacier cave is a cave formed within the ice...1
4how are glacier caves formed ?Glacier caves are often called ice caves , but...0
\n", "
" ], "text/plain": [ " left_value \\\n", "id \n", "0 how are glacier caves formed ? \n", "1 how are glacier caves formed ? \n", "2 how are glacier caves formed ? \n", "3 how are glacier caves formed ? \n", "4 how are glacier caves formed ? \n", "\n", " right_value label \n", "id \n", "0 A partly submerged glacier cave on Perito More... 0 \n", "1 The ice facade is approximately 60 m high 0 \n", "2 Ice formations in the Titlis glacier cave 0 \n", "3 A glacier cave is a cave formed within the ice... 1 \n", "4 Glacier caves are often called ice caves , but... 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_train.columns = ['left_value', 'right_value', 'label']\n", "raw_train.index.name = 'id'\n", "raw_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks good, now let's save this to disk and transform the validation and test data in the same way:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "raw_train.to_csv(os.path.join('WikiQACorpus', 'dm_train.csv'))\n", "\n", "raw_files = ['WikiQA-dev.txt', 'WikiQA-test.txt']\n", "csv_files = ['dm_valid.csv', 'dm_test.csv']\n", "for i in range(2):\n", " raw_data = pd.read_csv(os.path.join('WikiQACorpus', raw_files[i]), sep='\\t', header=None)\n", " raw_data.columns = ['left_value', 'right_value', 'label']\n", " raw_data.index.name = 'id'\n", " raw_data.to_csv(os.path.join('WikiQACorpus', csv_files[i]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Setup and train DeepMatcher model\n", "\n", "Now we are ready to load and process the data for `deepmatcher`:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rebuilding data cache because: {'One or more data files have been modified.'}\n", "Load time: 6.962715303525329\n", "Vocab time: 14.411666898056865\n", "Metadata time: 4.01532360445708\n", "Cache time: 7.646213295869529\n" ] } ], "source": [ "train, validation, test = dm.data.process(\n", " path='WikiQACorpus',\n", " train='dm_train.csv',\n", " validation='dm_valid.csv',\n", " test='dm_test.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we create a `deepmatcher` model and train it. Note that since this is a demo, we do not perform hyperparameter tuning - we simply use the default settings for everything except the `pos_neg_ratio` param. This must be set since there are very few \"positive matches\" (candidates that correctly answer the question) in this dataset. In a real application setting you must tune other model hyperparameters as well to get optimal performance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Number of trainable parameters: 2798703\n", "===> TRAIN Epoch 1 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 1 || Run Time: 18.7 | Load Time: 6.6 || F1: 12.55 | Prec: 18.40 | Rec: 9.52 || Ex/s: 803.13\n", "\n", "===> EVAL Epoch 1 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 1 || Run Time: 1.1 | Load Time: 0.9 || F1: 29.49 | Prec: 21.77 | Rec: 45.71 || Ex/s: 1378.35\n", "\n", "* Best F1: 29.493087557603683\n", "Saving best model...\n", "===> TRAIN Epoch 2 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 2 || Run Time: 18.5 | Load Time: 6.7 || F1: 30.13 | Prec: 24.97 | Rec: 37.98 || Ex/s: 808.23\n", "\n", "===> EVAL Epoch 2 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 2 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.00 | Prec: 24.60 | Rec: 55.00 || Ex/s: 1389.86\n", "\n", "* Best F1: 33.99558498896247\n", "Saving best model...\n", "===> TRAIN Epoch 3 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 3 || Run Time: 18.6 | Load Time: 6.7 || F1: 40.77 | Prec: 31.73 | Rec: 57.02 || Ex/s: 804.25\n", "\n", "===> EVAL Epoch 3 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 3 || Run Time: 1.0 | Load Time: 0.9 || F1: 31.07 | Prec: 22.40 | Rec: 50.71 || Ex/s: 1429.85\n", "\n", "===> TRAIN Epoch 4 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 4 || Run Time: 18.6 | Load Time: 6.7 || F1: 50.94 | Prec: 39.58 | Rec: 71.44 || Ex/s: 803.71\n", "\n", "===> EVAL Epoch 4 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 4 || Run Time: 1.0 | Load Time: 0.9 || F1: 34.23 | Prec: 26.02 | Rec: 50.00 || Ex/s: 1427.62\n", "\n", "* Best F1: 34.22982885085574\n", "Saving best model...\n", "===> TRAIN Epoch 5 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 5 || Run Time: 18.6 | Load Time: 6.7 || F1: 63.19 | Prec: 50.67 | Rec: 83.94 || Ex/s: 805.32\n", "\n", "===> EVAL Epoch 5 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 5 || Run Time: 1.0 | Load Time: 0.9 || F1: 35.14 | Prec: 33.33 | Rec: 37.14 || Ex/s: 1430.46\n", "\n", "* Best F1: 35.13513513513514\n", "Saving best model...\n", "===> TRAIN Epoch 6 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 6 || Run Time: 18.6 | Load Time: 6.7 || F1: 74.33 | Prec: 62.93 | Rec: 90.77 || Ex/s: 804.23\n", "\n", "===> EVAL Epoch 6 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 6 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.93 | Prec: 33.55 | Rec: 36.43 || Ex/s: 1388.46\n", "\n", "===> TRAIN Epoch 7 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 7 || Run Time: 18.6 | Load Time: 6.7 || F1: 82.96 | Prec: 73.62 | Rec: 95.00 || Ex/s: 805.43\n", "\n", "===> EVAL Epoch 7 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 7 || Run Time: 1.1 | Load Time: 0.9 || F1: 32.07 | Prec: 27.09 | Rec: 39.29 || Ex/s: 1386.02\n", "\n", "===> TRAIN Epoch 8 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 8 || Run Time: 18.6 | Load Time: 6.7 || F1: 86.39 | Prec: 79.07 | Rec: 95.19 || Ex/s: 804.31\n", "\n", "===> EVAL Epoch 8 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 8 || Run Time: 1.0 | Load Time: 0.9 || F1: 33.33 | Prec: 30.72 | Rec: 36.43 || Ex/s: 1426.20\n", "\n", "===> TRAIN Epoch 9 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 9 || Run Time: 18.7 | Load Time: 6.7 || F1: 91.81 | Prec: 86.68 | Rec: 97.60 || Ex/s: 802.35\n", "\n", "===> EVAL Epoch 9 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 9 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.88 | Prec: 38.14 | Rec: 32.14 || Ex/s: 1385.16\n", "\n", "===> TRAIN Epoch 10 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 10 || Run Time: 18.7 | Load Time: 6.7 || F1: 94.18 | Prec: 90.59 | Rec: 98.08 || Ex/s: 803.29\n", "\n", "===> EVAL Epoch 10 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [█████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 10 || Run Time: 1.0 | Load Time: 0.9 || F1: 32.86 | Prec: 32.86 | Rec: 32.86 || Ex/s: 1426.35\n", "\n", "Loading best model...\n" ] } ], "source": [ "model = dm.MatchingModel()\n", "model.run_train(\n", " train,\n", " validation,\n", " epochs=10,\n", " best_save_path='hybrid_model.pth',\n", " pos_neg_ratio=7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have a trained model, we obtain the predictions for the test data. Note that `deepmatcher` computes F1, precision and recall by default but these may not be optimal evaluation metrics for your end task. For instance, in Question Answering, the more relevant metrics are MAP and MRR which we will compute in the next step." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "===> PREDICT Epoch 5 :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0% [██████████████████████████████] 100% | ETA: 00:00:00\n", "Total time elapsed: 00:00:04\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Finished Epoch 5 || Run Time: 2.5 | Load Time: 1.9 || F1: 28.88 | Prec: 26.50 | Rec: 31.74 || Ex/s: 1381.48\n", "\n" ] } ], "source": [ "predictions = model.run_prediction(test, output_attributes=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Evaluate model using QA eval metrics " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we compute the Mean Average Precision (MAP) and Mean Reciprocal Rank (MRR) using the model's predictions on the test set. Following the approach of the [paper that introduced this dataset](https://aclweb.org/anthology/D15-1237), questions in the test set without answers are ignored when computing these metrics." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MAP: 0.6951284386872146\n", "MRR: 0.7099084137865672\n" ] } ], "source": [ "MAP, MRR = 0, 0\n", "\n", "grouped = predictions.groupby('left_value')\n", "num_questions = 0\n", "for question, answers in grouped:\n", " sorted_answers = answers.sort_values('match_score', ascending=False)\n", " \n", " p, ap = 0, 0\n", " top_answer_found = False\n", " for idx, answer in enumerate(sorted_answers.itertuples()):\n", " if answer.label == 1:\n", " if not top_answer_found:\n", " MRR += 1 / (idx + 1)\n", " top_answer_found = True\n", " p += 1\n", " ap += p / (idx + 1)\n", " \n", " if p > 0:\n", " ap /= p\n", " num_questions += 1\n", " MAP += ap\n", " \n", "MAP /= num_questions\n", "MRR /= num_questions\n", "\n", "print('MAP:', MAP)\n", "print('MRR:', MRR)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "question_answering.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }