{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*ktrain* uses TensorFlow 2. To support sequence-tagging, *ktrain* also currently uses the CRF module from `keras_contrib`, which is not yet fully compatible with TensorFlow 2.\n", "To use the BiLSTM-CRF model (which currently requires `keras_contrib`) for sequence-tagging in *ktrain*, you must disable V2 behavior in TensorFlow 2\n", "by adding the following line to the top of your notebook or script **before** importing *ktrain*:\n", "```python\n", "import os\n", "os.environ['DISABLE_V2_BEHAVIOR'] = '1'\n", "```\n", "Since we are employing a CRF layer in this notebook, we will set this value here:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "os.environ['DISABLE_V2_BEHAVIOR'] = '1'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using DISABLE_V2_BEHAVIOR with TensorFlow\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sequence Tagging\n", "\n", "Sequence tagging (or sequence labeling) involves classifying words or sequences of words as representing some category or concept of interest. One example of sequence tagging is Named Entity Recognition (NER), where we classify words or sequences of words that identify some entity such as a person, organization, or location. In this tutorial, we will show how to use *ktrain* to perform sequence tagging in three simple steps.\n", "\n", "## STEP 1: Load and Preprocess Data\n", "\n", "The `entities_from_txt` function can be used to load tagged sentences from a text file. The text file can be in one of two different formats: 1) the [CoNLL2003 format](https://www.aclweb.org/anthology/W03-0419) or 2) the [Groningen Meaning Bank (GMB) format](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus). In both formats, there is one word and its associated tag on each line (where the word and tag are delimited by a space, tab or comma). Words are ordered as they appear in the sentence. In the CoNLL2003 format, there is a blank line that delineates sentences. In the GMB format, there is a third column for Sentence ID that assignes a number to each row indicating the sentence to which the word belongs. If you are building a sequence tagger for your own use case with the `entities_from_txt` function, the training data should be formatted into one of these two formats. Alternatively, one can use the `entities_from_aray` function which simply expects arrays of the following form:\n", "```python\n", "x_train = [['Hello', 'world', '!'], ['Hello', 'Barack', 'Obama'], ['I', 'love', 'Chicago']]\n", "y_train = [['O', 'O', 'O'], ['O', 'B-PER', 'I-PER'], ['O', 'O', 'B-LOC']]\n", "```\n", "Note that the tags in this example follow the [IOB2 format](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging)).\n", "\n", "In this notebook, we will be using `entities_from_txt` and build a sequence tagger using the Groningen Meaning Bank NER dataset available on Kaggle [here](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus). The format essentially looks like this (with fields being delimited by comma):\n", "```\n", " SentenceID Word Tag \n", " 1 Paul B-PER\n", " 1 Newman I-PER\n", " 1 is O\n", " 1 a O\n", " 1 great O\n", " 1 actor O\n", " 1 . O\n", " ```\n", "\n", "We will be using the file `ner_dataset.csv` (which conforms to the format above) and will load and preprocess it using the `entities_from_txt` function. The output is simlar to data-loading functions used in previous tutorials and includes the processed training set, processed validaton set, and an instance of `NERPreprocessor`. \n", "\n", "The Kaggle dataset `ner_dataset.csv` the three columns of interest (mentioned above) are labeled 'Sentence #', 'Word', and 'Tag'. Thus, we specify these in the call to the function." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: WINDOWS-1250 (if wrong, set manually)\n", "Number of sentences: 47959\n", "Number of words in the dataset: 35178\n", "Tags: ['B-art', 'I-art', 'I-eve', 'B-geo', 'B-gpe', 'I-per', 'O', 'B-tim', 'I-gpe', 'B-nat', 'B-eve', 'B-org', 'I-nat', 'B-per', 'I-org', 'I-tim', 'I-geo']\n", "Number of Labels: 17\n", "Longest sentence: 104 words\n" ] } ], "source": [ "DATAFILE = '/home/amaiya/data/groningen_meaning_bank/ner_dataset.csv'\n", "(trn, val, preproc) = text.entities_from_txt(DATAFILE,\n", " sentence_column='Sentence #',\n", " word_column='Word',\n", " tag_column='Tag', \n", " data_format='gmb',\n", " use_char=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When loading the dataset above, we specify `use_char=True` to instruct *ktrain* to extract the character vocabulary to be used in a character embedding layer of a model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Define a Model\n", "\n", "The `print_sequence_taggers` function shows that, as of this writing, *ktrain* currently supports both Bidirectional LSTM-CRM and Bidirectional LSTM as base models for sequence tagging. Theses base models can be used with different embedding schemes.\n", "\n", "For instance, the `bilstm-bert` model employs [BERT word embeddings](https://arxiv.org/abs/1810.04805) as features for a Bidirectional LSTM. See [this notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/CoNLL2002_Dutch-BiLSTM.ipynb) for an example of `bilstm-bert`. In this tutorial, we will use a Bidirectional LSTM model with a CRF layer. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bilstm: Bidirectional LSTM (https://arxiv.org/abs/1603.01360)\n", "bilstm-bert: Bidirectional LSTM w/ BERT embeddings\n", "bilstm-crf: Bidirectional LSTM-CRF (https://arxiv.org/abs/1603.01360)\n", "bilstm-elmo: Bidirectional LSTM w/ Elmo embeddings [English only]\n", "bilstm-crf-elmo: Bidirectional LSTM-CRF w/ Elmo embeddings [English only]\n" ] } ], "source": [ "text.print_sequence_taggers()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embedding schemes employed (combined with concatenation):\n", "\tword embeddings initialized with fasttext word vectors (cc.en.300.vec.gz)\n", "\tcharacter embeddings\n", "\n", "pretrained word embeddings will be loaded from:\n", "\thttps://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz\n", "loading pretrained word vectors...this may take a few moments...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "WV_URL = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz'\n", "model = text.sequence_tagger('bilstm-crf', preproc, wv_path_or_url=WV_URL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the cell above, notice that we suppied the `wv_path_or_url` argument. This directs *ktrain* to initialized word embeddings with one of the pretrained fasttext (word2vec) word vector sets from [Facebook's fastttext site](https://fasttext.cc/docs/en/crawl-vectors.html). When supplied with a valid URL to a `.vec.gz`, the word vectors will be automatically downloaded, extracted, and loaded in STEP 2 (download location is `/ktrain_data`). To disable pretrained word embeddings, set `wv_path_or_url=None` and randomly initialized word embeddings will be employed. Use of pretrained embeddings will typically boost final accuracy. When used in combination with a model that uses an embedding scheme like BERT (e.g., `bilstm-bert`), the different word embeddings are stacked together using concatenation.\n", "\n", "Finally, we will wrap our selected model and datasets in a `Learner` object to facilitate training." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train and Evaluate the Model\n", "\n", "Here, we will train for a single epoch using an initial learning rate of 0.01 with gradual decay using cosine annealing (via the `cycle_len=1`) parameter and see how well we do. The learning rate of `0.01` is determined with the learning-rate-finder (i.e., `lr_find`)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 337 steps\n", "Epoch 1/1024\n", "337/337 [==============================] - 144s 426ms/step - loss: 1.2752\n", "Epoch 2/1024\n", "337/337 [==============================] - 138s 408ms/step - loss: 0.6956\n", "Epoch 3/1024\n", "337/337 [==============================] - 137s 407ms/step - loss: 0.2069\n", "Epoch 4/1024\n", "337/337 [==============================] - 136s 405ms/step - loss: 0.0684\n", "Epoch 5/1024\n", "160/337 [=============>................] - ETA: 1:12 - loss: 0.1804\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hUdd7+8fdnUgkJNaGDoQpIUYlIURHbYllxFbuuKPa66u5v3fVx12fX3XXV3X10rdh1bVjWBooNbNQA0ltAhCCShJoQ0r+/P2bAISQhwZycTOZ+XddcnDYzd8Y4d86cOd9jzjlERCR6BfwOICIi/lIRiIhEORWBiEiUUxGIiEQ5FYGISJRTEYiIRLlYvwPUVWpqqktPT/c7hohIRJk3b16ecy6tqnURVwTp6elkZmb6HUNEJKKY2XfVrdNHQyIiUU5FICIS5VQEIiJRTkUgIhLlVAQiIlFORSAiEuWivgg2bt/NloJiv2OIiPgm6otg5L2fMexvn1a5rqC4jG27Sur0eG/Nz6bfXR/y1eq8+ognIuK5qC6CotJyAErLHc/PWMe877btXffW/GwG/HEqx903bZ/7bCkoZtn3O9lSUMz9U1eQm1/Mnov7PP75Gm6btJDdpeVc8vRsRvztU37YUcQ732xkUuaGhvvBRETqIOLOLK5P909duXf6j+8uBWDunSeRlpKwd11+cRnLN+2kuKyC3u2SOfqvn1JW4bhmVA+e+Hwtj0xbs89jdmqZyD/PP5wLJs7i+x1F++xtHNmtFVOXbuacI7vQoWUiACVlFbw+bwNHd2/LSf/8HICh3dvwylXDiAmYpz+/iAiARdqlKjMyMlx9DDExedEmbnh5/n7L42MD3P3zw/j9fxdzfkZXXqvjX/Jzfn8i7Vok8uXqXG6btJDc/P2PPyTGBXji0gz+8M4SvttSWOXjvHX9CF6evZ5hPdoybkiXOmUQEanMzOY55zKqWhc1Hw1tLywh/Y7JTFuZQ25+8d4S6J7anI6hv84h+Bf67/+7GIAbRvfi2N6p1T7mdcf3ZGSvtjx/xVAObZ/C2MM70a5F8LGO7Z3G1789gSGHtKZvhxTiYn78676otILLnpmzXwn07ZDCL4cfAsDZj87gjXnZ/Pr1hVRURFZZi0hkiZqPhtbkFgBw+bNzObFvu73Ln7h0CKnJCXydlcdNryzYu3z8iHS6tU3ihSuGcu8HK1j+Qz5frMrljWuH8+XqPN5b+D1XHduDNs3jARjVZ/9B/eJjA7x53QgAnHOs21JIanI8L89ez2crcjjlsA6s3pzPiF6pfJu7i18OP4RWSXG8MHPfsaF6/H4KH996HL3bp9T76yIiEjUfDWWu28q4x2fut3zdvafvMz9tRQ6XPzeXL//faLq2STronD9FYUkZFz45m2uP68F1LwX3XG4+oRe3nXKoL3lEJPLV9NFQ1OwR5BeV7bfsjWuH77dsdN92+5VDQ0uKj+WdG0YC8D+n9+Oeyct5de4GFmbvYFH2dt654Ri6tfWnpESk6YmaYwQ7i0r3mR8/Ip2M9DY+pam9K4/tQXxMgJz8Yj5flcu2wlKOu38a6XdM5u0FG/duV67jCCJykKLmo6HisnI2bC2ktNwREzC6pzYnLiYyejBz3VZum7SQbYUltGwWR/a23XvX/fmsASz7fgevzNnAuUO6cP+5g/euW5S9neSEWD5bkcNlI9Ij5ucVkfpX00dDUVMETYFzjh27Szn/iVk0T4hh/vrt+22TFB9D66R4Nm7fvd+6t64fweAurXR+gkgUUhE0QQXFZfxn1nes31rIko07uOrYHvt862mP+JgAJeUVe+fHDenCA2F7DSISHXSwuAlKTojl2lE9986XlFVw8wm92LBtNwEzjuuTypmDO2EW/Ot/2F8/5YedRbwxL5vOrZpx68l9/IouIo2M9giiRHmFY/rKHCY8H3zt/P5mlIg0LO0RCDEB48R+7enUMpG4WB00FpEf6R0hyvxsQAfywkZMFRHxrAjM7BkzyzGzJdWsv9jMFpnZYjObYWY6gtkAurVJYldJOf+Z9d2BNxaRqODlHsFzwJga1n8LjHLODQT+DEz0MIuEjD28MzEB4653lu69HoOIRDfPisA59wWwtYb1M5xze64EMwvQWMsNoE3zeO7+eX8A+t71IbuK9x96Q0SiS2M5RjAB+MDvENHikmGH7D2p7PNVufsNvyEi0cX3IjCz0QSL4Lc1bHO1mWWaWWZubm7DhWuizIwFfzgZgOtfms+guz/i27xdPqcSEb/4WgRmNgh4ChjrnNtS3XbOuYnOuQznXEZa2v7j/kvdtUiMY3iPtnvnRz8wnXUqA5Go5FsRmFk34C3gUufcKr9yRLNnLz+KV68expBDWgNw/APTufzZObwyZ73PyUSkIXl2ZrGZvQIcD6QCm4E/AnEAzrnHzewp4Bxgz/cYy6o76y2cziz2xoOfrOZfn/zYx786qTc3ndBbA9SJNBEadE5qZd5323hs+ho+Wb4ZgPMzuvL3cYN8TiUi9UEXr5daGXJIa566LIP7Q2/+r2VuYPXmfJ9TiYjXVASyn3MzuvKv84Mnei/euMPnNCLiNRWBVOnUAR1JiA2Q+d22A28sIhFNRSBVSoyL4YxBnXh59nrmr1cZiDRlKgKp1tlHdgZg3GMzdPaxSBOmIpBqjeyVypXHdKfCwaC7P2L9lkK/I4mIB1QEUqPfndZv7/T4Z+ewvbDExzQi4gUVgdQoJmDMufNExh7eibV5uzj9oa9Yom8SiTQpKgI5oHYpidw3bhDnZ3Rl4/bdnPHvr8jXMQORJkNFILWSEBvD38cN4pJh3QB455vvfU4kIvVFRSB18oczDqN9iwTe+Waj31FEpJ6oCKRO4mMDnH9UN+au20b6HZNZk1vgdyQR+YlUBFJnN4zuycDOLQF4ceZ3B9haRBo7FYHUWUJsDO/ddAyjD03jsxU5RNoItiKyLxWBHLRTB3Rk/dZCPluR43cUEfkJVARy0E4f1JHm8TH8/cMV+jqpSARTEchBa54Qy/Wje7FqcwGXPDXb7zgicpBUBPKT3DC6Fyf1a8/C7B1k5egbRCKRSEUgP9m95wwkITawzzWPRSRyqAjkJ0tNTmD8iHQmL9rE8k07/Y4jInWkIpB6ccUx3UmMC/Dvz1b7HUVE6khFIPWifYtEfjk8nSmLf+CzFZv9jiMidaAikHoz4ZjuAFzxXCZFpeU+pxGR2vKsCMzsGTPLMbMl1aw3M3vIzLLMbJGZHelVFmkY7Vsk8vPBnQA48+GvdBEbkQjh5R7Bc8CYGtafCvQO3a4GHvMwizSQhy44nN+O6cuqzQXcPmmh33FEpBY8KwLn3BfA1ho2GQu84IJmAa3MrKNXeaRhmBnXHd+Tvh1S+HJ1HsVl+ohIpLHz8xhBZ2BD2Hx2aNl+zOxqM8s0s8zc3NwGCSc/za9O6k1JeQVLv9fXSUUau4g4WOycm+icy3DOZaSlpfkdR2ohI70NMQHjwyU/+B1FRA7AzyLYCHQNm+8SWiZNQGpyAmMGdOCV2evZUlDsdxwRqYGfRfAu8MvQt4eGATucc5t8zCP17NrjepJfXMbHy3RegUhj5uXXR18BZgKHmlm2mU0ws2vN7NrQJlOAtUAW8CRwvVdZxB8DOregc6tmKgKRRi7Wqwd2zl14gPUOuMGr5xf/mRljBnTg6a++ZV3eLtJTm/sdSUSqEBEHiyVyXTrsEAIGT3/1rd9RRKQaKgLxVHpqc04d2JHJizdp2AmRRkpFIJ47L6MrW3eV8OlyXdtYpDFSEYjnRvZsS6ukOD5droPGIo2RikA8FxsTYPSh7Zi2MofyCud3HBGpREUgDeLY3qlsKyxldU6+31FEpBIVgTSII7q1BiBz3Tafk4hIZSoCaRDpbZPo1S6ZSZkbCJ5CIiKNhYpAGoSZcdnwQ1iUvYNF2Tv8jiMiYVQE0mDOPLwzsQHjA41IKtKoqAikwbRsFsewHm35eJmKQKQxURFIgzq5f3vW5O4iK6fA7ygiEqIikAZ1cv/2ALy/6Hufk4jIHioCaVCdWjVjWI82unKZSCOiIpAGd0yvVFb8kE/OziK/o4gIKgLxwWkDOwIwKXODz0lEBFQE4oMeackc3b0Nby3YqJPLRBoBFYH44hdHdGZt7i4Wb9TJZSJ+UxGIL04d2JH4mACTF23yO4pI1FMRiC9aNotjUJeWzFm31e8oIlFPRSC+yUhvw5KNO9hdoktYivhJRSC+GdajDaXljllrt/gdRSSqeVoEZjbGzFaaWZaZ3VHF+m5mNs3MFpjZIjM7zcs80rgM79mW5IRYPtLYQyK+8qwIzCwGeAQ4FegPXGhm/Stt9j/AJOfcEcAFwKNe5ZHGJyE2huMPTeOjpZspLa/wO45I1PJyj2AokOWcW+ucKwFeBcZW2sYBLULTLQENQBNlzhjUkS27SnTlMhEfeVkEnYHwU0ezQ8vC3Q1cYmbZwBTgpqoeyMyuNrNMM8vMzc31Iqv4ZFiPtgDM+07fHhLxi98Hiy8EnnPOdQFOA140s/0yOecmOucynHMZaWlpDR5SvNMqKZ4+7ZOZtVZFIOIXL4tgI9A1bL5LaFm4CcAkAOfcTCARSPUwkzRCJ/Rtz4w1eewsKvU7ikhU8rII5gK9zay7mcUTPBj8bqVt1gMnAphZP4JFoM9+osyoPmlUOJijvQIRX3hWBM65MuBGYCqwnOC3g5aa2Z/M7MzQZrcDV5nZQuAVYLzTKGRR54hurUiIDTBjjc4nEPFDrJcP7pybQvAgcPiyP4RNLwNGeplBGr/EuBgy0lszY02e31FEopLfB4tFABjRM3ixmi0FxX5HEYk6KgJpFIb3DH6NVN8eEml4KgJpFAZ1bklyQqw+HhLxgYpAGoXYmABHd2+jA8YiPlARSKNxQr92fJu3iw8W62I1Ig2pVkVgZreYWQsLetrM5pvZKV6Hk+hyfkZXeqY159mv1/kdRSSq1HaP4Arn3E7gFKA1cClwr2epJCrFxgQ4fWBHMr/byvbCEr/jiESN2haBhf49DXjRObc0bJlIvTmhX3sqHExfqRPMRRpKbYtgnpl9RLAIpppZCqAB5KXeDercktTkBD5ZvtnvKCJRo7ZnFk8ADgfWOucKzawNcLl3sSRaBQLGiX3bMWXxJorLykmIjfE7kkiTV9s9guHASufcdjO7hOCVxXZ4F0ui2RmDO5JfXMaHS3QJS5GGUNsieAwoNLPBBAeKWwO84FkqiWoje6aS3jaJF2d+53cUkahQ2yIoC40KOhZ42Dn3CJDiXSyJZoGAcfHRh5D53TZWbc73O45Ik1fbIsg3s98R/Nro5NBVxOK8iyXR7szDOwHwxSp9e0jEa7UtgvOBYoLnE/xA8Gpj93uWSqJe+xaJdG3TjJkackLEc7UqgtCb/0tASzM7AyhyzukYgXjq5H4d+HJ1HoUlZX5HEWnSajvExHnAHOBc4DxgtpmN8zKYyPGHplFSXsFsDU0t4qnankdwJ3CUcy4HwMzSgE+AN7wKJjK0exuaxcXw8fLNjO7bzu84Ik1WbY8RBPaUQMiWOtxX5KAkxsUwum8a01bkoEtZi3intm/mH5rZVDMbb2bjgclUuhaxiBeO653Gph1FLN+kr5GKeKW2B4t/A0wEBoVuE51zv/UymAjAyf3bExMw3l/0vd9RRJqs2h4jwDn3JvCmh1lE9tM2OYFhPdrw8bLN/OZnh2KmQW9F6luNewRmlm9mO6u45ZvZzgM9uJmNMbOVZpZlZndUs815ZrbMzJaa2csH+4NI0zVmQEdW5xTowvYiHqmxCJxzKc65FlXcUpxzLWq6r5nFAI8ApwL9gQvNrH+lbXoDvwNGOucOA371k34aaZLOHdKFtJQEHvx0ld9RRJokL7/5MxTIcs6tdc6VAK8SHKso3FXAI865bQCVvpkkAgS/PXTdqJ7MWruVWWt1prFIffOyCDoDG8Lms0PLwvUB+pjZ12Y2y8zGeJhHIthFR3cL7hV8strvKCJNjt/nAsQCvYHjgQuBJ82sVeWNzOxqM8s0s8zcXA1CFo0S42K4dlRPZq7dwmztFYjUKy+LYCPQNWy+S2hZuGzgXedcqXPuW2AVwWLYh3NuonMuwzmXkZaW5llgadwuProbqckJPPip9gpE6pOXRTAX6G1m3c0sHrgAeLfSNm8T3BvAzFIJflS01sNMEsGCewU9mLFmC4uzdYE8kfriWRE458qAG4GpwHJgknNuqZn9yczODG02FdhiZsuAacBvnHPa75dqjRvShdiA8f5inWAmUl9qfULZwXDOTaHSUBTOuT+ETTvgttBN5IBaJcVzXJ80nvxiLedndKVHWrLfkUQint8Hi0Xq7Hen9qXCwV3vLPE7ikiToCKQiNO7fQrnDunCjDVbWJtb4HcckYinIpCI9P/G9CUuEODZr9f5HUUk4qkIJCKlpSQw9vBOvDEvm+2FJX7HEYloKgKJWBOO7c7u0nL+OmW531FEIpqKQCJW3w4tuHxkOpMys8nK0YVrRA6WikAi2o2je5EYF+DRaWv8jiISsVQEEtHaJicwfkR33lqwkSmLN/kdRyQiqQgk4t12ch+O6NaK3/93MSVlFX7HEYk4KgKJePGxAW4+oTfbC0u578MVlFc4vyOJRBQVgTQJx/VJ42eHteepr77l928tpqxcewYitaUikCYhJmA8cWkGN53Qi9cyN/DwtCy/I4lEDBWBNCm3n3IoJ/Vrz8OfZbHyB32lVKQ2VATS5PzqpN6UVTiueG4u32/f7XcckUZPRSBNzoDOLXn4oiPIKyjmV69+Q4UOHovUSEUgTdIZgzpxz1kDmLNuKxO/1EXvRGqiIpAma9yQLvTv2IKHP8vi9cwNfscRabRUBNJkmRl3nt6P+NgAv3ljEdNX5vgdSaRRUhFIkzayVyqTbz6G1klxjH92Luc9PpOi0nK/Y4k0KioCafI6tmzGx7eNYsxhHZizbiuD7v6I77bs8juWSKOhIpCokJqcwGOXHMm1o3pSUl7BaQ9+yXsLv/c7lkijoCKQqGFm3HFqXz67fRQdWzXjN28s5P1FKgMRFYFEnR5pyfxnwtEc2qEFN768gHveX6ZzDSSqeVoEZjbGzFaaWZaZ3VHDdueYmTOzDC/ziOzRoWUir18znHOO7MJTX33LlS9kaghriVqeFYGZxQCPAKcC/YELzax/FdulALcAs73KIlKV+NgAD5w7iF+f0ofPVuQw4I9TeeJzXelMoo+XewRDgSzn3FrnXAnwKjC2iu3+DPwdKPIwi0iVzIwbT+jNA+cOpqS8gr99sIJrXsxkR2Gp39FEGoyXRdAZCD+dMzu0bC8zOxLo6pybXNMDmdnVZpZpZpm5ubn1n1Si3rghXZh/18mcdXgnpi7dzFmPfq1vFUnU8O1gsZkFgH8Ctx9oW+fcROdchnMuIy0tzftwEpXaNI/n/y44gnvPHsjWXSXc9MoCHpi6kp1F2juQps3LItgIdA2b7xJatkcKMACYbmbrgGHAuzpgLH67YGg35t55EuOGdOHhaVkMuvsjxj78FTPW5PkdTcQTXhbBXKC3mXU3s3jgAuDdPSudczucc6nOuXTnXDowCzjTOZfpYSaRWomPDXD/uEHcfEIvEuMCbCss5aInZ3PxU7OYvXaL3/FE6lWsVw/snCszsxuBqUAM8IxzbqmZ/QnIdM69W/MjiPjLzLjtlEO57ZRD2VVcxj2Tl/Hq3A18nbWFsYd34s7T+tGuRaLfMUV+MnMusk6kycjIcJmZ2mkQfxSVlvPwZ1k89vka2qUkcNcZ/RnVJ43mCZ79TSVSL8xsnnOuyo/edWaxSB0kxsXw658dylvXjaCkrILrX5rPz//9FWtyC/yOJnLQVAQiB2Fw11Z8+dvR3HJib7K37ebn//6KR6dnUVymIa4l8qgIRA5SUnwst57chym3HMthnVpw34crufTpOeTr66YSYVQEIj9Rr3bJvH7tCB684HDmf7eNcx6bwYofdvodS6TWVAQi9WTs4Z15ZvxRbNpexPlPzGJdni5+I5FBRSBSj47rk8bT449iZ1EpJ/xjOk9+sVYfFUmjpyIQqWdDu7fh1auGMbhrK/4yZTmH/+lj7np7ia55II2WikDEA0f3aMt/rx/JW9eP4OR+7Xlx1nec+uCXZK7b6nc0kf2oCEQ8dGS31jx2yZE8cO5gtuwq4dwnZnLl83P5VscPpBFREYh4zMwYN6QLn//meC4a2o0vVufxs//7grcXbDzwnUUagIpApIE0T4jlL78YyKe3jWJwl5b86rVvePCT1UTaMC/S9KgIRBpY1zZJ/OfKozn7iM7865NVjHt8JlsKiv2OJVFMRSDig4TYGP5x3mD+dvZAlmzcwbjHZ7Jha6HfsSRKqQhEfGJmXDi0Gy9deTRbd5Uw9pGvWfa9zkiWhqciEPFZRnob3rxuODEB46oXMjWSqTQ4FYFII9CrXQrPjj+KotJyJjw3l5ydRX5HkiiiIhBpJAZ0bsljlwxh885iznl8BtsLS/yOJFFCRSDSiAzt3ob/XHk0P+wo4tbXvqFcw1JIA1ARiDQyQw5pzV1n9Gfaylye/mqt33EkCqgIRBqhS4cdwsn92/PXKSt4be56v+NIE6ciEGmEzIyHLzqC4/qkccdbi5myeJPfkaQJUxGINFIJsTE8dvGRHNapBbdPWsjqzfl+RxIfLVi/jcKSMk8e29MiMLMxZrbSzLLM7I4q1t9mZsvMbJGZfWpmh3iZRyTSNE+I5enLjqJ5QgzXvTSfHYW6yE00Kiot56InZ3PvBys8eXzPisDMYoBHgFOB/sCFZta/0mYLgAzn3CDgDeA+r/KIRKr2LRJ56IIjWL+lkAufnKUyiEJfZ+Wxu7Sck/q19+TxvdwjGApkOefWOudKgFeBseEbOOemOef2DLAyC+jiYR6RiDWiVypPXpZBVk4Bv3xmti5/GWU+WrqZ5IRYhvVo68nje1kEnYENYfPZoWXVmQB84GEekYg2qk8aj1x8JEu/38mE5zPZXVLudyRpAJ8u38ykeRs4fWBH4mO9ectuFAeLzewSIAO4v5r1V5tZppll5ubmNmw4kUbk5P7t+cd5g5m7bitXvZBJcZnKoCnbvLOIX7++kP4dW3D3mYd59jxeFsFGoGvYfJfQsn2Y2UnAncCZzrkqB2V3zk10zmU45zLS0tI8CSsSKcYe3pn7xw3mq6w8bnttIaXlFX5HEg9UVDhun7SQotIKHrrwCJrFx3j2XLGePTLMBXqbWXeCBXABcFH4BmZ2BPAEMMY5l+NhFpEmZdyQLmwvLOGeycvZWVTKIxcfSYvEOL9jST0pKavg168v5KusPP76i4H0TEv29Pk82yNwzpUBNwJTgeXAJOfcUjP7k5mdGdrsfiAZeN3MvjGzd73KI9LUXHlsD+47ZxAz12zhnEdnsGTjDr8jST2YtXYLI+79jHcXfs9Vx3bnwqFdD3ynn8gi7XqpGRkZLjMz0+8YIo3GjDV53PjyAnbuLuW3Y/oy4ZjuBALmdyw5CFk5BZz7+AxaJ8Vz5+n9OLEevy5qZvOccxlVrWsUB4tF5OCN6JnKtNuP58R+7fjLlOVc8fxcXQM5wjjneHHmOk765+cAPHlZRr2WwIGoCESagJZJcTx+yRD+PPYwZqzZwqkPfsmnyzf7HUtqobCkjNsnLeSud5ZySNsk3r3xGM+PCVSmIhBpIsyMS4en8/b1I2mdFM+E5zO56ZUF2jtoxHaXlHPNi/N4a8FGbj6xN9NuP56ubZIaPIeKQKSJ6d+pBe/ddAy3ntSHD5dsYvQD07nvwxWeDVgmB2frrhLGPzuHL1fncc9ZA7jt5D6+HdtREYg0QfGxAW45qTdvXTeSkb1SeXT6Gk7+5xe8t/B7Iu0LIk3R9sISxj02gwXrt/PnswZwyTB/x9tUEYg0YQO7BK+D/Pq1w2meEMNNryzgvCdmMn2lTtvxS05+ERdMnEX2tt28OGEol/pcAqAiEIkKR6W34cNbjuMvvxhA9rbdjH92Lje8PJ+N23f7HS2qbNhayLmPz2T91kKeGX8UR3s0iFxd6TwCkShTUlbB45+v4eFpWQBcfHQ3Lj66G73apficrGmbkZXH9S/Pxzl47vKjOKJb6wZ9/prOI1ARiESpjdt38/cPVjBl8SbKKhx9O6RwxcjunHVEZ89GuYxGzjnenL+R3/93MV1bN+PxS4bQu33Dl66KQESqlZNfxORFm3g9M5tlm3bStnk8Fw7txjWjepCi8Yt+ksKSMq5+YR5fZeWRcUhrnvxlBq2bx/uSRUUgIgfknOOL1Xm8NOs7Pl6+mbbN4xl9aDsGdW1F/44p9GqXQstmKobayly3lTv/u4TVOfn879gBXDy0m69Df9RUBF6OPioiEcTMGNUnjVF90licvYMHP13NJ8s38/q87NB6GNylFcf0SqVXu2QGd21FetskzDSuUbgdhaXc++FyXpmzgY4tE3nqsgxO6Ntww0UcDO0RiEi1nHNkb9vN6px8vtmwgy9W5bIwezt73jYCBi2axZGWnEC/ji3ISG/Nsb3TorIgCkvKePbrdTzx+Rp2FpVx1bHdufXkPiTFN46/t/XRkIjUm6LScr7N28XCDdvZsK2QnbvL+GFnEYuzd/DDziIAWiXFMbBzSzq0SKR183gqKhwxAaNrmyR6pDWnU8tmtGuR0GjeJOvKOUdBcRmbdxaxanMBs9duYfLiTeQVlHBi33b8+meH0q9jC79j7kMfDYlIvUmMi6FfxxZVvtFl5RSQuW4rC9ZvZ9mmnSzflM/OolJizCh3jpKyfa+m1jopjmZxMSTExRAfE6BdiwRSkxNolRRH2+bxJCfE0rp5PAmxMZRVVBAwIy0lgXYpCTRPiGXVD/nkF5eRm19McVkFvdol0zOtOfGxAeICAWJjjGZxMcTG1P5bUGXlFWzOL2bT9t3MXLOF1TkF5BUUk1dQzLbCUopKyykpq6A47GdJjAtwbO80rjmuBxnpbQ7+xfWJikBE6k2vdsn0apfMBUO77beuosKxOb+IrJwCcnYW88POIrK37aa0PPimWlRaTm5+Md/m7WJ7YSkFxfUzNlJswEhNTqB183jaNI8jLiZAYUk5xaXllJY7yiscO3YHn9AHfAEAAAq8SURBVC8uxthdWk5RafBN3gy6tUkiNTmB7qnNOTIpnsS4GOJjA6Qmx5OWkkDvdin0apdMYpx3l5L0mopARBpEIGB0bNmMji2b1Wr7otJy8ovK2FZYQklZBQmxAcoqHDn5xeTmF7O9sITe7VNoHh9D1zZJBMxYtTmf7G2FlJQ7yssrKKtwbCssITe/mK27StlWWMLO3WUkxcfQKimeuBgjNhAgOTGWFolxlJZXEBMwDu2QQseWiaF/a5c3kqkIRKRRSoyLITEuhrSUhH2W9+tY/X0qbyu1o9MHRUSinIpARCTKqQhERKKcikBEJMqpCEREopynRWBmY8xspZllmdkdVaxPMLPXQutnm1m6l3lERGR/nhWBmcUAjwCnAv2BC82sf6XNJgDbnHO9gH8Bf/cqj4iIVM3LPYKhQJZzbq1zrgR4FRhbaZuxwPOh6TeAEy3aRqoSEfGZlyeUdQY2hM1nA0dXt41zrszMdgBtgbzwjczsauDq0GyBma0MTbcEdtQwveff1MqPWQfhj1uX9ZWX1zRfU3Y4+Pz1lb02eaubbmzZq8vZkNlrm7OqZV7/zh9s9gNlrS5v+LT+f/X2/9dDqn1k55wnN2Ac8FTY/KXAw5W2WQJ0CZtfA6TW4Tkm1jQd9m/mT/g5Jh7M+srLa5qvKftPyV9f2WuTt4afo1Flr+Xr7Wn22ub8Kb83DZ39QFkbInt9/t401f9fq7t5+dHQRqBr2HyX0LIqtzGzWILNtaUOz/HeAabDlx2sAz1GdesrL69pvrFnr7ysrtMHw6vsleerer29zl7dNgebPXzar+xVLY/k3/lIzl552QGzeXY9gtAb+yrgRIJv+HOBi5xzS8O2uQEY6Jy71swuAM52zp3nQZZMV8043JEgkvMruz+U3T+RmN+zYwQu+Jn/jcBUIAZ4xjm31Mz+RHDX6V3gaeBFM8sCtgIXeBRnokeP21AiOb+y+0PZ/RNx+SPuCmUiIlK/dGaxiEiUUxGIiEQ5FYGISJSL+iIws2PN7HEze8rMZvidpy7MLGBmfzGzf5vZZX7nqSszO97Mvgy9/sf7naeuzKy5mWWa2Rl+Z6kLM+sXes3fMLPr/M5TF2Z2lpk9GRqj7BS/89SFmfUws6fN7A2/s1QW0UVgZs+YWY6ZLam0vMbB7sI55750zl0LvM+Pw114rj6yExyiowtQSvDM7QZTT/kdUAAk0oD56yk7wG+BSd6krFo9/c4vD/3OnweM9DJvuHrK/rZz7irgWuB8L/OGq6fsa51zE7xNepAO9gy+xnADjgOOBJaELYsheIZyDyAeWEhw0LuBBN/sw2/twu43CUiJpOzAHcA1ofu+EWmvPRAI3a898FKEZT+Z4NedxwNnRFL20H3OBD4geG5PRGUP3e8fwJERmr1B/1+tzS2iL17vnPuiiqGr9w52B2BmrwJjnXN/A6rchTezbsAO51y+h3H3UR/ZzSwbKAnNlnuXdn/19dqHbAMa7Krj9fTaHw80J/g//m4zm+Kcq/AyN9Tf6+6C5/G8a2aTgZe9S7zPc9bH627AvcAHzrn53ib+UT3/vjc6EV0E1ajNYHeVTQCe9SxR7dU1+1vAv83sWOALL4PVUp3ym9nZwM+AVsDD3kY7oDpld87dCWBm44G8hiiBGtT1dT8eOJtg+U7xNNmB1fV3/ibgJKClmfVyzj3uZbgDqOvr3hb4C3CEmf0uVBiNQlMsgjpzzv3R7wwHwzlXSLDEIpJz7i2CZRaxnHPP+Z2hrpxz04HpPsc4KM65h4CH/M5xMJxzWwge22h0IvpgcTVqM9hdYxXJ2SGy8yu7P5S9EWiKRTAX6G1m3c0snuABvXd9zlRbkZwdIju/svtD2RsDv49W/8Qj+a8Am/jx65MTQstPIzjy6RrgTr9zNrXskZ5f2ZU9mrLX5qZB50REolxT/GhIRETqQEUgIhLlVAQiIlFORSAiEuVUBCIiUU5FICIS5VQE4jkzK2iA5zizlkNH1+dzHm9mIw7ifkeY2dOh6fFm5vc4SwCYWXrlYZar2CbNzD5sqEzSMFQEEjHMLKa6dc65d51z93rwnDWNx3U8UOciAH5P5I6XkwtsMrMGu46BeE9FIA3KzH5jZnPNbJGZ/W/Y8rfNbJ6ZLTWzq8OWF5jZP8xsITDczNaZ2f+a2XwzW2xmfUPb7f3L2syeM7OHzGyGma01s3Gh5QEze9TMVpjZx2Y2Zc+6Shmnm9n/mVkmcIuZ/dzMZpvZAjP7xMzah4Ykvha41cy+seCV7tLM7M3Qzze3qjdLM0sBBjnnFlaxLt3MPgu9Np+GhkfHzHqa2azQz3tPVXtYFrxa2mQzW2hmS8zs/NDyo0Kvw0Izm2NmKaHn+TL0Gs6vaq/GzGLM7P6w/1bXhK1+G7i4yv/AEpn8PrVZt6Z/AwpC/54CTASM4B8h7wPHhda1Cf3bDFgCtA3NO+C8sMdaB9wUmr4eeCo0PR54ODT9HPB66Dn6ExwzHmAcwWGXA0AHgtdBGFdF3unAo2HzrWHvWfhXAv8ITd8N/Dpsu5eBY0LT3YDlVTz2aODNsPnw3O8Bl4WmrwDeDk2/D1wYmr52z+tZ6XHPAZ4Mm29J8GIpa4GjQstaEBxxOAlIDC3rDWSGptMJXXgFuBr4n9B0ApAJdA/NdwYW+/17pVv93TQMtTSkU0K3BaH5ZIJvRF8AN5vZL0LLu4aWbyF4wZ03Kz3OnqGr5xEcV78qb7vgNQKWmVn70LJjgNdDy38ws2k1ZH0tbLoL8JqZdST45vptNfc5CegfvHYKAC3MLNk5F/4XfEcgt5r7Dw/7eV4E7gtbflZo+mXggSruuxj4h5n9HXjfOfelmQ0ENjnn5gI453ZCcO8BeNjMDif4+vap4vFOAQaF7TG1JPjf5FsgB+hUzc8gEUhFIA3JgL85557YZ2HwQiknAcOdc4VmNp3gdYwBipxzla++Vhz6t5zqf4eLw6atmm1qsits+t/AP51z74ay3l3NfQLAMOdcUQ2Pu5sff7Z645xbZWZHEhwE7R4z+xT4bzWb3wpsBgYTzFxVXiO45zW1inWJBH8OaSJ0jEAa0lTgCjNLBjCzzmbWjuBfm9tCJdAXGObR838NnBM6VtCe4MHe2mjJj+PMXxa2PB9ICZv/iOAVtAAI/cVd2XKgVzXPM4PgUMYQ/Az+y9D0LIIf/RC2fh9m1gkodM79B7if4PV1VwIdzeyo0DYpoYPfLQnuKVQAlxK89m5lU4HrzCwudN8+oT0JCO5B1PjtIoksKgJpMM65jwh+tDHTzBYDbxB8I/0QiDWz5QSvRzvLowhvEhxCeBnwH2A+sKMW97sbeN3M5gF5YcvfA36x52AxcDOQETq4uowqrkblnFtB8DKLKZXXESyRy81sEcE36FtCy38F3BZa3quazAOBOWb2DfBH4B7nXAlwPsHLmS4EPib41/yjwGWhZX3Zd+9nj6cIvk7zQ18pfYIf975GA5OruI9EKA1DLVFlz2f2Frx+7BxgpHPuhwbOcCuQ75x7qpbbJwG7nXPOzC4geOB4rKcha87zBcGLtG/zK4PULx0jkGjzvpm1InjQ988NXQIhjwHn1mH7IQQP7hqwneA3inxhZmkEj5eoBJoQ7RGIiEQ5HSMQEYlyKgIRkSinIhARiXIqAhGRKKciEBGJcioCEZEo9/8BdBDx9Wd3YM4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preparing train data ...done.\n", "preparing valid data ...done.\n", "338/338 [==============================] - 123s 365ms/step - loss: 4.6233 - val_loss: 4.5265\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(1e-2, 1, cycle_len=1)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " F1: 84.19\n", " precision recall f1-score support\n", "\n", " tim 0.90 0.86 0.88 2078\n", " geo 0.84 0.90 0.87 3728\n", " org 0.75 0.69 0.72 1981\n", " per 0.81 0.78 0.79 1717\n", " gpe 0.97 0.93 0.95 1540\n", " eve 0.60 0.21 0.31 29\n", " art 0.00 0.00 0.00 47\n", " nat 0.57 0.19 0.29 21\n", "\n", "micro avg 0.85 0.84 0.84 11141\n", "macro avg 0.84 0.84 0.84 11141\n", "\n" ] }, { "data": { "text/plain": [ "0.8418623591692684" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our F1-score is **84.19** after a single pass through the dataset. Not bad for a single epoch of training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's invoke `view_top_losses` to see the sentence we got the most wrong. This single sentence about James Brown contains 10 words that are misclassified. We can see here that our model has trouble with titles of songs. In addition, some of the ground truth labels for this example are sketchy and incomplete, which also makes things difficult." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total incorrect: 10\n", "Word True : (Pred)\n", "==============================\n", "Mr. :B-per (B-per)\n", "Brown :I-per (I-per)\n", "is :O (O)\n", "known :O (O)\n", "by :O (O)\n", "millions :O (O)\n", "of :O (O)\n", "fans :O (O)\n", "as :O (O)\n", "\" :O (O)\n", "The :O (O)\n", "Godfather :B-per (B-org)\n", "of :O (O)\n", "Soul :B-per (B-per)\n", "\" :O (O)\n", "thanks :O (O)\n", "to :O (O)\n", "such :O (O)\n", "classic :O (O)\n", "songs :O (O)\n", "as :O (O)\n", "\" :O (O)\n", "Please :B-art (O)\n", ", :O (O)\n", "Please :O (B-geo)\n", ", :O (O)\n", "Please :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "\" :O (O)\n", "It :O (O)\n", "'s :O (O)\n", "a :O (O)\n", "Man :O (O)\n", "'s :O (O)\n", "World :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "and :O (O)\n", "\" :O (O)\n", "Papa :B-art (B-org)\n", "'s :I-art (O)\n", "Got :I-art (O)\n", "a :I-art (O)\n", "Brand :I-art (B-org)\n", "New :I-art (I-org)\n", "Bag :I-art (I-org)\n", ". :O (O)\n", "\" :O (O)\n", "\n", "\n" ] } ], "source": [ "learner.view_top_losses(n=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Making Predictions on New Sentences\n", "\n", "Let's use our model to extract entities from new sentences. We begin by instantating a `Predictor` object." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('As', 'O'),\n", " ('of', 'O'),\n", " ('2019', 'B-tim'),\n", " (',', 'O'),\n", " ('Donald', 'B-per'),\n", " ('Trump', 'I-per'),\n", " ('is', 'O'),\n", " ('still', 'O'),\n", " ('the', 'O'),\n", " ('President', 'B-per'),\n", " ('of', 'O'),\n", " ('the', 'O'),\n", " ('United', 'B-geo'),\n", " ('States', 'I-geo'),\n", " ('.', 'O')]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict('As of 2019, Donald Trump is still the President of the United States.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can save the predictor for later deployment." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "reloaded_predictor = ktrain.load_predictor('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Paul', 'B-per'),\n", " ('Newman', 'I-per'),\n", " ('is', 'O'),\n", " ('my', 'O'),\n", " ('favorite', 'O'),\n", " ('American', 'B-gpe'),\n", " ('actor', 'O'),\n", " ('.', 'O')]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict('Paul Newman is my favorite American actor.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Note on Sentence Tokenization\n", "\n", "The `predict` method typically operates on individual sentences instead of entire paragraphs or documents. The model after all was trained on individual sentences. In production, you can use the `sent_tokenize` function to tokenize text into individual sentences.\n", "\n", "```python\n", "from ktrain import text\n", "text.textutils.sent_tokenize('This is the first sentence about Dr. Smith. This is the second sentence.')\n", "```\n", "\n", "The above will output:\n", "```\n", "['This is the first sentence about Dr . Smith .',\n", " 'This is the second sentence .']\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }