{ "cells": [ { "cell_type": "markdown", "id": "KLqW6FOnEvov", "metadata": { "id": "KLqW6FOnEvov" }, "source": [ "# Training Multilabel Classification Models with Legal NLP" ] }, { "cell_type": "markdown", "id": "wxZDXLDCXkk_", "metadata": { "id": "wxZDXLDCXkk_" }, "source": [ "\n", "![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)\n" ] }, { "cell_type": "markdown", "id": "pZ6sKi8ZX1z4", "metadata": { "id": "pZ6sKi8ZX1z4" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/legal-nlp/04.3.Training_Legal_Multilabel_Classifier.ipynb)" ] }, { "cell_type": "markdown", "id": "245d53bb", "metadata": { "id": "245d53bb" }, "source": [ "In this notebook, you will learn how to use Spark NLP and Legal NLP to train multilabel classification models.\n", "\n", "Let`s dive in!" ] }, { "cell_type": "markdown", "id": "RPprYrOB-O7L", "metadata": { "id": "RPprYrOB-O7L" }, "source": [ "# Colab Setup" ] }, { "cell_type": "markdown", "id": "gk3kZHmNj51v", "metadata": { "collapsed": false, "id": "gk3kZHmNj51v" }, "source": [ "# Installation" ] }, { "cell_type": "code", "execution_count": null, "id": "_914itZsj51v", "metadata": { "id": "_914itZsj51v", "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "! pip install -q johnsnowlabs" ] }, { "cell_type": "markdown", "id": "YPsbAnNoPt0Z", "metadata": { "id": "YPsbAnNoPt0Z" }, "source": [ "## Automatic Installation\n", "Using my.johnsnowlabs.com SSO" ] }, { "cell_type": "code", "execution_count": null, "id": "fY0lcShkj51w", "metadata": { "id": "fY0lcShkj51w", "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "from johnsnowlabs import nlp, legal\n", "\n", "# nlp.install(force_browser=True)" ] }, { "cell_type": "markdown", "id": "hsJvn_WWM2GL", "metadata": { "id": "hsJvn_WWM2GL" }, "source": [ "## Manual downloading\n", "If you are not registered in my.johnsnowlabs.com, you received a license via e-email or you are using Safari, you may need to do a manual update of the license.\n", "\n", "- Go to my.johnsnowlabs.com\n", "- Download your license\n", "- Upload it using the following command" ] }, { "cell_type": "code", "execution_count": null, "id": "i57QV3-_P2sQ", "metadata": { "id": "i57QV3-_P2sQ" }, "outputs": [], "source": [ "from google.colab import files\n", "print('Please Upload your John Snow Labs License using the button below')\n", "license_keys = files.upload()" ] }, { "cell_type": "markdown", "id": "xGgNdFzZP_hQ", "metadata": { "id": "xGgNdFzZP_hQ" }, "source": [ "- Install it" ] }, { "cell_type": "code", "execution_count": null, "id": "OfmmPqknP4rR", "metadata": { "id": "OfmmPqknP4rR" }, "outputs": [], "source": [ "nlp.install()" ] }, { "cell_type": "markdown", "id": "DCl5ErZkNNLk", "metadata": { "id": "DCl5ErZkNNLk" }, "source": [ "# Starting" ] }, { "cell_type": "code", "execution_count": null, "id": "wRXTnNl3j51w", "metadata": { "id": "wRXTnNl3j51w" }, "outputs": [], "source": [ "spark = nlp.start()" ] }, { "cell_type": "markdown", "id": "e9d3bac3", "metadata": { "id": "e9d3bac3" }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "id": "fd9e23a6", "metadata": { "id": "fd9e23a6" }, "source": [ "For the text classification tasks, we will use two annotators:\n", "\n", "- `MultiClassifierDL`: `Multilabel Classification` (can predict more than one class for each text) using a Bidirectional GRU with Convolution architecture built with TensorFlow that supports up to 100 classes. The inputs are Sentence Embeddings such as state-of-the-art UniversalSentenceEncoder, BertSentenceEmbeddings or SentenceEmbeddings.\n", "- `ClassifierDL`: uses the state-of-the-art Universal Sentence Encoder as an input for text classifications. Then, a deep learning model (DNNs) built with TensorFlow that supports `Binary Classification` and `Multiclass Classification` (up to 100 classes)." ] }, { "cell_type": "markdown", "id": "BhKlxMBWH9-Z", "metadata": { "id": "BhKlxMBWH9-Z" }, "source": [ "## Loading the data" ] }, { "cell_type": "code", "execution_count": null, "id": "2Ab4pyK6MyCZ", "metadata": { "id": "2Ab4pyK6MyCZ" }, "outputs": [], "source": [ "! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/finance_data.csv" ] }, { "cell_type": "code", "execution_count": null, "id": "gaRulasEKZK3", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gaRulasEKZK3", "outputId": "cbb0c070-44c9-4c41-aa66-2f8ec08cea89" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the full dataset: (27527, 2)\n" ] } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('./finance_data.csv')\n", "df['label'] = df['label'].apply(eval)\n", "print(f\"Shape of the full dataset: {df.shape}\")" ] }, { "cell_type": "markdown", "id": "jMRVLpGDq3zV", "metadata": { "id": "jMRVLpGDq3zV" }, "source": [ "> We will use a sample from this dataset to avoid making the training process faster (to illustrate how to perform them). Use the full dataset if you want to experiment with it and achieve more realistic results. \n", ">\n", "> The sample has size of 1000 observations only, please keep in mind that this will impact the accuracy and generalization capabilities of the model. Since the dataset is smaller now, we use 90% of it to train the model and the other 10% for testing." ] }, { "cell_type": "code", "execution_count": null, "id": "bUFO8usqKt2Z", "metadata": { "id": "bUFO8usqKt2Z" }, "outputs": [], "source": [ "data = spark.createDataFrame(df)\n", "\n", "# If you have a single dataset, then split it or else you can load the test dataset the same way that you load the train data.\n", "train, test = data.limit(1000).randomSplit([0.7, 0.3], seed=42)" ] }, { "cell_type": "code", "execution_count": null, "id": "bowhz89ZbN1C", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bowhz89ZbN1C", "outputId": "dd14b7f6-290d-4f43-e45d-2f528f0dcb54" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+-----------------------------------+\n", "| provision| label|\n", "+--------------------------------------------------+-----------------------------------+\n", "|(a) No failure or delay of the Administrative A...| [waivers, amendments]|\n", "|(a) Seller, the Agent, each Managing Agent, eac...| [assignments]|\n", "|(a) To induce the other parties hereto to enter...| [representations, warranties]|\n", "|(a)  The provisions of this Agreement shall be ...| [assigns, successors]|\n", "|(a) All of the representations and warranties m...| [representations, warranties]|\n", "|(a) THIS AGREEMENT AND ANY CLAIM, CONTROVERSY, ...|[governing laws, entire agreements]|\n", "|All Bank Expenses (including reasonable attorne...| [expenses]|\n", "|All agreements, covenants, representations, war...| [terminations]|\n", "|All agreements, representations and warranties ...| [survival]|\n", "|All communications hereunder will be in writing...| [notices]|\n", "|All costs and expenses incurred in connection w...| [expenses]|\n", "|All covenants of the Company contained in this ...| [survival]|\n", "|All covenants, agreements, representations and ...| [survival]|\n", "|All covenants, agreements, representations and ...| [survival]|\n", "|All covenants, agreements, representations and ...| [survival]|\n", "|All demands, notices and communications hereund...| [notices]|\n", "|All demands, notices and communications hereund...| [notices]|\n", "|All indemnities set forth herein including, wit...| [survival]|\n", "|All non-competition, non-solicitation, non-disc...| [survival]|\n", "|All notices and other communications given or m...| [notices]|\n", "+--------------------------------------------------+-----------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "train.show(truncate=50)" ] }, { "cell_type": "code", "execution_count": null, "id": "RZ0xy9hpebLe", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RZ0xy9hpebLe", "outputId": "d0642fce-1726-42ae-a4c3-257d1e7b00d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----+\n", "| label|count|\n", "+--------------------+-----+\n", "| [governing laws]| 35|\n", "| [notices]| 31|\n", "| [severability]| 27|\n", "| [entire agreements]| 27|\n", "| [counterparts]| 24|\n", "| [survival]| 19|\n", "|[assigns, success...| 14|\n", "| [terminations]| 14|\n", "| [amendments]| 13|\n", "| [expenses]| 11|\n", "| [assignments]| 10|\n", "|[waivers, amendme...| 8|\n", "| [waivers]| 7|\n", "|[amendments, enti...| 3|\n", "| [representations]| 3|\n", "| [successors]| 2|\n", "|[amendments, term...| 2|\n", "|[representations,...| 2|\n", "| [warranties]| 1|\n", "|[severability, su...| 1|\n", "+--------------------+-----+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import col\n", "\n", "test.groupBy(\"label\").count().orderBy(col(\"count\").desc()).show()\n" ] }, { "cell_type": "markdown", "id": "78sjKvMKcHdn", "metadata": { "id": "78sjKvMKcHdn" }, "source": [ " ## With Universal Encoder" ] }, { "cell_type": "code", "execution_count": null, "id": "oFl9orR5LrUu", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oFl9orR5LrUu", "outputId": "eee1ba98-31db-4985-a658-c5cd5c695859" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfhub_use download started this may take some time.\n", "Approximate size to download 923.7 MB\n", "[OK!]\n" ] } ], "source": [ "document_assembler = (\n", " nlp.DocumentAssembler()\n", " .setInputCol(\"provision\")\n", " .setOutputCol(\"document\")\n", " .setCleanupMode(\"shrink\")\n", ")\n", "\n", "embeddings = (\n", " nlp.UniversalSentenceEncoder.pretrained()\n", " .setInputCols(\"document\")\n", " .setOutputCol(\"sentence_embeddings\")\n", ")\n", "\n", "classifierdl = (\n", " nlp.MultiClassifierDLApproach()\n", " .setInputCols([\"sentence_embeddings\"])\n", " .setOutputCol(\"class\")\n", " .setLabelColumn(\"label\")\n", " .setMaxEpochs(20)\n", " .setLr(0.001)\n", " .setRandomSeed(42)\n", " .setEnableOutputLogs(True)\n", " .setOutputLogsPath(\"multilabel_use_logs\")\n", " .setBatchSize(8)\n", ")\n", "\n", "clf_pipeline = nlp.Pipeline(stages=[document_assembler, embeddings, classifierdl])" ] }, { "cell_type": "markdown", "id": "19d62bec", "metadata": { "id": "19d62bec" }, "source": [ "Since this model can takes longer time to train, we will limit (reduce) the size of the training data to avoid having it training for hours. \n", "\n", "> Please note that this reduction can greatly impact the performance of the model" ] }, { "cell_type": "code", "execution_count": null, "id": "8kcgVr07MUss", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8kcgVr07MUss", "outputId": "a2bab13d-d887-4241-a281-2cf7f9f6b8d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 494 ms, sys: 67.1 ms, total: 561 ms\n", "Wall time: 1min 26s\n" ] } ], "source": [ "%%time\n", "clf_pipelineModel = clf_pipeline.fit(train)" ] }, { "cell_type": "code", "execution_count": null, "id": "p88hMraHOac9", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p88hMraHOac9", "outputId": "4712855b-091b-442f-b3b3-1b796463fa71" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training started - epochs: 20 - learning_rate: 0.001 - batch_size: 8 - training_examples: 744 - classes: 15\n", "Epoch 0/20 - 5.90s - loss: 0.31367278 - acc: 0.91523325 - batches: 93\n", "Epoch 1/20 - 2.32s - loss: 0.20648386 - acc: 0.93324363 - batches: 93\n", "Epoch 2/20 - 1.74s - loss: 0.15775694 - acc: 0.9456988 - batches: 93\n", "Epoch 3/20 - 1.76s - loss: 0.13085833 - acc: 0.9548385 - batches: 93\n", "Epoch 4/20 - 1.72s - loss: 0.11435161 - acc: 0.9614694 - batches: 93\n", "Epoch 5/20 - 1.71s - loss: 0.1033926 - acc: 0.965412 - batches: 93\n", "Epoch 6/20 - 1.67s - loss: 0.09538201 - acc: 0.96827936 - batches: 93\n", "Epoch 7/20 - 1.69s - loss: 0.08915223 - acc: 0.9700714 - batches: 93\n", "Epoch 8/20 - 1.72s - loss: 0.08416093 - acc: 0.9717739 - batches: 93\n", "Epoch 9/20 - 1.67s - loss: 0.08005884 - acc: 0.9731181 - batches: 93\n", "Epoch 10/20 - 1.68s - loss: 0.07660815 - acc: 0.9741037 - batches: 93\n", "Epoch 11/20 - 1.66s - loss: 0.07365137 - acc: 0.9750894 - batches: 93\n", "Epoch 12/20 - 1.67s - loss: 0.071067244 - acc: 0.9752686 - batches: 93\n", "Epoch 13/20 - 1.71s - loss: 0.0687695 - acc: 0.97598547 - batches: 93\n", "Epoch 14/20 - 1.70s - loss: 0.066702016 - acc: 0.97697115 - batches: 93\n", "Epoch 15/20 - 1.69s - loss: 0.06482201 - acc: 0.97732955 - batches: 93\n", "Epoch 16/20 - 1.66s - loss: 0.063097924 - acc: 0.9780463 - batches: 93\n", "Epoch 17/20 - 2.00s - loss: 0.061507963 - acc: 0.9785841 - batches: 93\n", "Epoch 18/20 - 1.69s - loss: 0.060031768 - acc: 0.978853 - batches: 93\n", "Epoch 19/20 - 1.73s - loss: 0.05865668 - acc: 0.9790321 - batches: 93\n", "\n" ] } ], "source": [ "import os\n", "log_file_name = os.listdir(\"multilabel_use_logs\")[0]\n", "\n", "with open(\"multilabel_use_logs/\"+log_file_name, \"r\") as log_file :\n", " print(log_file.read())" ] }, { "cell_type": "code", "execution_count": null, "id": "cwf0OM2Oy6Hd", "metadata": { "id": "cwf0OM2Oy6Hd" }, "outputs": [], "source": [ "preds = clf_pipelineModel.transform(test)" ] }, { "cell_type": "code", "execution_count": null, "id": "jUKamU-50kZi", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "jUKamU-50kZi", "outputId": "365767a2-02b9-4b4f-ace2-bc2e14f44562" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelprovisionresult
0[assigns, successors](a) The provisions of this Agreement shall be ...[successors]
1[waivers](a) Any provision of this Agreement may be wai...[waivers, amendments]
2[waivers, amendments](a) This Agreement may be amended, supplemente...[waivers]
3[counterparts](a) This Agreement may be executed by one or m...[counterparts]
4[survival]All agreements, representations and warranties...[survival]
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " label provision \\\n", "0 [assigns, successors] (a) The provisions of this Agreement shall be ... \n", "1 [waivers] (a) Any provision of this Agreement may be wai... \n", "2 [waivers, amendments] (a) This Agreement may be amended, supplemente... \n", "3 [counterparts] (a) This Agreement may be executed by one or m... \n", "4 [survival] All agreements, representations and warranties... \n", "\n", " result \n", "0 [successors] \n", "1 [waivers, amendments] \n", "2 [waivers] \n", "3 [counterparts] \n", "4 [survival] " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds_df = preds.select('label','provision',\"class.result\").toPandas()\n", "preds_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "GGIRz-xtaY82", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GGIRz-xtaY82", "outputId": "950594fb-5da9-4923-e3e6-1fb7b8f9c4b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classification report: \n", " precision recall f1-score support\n", "\n", " 0 0.85 0.42 0.56 26\n", " 1 0.00 0.00 0.00 10\n", " 2 0.88 0.50 0.64 14\n", " 3 1.00 1.00 1.00 24\n", " 4 0.97 0.97 0.97 30\n", " 5 0.86 0.55 0.67 11\n", " 6 0.97 0.92 0.94 36\n", " 7 0.93 0.81 0.86 31\n", " 8 0.60 0.60 0.60 5\n", " 9 0.93 0.93 0.93 30\n", " 10 0.91 0.62 0.74 16\n", " 11 0.79 0.55 0.65 20\n", " 12 0.88 0.44 0.58 16\n", " 13 1.00 0.69 0.81 16\n", " 14 0.40 0.67 0.50 3\n", "\n", " micro avg 0.91 0.72 0.80 288\n", " macro avg 0.80 0.64 0.70 288\n", "weighted avg 0.88 0.72 0.78 288\n", " samples avg 0.73 0.74 0.73 288\n", "\n", "F1 micro averaging: 0.8038834951456311\n", "ROC: 0.8565596846846847\n" ] } ], "source": [ "from sklearn.preprocessing import MultiLabelBinarizer\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import f1_score\n", "from sklearn.metrics import roc_auc_score\n", "\n", "mlb = MultiLabelBinarizer()\n", "\n", "y_true = mlb.fit_transform(preds_df['label'])\n", "y_pred = mlb.transform(preds_df['result'])\n", "\n", "\n", "print(\"Classification report: \\n\", (classification_report(y_true, y_pred)))\n", "print(\"F1 micro averaging:\",(f1_score(y_true, y_pred, average='micro')))\n", "print(\"ROC: \",(roc_auc_score(y_true, y_pred, average=\"micro\")))\n" ] } ], "metadata": { "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.9 (tags/v3.10.9:1dd9be6, Dec 6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "eaaf1efa59a88a5a39769af8d2cb09ad2646375b6f26a94367c2f840c23a5e60" } } }, "nbformat": 4, "nbformat_minor": 5 }