{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sentence Pair Classification with *ktrain*\n", "\n", "This notebook demonstrates sentence pair classification with *ktrain*. \n", "\n", "## Download a Sentence Pair Classification Dataset\n", "\n", "In this notebook, we will use the Microsoft Research Paraphrase Corpus (MRPC) to build a model that can detect pairs of sentences that are paraphrases of one another. The MRPC train and test datasets can be downloaded from here:\n", "- [MRPC train dataset](https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt)\n", "- [MRPC test dataset](https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt)\n", "\n", "Once downloaded, we will prepare the datasets as arrays of sentence pairs." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import csv\n", "TRAIN = 'data/mrpc/msr_paraphrase_train.txt'\n", "TEST = 'data/mrpc/msr_paraphrase_test.txt'\n", "train_df = pd.read_csv(TRAIN, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", "test_df = pd.read_csv(TEST, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", "x_train = train_df[['#1 String', '#2 String']].values\n", "y_train = train_df['Quality'].values\n", "x_test = test_df[['#1 String', '#2 String']].values\n", "y_test = test_df['Quality'].values\n", "\n", "\n", "# IMPORTANT: data format for sentence pair classification is list of tuples of form (str, str)\n", "x_train = list(map(tuple, x_train))\n", "x_test = list(map(tuple, x_test))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .', 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .')\n", "1\n" ] } ], "source": [ "print(x_train[0])\n", "print(y_train[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build and Train a `BERT` Model\n", "\n", "For demonstration purposes, we only train for 3 epochs." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 5e-05...\n", "Train for 128 steps, validate for 54 steps\n", "Epoch 1/3\n", "128/128 [==============================] - 66s 518ms/step - loss: 0.5913 - accuracy: 0.6796 - val_loss: 0.5731 - val_accuracy: 0.7328\n", "Epoch 2/3\n", "128/128 [==============================] - 50s 390ms/step - loss: 0.3982 - accuracy: 0.8182 - val_loss: 0.4072 - val_accuracy: 0.8354\n", "Epoch 3/3\n", "128/128 [==============================] - 50s 390ms/step - loss: 0.1550 - accuracy: 0.9495 - val_loss: 0.4492 - val_accuracy: 0.8504\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ktrain\n", "from ktrain import text\n", "MODEL_NAME = 'bert-base-uncased'\n", "t = text.Transformer(MODEL_NAME, maxlen=128, class_names=['not paraphrase', 'paraphrase'])\n", "trn = t.preprocess_train(x_train, y_train)\n", "val = t.preprocess_test(x_test, y_test)\n", "model = t.get_classifier()\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=32) # lower bs if OOM occurs\n", "learner.fit_onecycle(5e-5, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's select a positive and negative example from `x_test`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 0, 0])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test[:5]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "positive = x_test[0]\n", "negative = x_test[4]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Valid Paraphrase:\n", "(\"PCCW 's chief operating officer , Mike Butcher , and Alex Arena , the chief financial officer , will report directly to Mr So .\", 'Current Chief Operating Officer Mike Butcher and Group Chief Financial Officer Alex Arena will report to So .')\n" ] } ], "source": [ "print('Valid Paraphrase:\\n%s' %(positive,))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Invalid Paraphrase:\n", "(\"The company didn 't detail the costs of the replacement and repairs .\", 'But company officials expect the costs of the replacement work to run into the millions of dollars .')\n" ] } ], "source": [ "print('Invalid Paraphrase:\\n%s' %(negative,))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'paraphrase'" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(positive)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'not paraphrase'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(negative)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "['paraphrase', 'not paraphrase']" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict([positive, negative])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/mrpc_model')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "p = ktrain.load_predictor('/tmp/mrpc_model')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'paraphrase'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(positive)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }