{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import text as txt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess the Dataset\n", "\n", "A Dutch NER dataset can be downloaded from [here](https://www.clips.uantwerpen.be/conll2002/ner/).\n", "\n", "We use the `entities_from_conll2003` function to load and preprocess the data, as the dataset is in a standard **CoNLL** format. (Download the data from the link above to see what the format looks like.)\n", "\n", "See the *ktrain* [sequence-tagging tutorial](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/tutorials/tutorial-06-sequence-tagging.ipynb) for more information on how to load data in different ways." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: ISO-8859-1 (if wrong, set manually)\n", "Number of sentences: 15806\n", "Number of words in the dataset: 27803\n", "Tags: ['I-PER', 'B-MISC', 'O', 'B-PER', 'I-ORG', 'I-LOC', 'B-LOC', 'I-MISC', 'B-ORG']\n", "Number of Labels: 9\n", "Longest sentence: 859 words\n" ] } ], "source": [ "TDATA = 'data/dutch_ner/ned.train'\n", "VDATA = 'data/dutch_ner/ned.testb'\n", "(trn, val, preproc) = txt.entities_from_conll2003(TDATA, val_filepath=VDATA)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Build the Model\n", "\n", "Next, we will build a Bidirectional LSTM model that employs the use of transformer embeddings like [BERT word embeddings](https://arxiv.org/abs/1810.04805). By default, the `bilstm-transformer` model will use a pretrained multilingual model (i.e., `bert-base-multilingual-cased`). However, since we are training a Dutch-language model, it is better to select the Dutch pretrained BERT model: `bert-base-dutch-cased`. A full list of available pretrained models is [listed here](https://huggingface.co/transformers/pretrained_models.html). One can also employ the use of [community-uploaded models](https://huggingface.co/models) that focus on specific domains such as the biomedical or scientific domains (e.g, BioBERT, SciBERT). To use SciBERT, for example, set `bert_model` to `allenai/scibert_scivocab_uncased`. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embedding schemes employed (combined with concatenation):\n", "\tword embeddings initialized with fasttext word vectors (cc.nl.300.vec.gz)\n", "\ttransformer embeddings with bert-base-dutch-cased\n", "\n", "pretrained word embeddings will be loaded from:\n", "\thttps://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz\n", "loading pretrained word vectors...this may take a few moments...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "WV_URL='https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz'\n", "model = txt.sequence_tagger('bilstm-transformer', preproc, \n", " transformer_model='wietsedv/bert-base-dutch-cased', wv_path_or_url=WV_URL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the cell above, notice that we suppied the `wv_path_or_url` argument. This directs *ktrain* to initialized word embeddings with one of the pretrained fasttext (word2vec) word vector sets from [Facebook's fasttext site](https://fasttext.cc/docs/en/crawl-vectors.html). When supplied with a valid URL to a `.vec.gz`, the word vectors will be automatically downloaded, extracted, and loaded in STEP 2 (download location is `/ktrain_data`). To disable pretrained word embeddings, set `wv_path_or_url=None` and randomly initialized word embeddings will be employed. Use of pretrained embeddings will typically boost final accuracy. When used in combination with a model that uses an embedding scheme like BERT (e.g., `bilstm-bert`), the different word embeddings are stacked together using concatenation.\n", "\n", "Finally, we will wrap our selected model and datasets in a `Learner` object to facilitate training." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train the Model\n", "\n", "We will train for 5 epochs and decay the learning rate using cosine annealing. This is equivalent to one cycle with a length of 5 epochs. We will save the weights for each epoch in a checkpoint folder. Will train with a learning rate of `0.01`, previously identified using our [learning-rate finder](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/tutorials/tutorial-02-tuning-learning-rates.ipynb)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preparing train data ...done.\n", "preparing valid data ...done.\n", "Train for 124 steps, validate for 41 steps\n", "Epoch 1/5\n", "124/124 [==============================] - 83s 666ms/step - loss: 0.0699 - val_loss: 0.0212\n", "Epoch 2/5\n", "124/124 [==============================] - 73s 587ms/step - loss: 0.0167 - val_loss: 0.0135\n", "Epoch 3/5\n", "124/124 [==============================] - 73s 585ms/step - loss: 0.0083 - val_loss: 0.0131\n", "Epoch 4/5\n", "124/124 [==============================] - 73s 592ms/step - loss: 0.0053 - val_loss: 0.0123\n", "Epoch 5/5\n", "124/124 [==============================] - 73s 591ms/step - loss: 0.0040 - val_loss: 0.0123\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.01, 1, cycle_len=5, checkpoint_folder='/tmp/saved_weights')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd5xU5d338c9vO3XpbQFpa1lB2tIEsUbBKFhQQVRUlOANRlPu3PrkMY8xuY3GXrAQxa6oWEKMXdQoSlmqtIWlKCC9LiBs+z1/zMFs1qUs7OzM7Hzfr9e8mDnnOmd+lw5855S5LnN3REREDldCpAsQEZHYouAQEZEKUXCIiEiFKDhERKRCFBwiIlIhCg4REakQBYdIFTKzNmbmZpZUSftbZWZnVXZbkYNRcEhcOdA/nmZ2mpmVmNkuM8s3s1wzu+YQ+xppZkuC9hvM7F0zqxO+6kWiQ6V86xGpJr5395ZmZsBAYLKZfeXuuWUbmtmpwJ3AAHefY2YNgPOruF6RiNARh0gZHvIusBU46QDNegBfu/ucYJut7v6cu+cDmFkNM7vPzL41sx1m9qWZ1Si1/XAz+87MNpvZ7/cvNLMEM7vFzJab2RYzey0Ipf3rrwz2uaX0dsG6Z83sz6Ven2Zma8or/lDvI3IwCg6RMoJ/VAcBjYC8AzSbDpxjZn80s75mllpm/b1Ad+BkoAHwO6Ck1Pp+wHHAmcAfzOyEYPmNwAXAqUALYBswLqgrC3gcuDJY1xBoeYTdPOD7iByKgkPk31qY2XbgB+At4Nf7jyjKcvcvgIuAbsA/gS1mdr+ZJZpZAnAtcJO7r3X3Ynf/yt33ldrFH939B3efB8wDOgfLRwO/d/c1QfvbgSHBxfQhwDvu/q9g3W38ZxhVxMHeR+Sg9CER+bf91zhSgbuAM4AHD9TY3d8D3guC4nTgdSCXUOikAcsP8l7rSz3fA9QOnh8DvGVmpQOhGGhK6Mhgdan3321mWw6zb2Ud7H3WHuE+JU7oiEOkjOAb+P8AnczsgsNoX+LunwBTgI7AZmAv0P4I3n41MNDd65V6pLn7WmAd0Gp/QzOrSeh01X67gZqlXjc7wvcROSgFh8SjZDNLK/X4yZG3uxcA9wF/KG8HZjbYzIaaWX0L6UnoesE0dy8BJgD3m1mL4PRVn3Kug5TnCeB/zeyY4H0am9ngYN0k4Dwz62dmKcAd/Off4bnAuWbWwMyaATcf4fuIHJSCQ+LRu4SuY+x/3H6AdhOA1mZW3m2224DrgWXATuBF4B53fylY/1vgG2Amobuz7ubw/r49BEwGPjSzfGAa0AvA3RcCY4CXCR19bANK3zX1AqHrJauAD4FXj+R9RA7FNJGTiIhUhI44RESkQhQcIiJSIQoOERGpEAWHiIhUSFz8ALBRo0bepk2bSJchIhIzZs2atdndG5e3Li6Co02bNuTk5ES6DBGRmGFm3x5onU5ViYhIhSg4RESkQhQcIiJSIQoOERGpEAWHiIhUSFiDw8wGmFmumeWZ2S3lrE81s1eD9dPNrE2wvKGZfWpmu8zs0TLbdDezb4JtHg7mhxYRkSoStuAws0RCU1EOBLKAYcHUl6WNBLa5ewfgAUIjiEJoLoPbCI0wWtbjhEYlzQweAyq/ehEROZBw/o6jJ5Dn7isAzGwiMBhYVKrNYP49pPUk4FEzM3ffDXxpZh1K79DMmgN13X1a8Pp5QvMmvxeODjz8yTLcITnJSElMIDkxgbTkBNJrJJNeI4V6NZOpVzOZRrVTSU7UWT8RiQ/hDI4MSk1zSWjegLLj/f/Yxt2LzGwHoRnNNh9kn6XnH1gTLPsJMxsFjAJo3bp1RWsH4InPl7OnoPiQ7RIMmtZNo0W9GjRPT6N1g5p0aFKbDk1q075xbWqlxsXvLEUkTlTbf9HcfTwwHiA7O/uIJh1ZdMcAikucwuKS4OHsKShixw+F7NhTyPYfCtm2p4ANO/aydvte1u34gW/W7uD9BespKvn3W2bUq0GnjHQ6t6pH51bpdMpIp05acuV0VESkioUzONZSan5koGWwrLw2a4LpO9OBLYfYZ8tD7LNSJSYYiQmJpCUnAtCgVgot6x98m8LiEr7dsoe8jfnkbdzFkvX5oUBZuB4AM8hsUpve7RpycvtG9G7XgHo1U8LZDRGRShPO4JgJZJpZW0L/uA8FLi/TZjIwAvgaGAJM8YNMSeju68xsp5n1BqYDVwGPhKP4o5GcmPDjqarStu0uYN6a7cxbvYNZ323j9Zw1PP/1t5hBxxbpnNyhIT87oSldW9cnMUE3i4lIdArr1LFmdi7wIJAITHD3/zWzO4Acd59sZmmE5knuSmhe5qGlLqavAuoCKcB24Gx3X2Rm2cCzQA1CF8VvPFjYQOhUVTQOclhQVMK8NduZmreZr5ZvYfa32ygqcRrWSuGsE5rys6ym9Mts9OPRjohIVTGzWe6eXe66eJhzPFqDo6ydewv5LHcTHy5cz+e5m8jfV0StlETO6diMC7tmcHL7RjoSEZEqoeCIkeAoraCohGkrtvDP+et4d8E68vcW0bhOKoM6t+Cibhmc2CI90iWKSDWm4IjB4Chtb2Exny7ZyFtz1vJp7kYKi53OreoxvFdrzj+pBTVSdCpLRCqXgiPGg6O07XsKeHvOWl6a/h3LNu6ibloSF3dvyfBex/zkYryIyJFScFSj4NjP3Zm5ahsvTf+W975ZT0FxCWce34Tr+7ejV9sGaAgvETkaCo5qGBylbdm1jxenfcfzX69iy+4CTmqZzvWntGNgx2YkaSgUETkCCo5qHhz77S0s5s3Za3nqixWs2Lyb1g1qMvaMDlzYNUNjaYlIhSg44iQ49ispcT5avIFHpixjwdqdHNOwJmNPDwWIjkBE5HAoOOIsOPZzdz5ZvJEHP1n6HwFyUbeW+j2IiBzUwYJDXz+rMTPjrKym/GNsP/52VTa1U5P470nzGfjQv/h0yUbi4UuDiFQ+BUccMDN+ltWUd27sx2PDu7GvqIRrnp3J8Kems2DtjkiXJyIxRsERR8yMczs156Nfncrt52exeN1OznvkS26eOIe123+IdHkiEiMUHHEoJSmBq/u25fPfnc4Np7XnvQXrOfO+z3h0yjL2FR164ioRiW8KjjhWNy2Z/xlwPJ/85lROP64J9364lHMe+Bef5m6MdGkiEsUUHELL+jV5/IruPH9tTxISjGuemcn1z+eweuueSJcmIlFIwSE/6n9sY96/qT//M+B4puZt5mcPfM74fy2nqLgk0qWJSBRRcMh/SElK4IbT2vPJb07llMzG3PnuEi587CsWfq+7r0QkRMEh5WqeXoPxV3Zn3OXdWLfjBwY9OpW/vr+EvYW6eC4S7xQcckBmxs9Pas7Hvz6VC7tm8Nhnyzn3oS+YvmJLpEsTkQhScMgh1auZwr2XdObFkb0oLClh6N+m8ed3FunoQyROKTjksPXLbMQHN/dneK/WPPXlSgY9+qV+eS4ShxQcUiE1U5L48wWdePaaHmzfU8iFj01l3Kd5uvNKJI4oOOSInHZcEz64uT9nn9iMez7I5dInv2bV5t2RLktEqoCCQ45Y/VopjLu8Gw8N7ULexl38/OEveGvOmkiXJSJhpuCQoza4SwYf/Ko/J7ZI51evzuM3r81j976iSJclImGi4JBK0Ty9Bi9f34ubzszkzTlrOP+RL/WjQZFqSsEhlSYpMYFf/exYXr6uN7sLirjwsa947qtVmjBKpJpRcEil69O+Ie/+8hT6tm/I/5u8kBtenE3+3sJIlyUilUTBIWHRsHYqE67uwe/PPYGPFm9g0KNTyV2fH+myRKQSKDgkbMyM6/u34+XrerFrXxEXjJvK3+eujXRZInKUFBwSdr3aNeSfN/ajY0Zdbpo4l9snL6SgSD8YFIlVCg6pEk3qpvHy9b0Z2a8tz361iqHjv2b9jr2RLktEjkBYg8PMBphZrpnlmdkt5axPNbNXg/XTzaxNqXW3BstzzeycUst/ZWYLzWyBmb1iZmnh7INUnuTEBG47L4tHL+/KkvX5nPfIF8xctTXSZYlIBYUtOMwsERgHDASygGFmllWm2Uhgm7t3AB4A7g62zQKGAicCA4DHzCzRzDKAXwLZ7t4RSAzaSQw576QWTB7blzppyVz+t2m8OvO7SJckIhUQziOOnkCeu69w9wJgIjC4TJvBwHPB80nAmWZmwfKJ7r7P3VcCecH+AJKAGmaWBNQEvg9jHyRMOjSpw9v/1Zfe7RryP298w+2TF2qgRJEYEc7gyABWl3q9JlhWbht3LwJ2AA0PtK27rwXuBb4D1gE73P3D8t7czEaZWY6Z5WzatKkSuiOVLb1mMs9c3ePH6x5XPzOT7XsKIl2WiBxCTF0cN7P6hI5G2gItgFpmdkV5bd19vLtnu3t248aNq7JMqYCk4LrHX4ecxIyVW7lg3FTyNur3HiLRLJzBsRZoVep1y2BZuW2CU0/pwJaDbHsWsNLdN7l7IfAmcHJYqpcqdWl2K14Z1Ytd+4q5YNxXTFmyIdIlicgBhDM4ZgKZZtbWzFIIXcSeXKbNZGBE8HwIMMVDAxtNBoYGd121BTKBGYROUfU2s5rBtZAzgcVh7INUoe7HNGDy2L60aVST657L4dmpKyNdkoiUI2zBEVyzGAt8QOgf99fcfaGZ3WFmg4JmTwMNzSwP+DVwS7DtQuA1YBHwPjDG3YvdfTqhi+izgW+C+seHqw9S9VrUq8HrvziZs05oyu3/WMQf/7GQ4hINkigSTSweRi7Nzs72nJycSJchFVBc4tz57mKe/nIlZ53QlIeHdaFmSlKkyxKJG2Y2y92zy1sXUxfHJX4kJhi3nZfFHYNPZMqSDVz25DQ25uuX5iLRQMEhUe2qPm3421XZLN+0iwvHfaURdkWigIJDot6ZJzTltV/0obC4hCGPf8UXy/S7HJFIUnBITOiYkc7bY/qSUb8G1zwzk0mz1kS6JJG4peCQmNGiXg1eH92H3u0a8tvX5/H4Z8s1La1IBCg4JKbUSUtmwtU9GNS5BXe/v4Q73llEiW7XFalSur9RYk5KUgIPXtaFRrVTmTB1JZvy93HfpZ1JTUqMdGkicUHBITEpIcG47bwTaFI3lbveW8K2PQU8cUV36qQlR7o0kWpPp6okZpkZo09tz32XdGbaiq0MHT+NTfn7Il2WSLWn4JCYd3H3ljw1IpsVm3Zz8eNfsWrz7kiXJFKtKTikWjj9uCa8fH0v8vcWcvHjX7Fg7Y5IlyRSbSk4pNro2ro+k244mbTkRIb9bRo5ms9cJCwUHFKttG9cm9dG96Fx7VSueHo6ny/Vr8xFKpuCQ6qdjHo1ePUXfWjbqDbXPTeTd79ZF+mSRKoVBYdUS43rpDJxVG9OalmPsS/P5rWZqw+9kYgcFgWHVFvpNZJ5YWRP+nZoxO/emM/TX2pGQZHKoOCQaq1mShJPjchmYMdm/OmdRTz48VKNbyVylBQcUu2lJiXyyLCuDOnekgc/Xsaf3lms8a1EjoKGHJG4kJSYwF8vPok6aUlMmLqS3fuKuPOiTiQmWKRLE4k5Cg6JGwkJxh/Oy6JOahIPT8mjsLiEvw45iaREHXiLVISCQ+KKmfHrs48jJSmBez9cyr7iEh68rAvJCg+Rw6bgkLg09oxMUpISuPPdJRQWlfDI5V01LLvIYdLXLIlbo/q35/bzs/hw0QZGvzCLvYXFkS5JJCYoOCSuXd23LXde2InPlm7i+udz+KFA4SFyKAoOiXuX92rNPUM6MzVvM9c8O4Pd+4oiXZJIVFNwiABDurfkgcu6MHPVNkZMmEH+3sJIlyQStRQcIoHBXTJ4dFhX5q7ezhVPz2DHHoWHSHkUHCKlDOzUnCeu6M7i73dy+VPT2Lq7INIliUQdBYdIGWdlNeVvI7LJ27iL4U9NV3iIlKHgECnHqcc2DuYxV3iIlBXW4DCzAWaWa2Z5ZnZLOetTzezVYP10M2tTat2twfJcMzun1PJ6ZjbJzJaY2WIz6xPOPkj8OiXz3+Fx+d902kpkv7AFh5klAuOAgUAWMMzMsso0Gwlsc/cOwAPA3cG2WcBQ4ERgAPBYsD+Ah4D33f14oDOwOFx9ENkfHis371Z4iATCecTRE8hz9xXuXgBMBAaXaTMYeC54Pgk408wsWD7R3fe5+0ogD+hpZulAf+BpAHcvcPftYeyDCKdkNubpET0UHiKBcAZHBlB6vs41wbJy27h7EbADaHiQbdsCm4BnzGyOmT1lZrXKe3MzG2VmOWaWs2nTpsroj8SxfpmNFB4igVi7OJ4EdAMed/euwG7gJ9dOANx9vLtnu3t248aNq7JGqaYUHiIh4QyOtUCrUq9bBsvKbWNmSUA6sOUg264B1rj79GD5JEJBIlIlFB4i4Q2OmUCmmbU1sxRCF7snl2kzGRgRPB8CTPHQhNCTgaHBXVdtgUxghruvB1ab2XHBNmcCi8LYB5GfKBseW3bti3RJIlUqbMERXLMYC3xA6M6n19x9oZndYWaDgmZPAw3NLA/4NcFpJ3dfCLxGKBTeB8a4+/5hS28EXjKz+UAX4M5w9UHkQPplNmLC1aHwGP7UdIWHxBULfcGv3rKzsz0nJyfSZUg1NDVvM9c+O5O2jWrx0nW9aFg7NdIliVQKM5vl7tnlrYu1i+MiUaVvBx15SPxRcIgcpbLhoQvmUt0pOEQqQd8OuttK4oeCQ6SS9Mts9OPwJMOfms42hYdUUwoOkUp0SmZj/nZVNsuDUXW371F4SPWj4BCpZP2PDYVHnsJDqikFh0gYnHpsY8Zf2Z1lG3ZxxdPTNQ2tVCsKDpEwOe24Jjx5ZXeWrg/C4weFh1QPCg6RMDr9+CY8cWU3ctfnc6XCQ6oJBYdImJ1xfFMev6Ibi9ft5Kqnp7Nzr8JDYpuCQ6QKnHlCUx4f3p1F63Zy5dMzFB4S0xQcIlXkrKymjLu8G4u+38FVT88gX+EhMUrBIVKFzj6xGY9e3o0Fa3cwYoLCQ2LTIYPDzBLNbElVFCMSD84JwmP+mh1c/cxMdu0rinRJIhVyyOAI5sHINbPWVVCPSFwY0LEZjwzrytzV27l6wgyFh8SUwz1VVR9YaGafmNnk/Y9wFiZS3Q3s1JxHhnVlzurtXPPMDHYrPCRGJB1mu9vCWoVInDq3U3Pc4ZcT53DNMzN55poe1Eo93L+WIpFxWJ9Qd/883IWIxKufn9Qcx7lp4lyueXYmz17Tg5opCg+JXgc9VWVm+Wa2s5xHvpntrKoiRaq7805qwQOXdSFn1VaueWYmewp02kqi10GDw93ruHvdch513L1uVRUpEg8GdQ6Fx8xVW7n22Zn8UFAc6ZJEyqXfcYhEkcFdMnjgsi7MWLmVkc8pPCQ6KThEoszgLhncd2lnvl6xheueV3hI9FFwiEShC7u25L5LOvPV8i1c/3wOewsVHhI9FBwiUeqibi25Z0hnpi7frPCQqKLgEIliQ7q35K8Xn8SXeZsZ9cIshYdEBQWHSJS7JLsVd190El8s28QvFB4SBRQcIjHg0h6tuOuiTny+dBOjX1R4SGQpOERixGU9WnPXRZ34LHcTN7w4i31FCg+JDAWHSAwZ2rM1d17YiU9zN3HDi7MVHhIRCg6RGHN5r9b874UdmbJkI2NeUnhI1VNwiMSg4b2O4U8XdOTjxRsZ89IcCopKIl2SxJGwBoeZDTCzXDPLM7NbylmfamavBuunm1mbUutuDZbnmtk5ZbZLNLM5ZvZOOOsXiWZX9j6GPw0+kY8Xb2DMy7MVHlJlwhYcZpYIjAMGAlnAMDPLKtNsJLDN3TsADwB3B9tmAUOBE4EBwGPB/va7CVgcrtpFYsWVfdrwx0En8tGiDdz4ymwKixUeEn7hPOLoCeS5+wp3LwAmAoPLtBkMPBc8nwScaWYWLJ/o7vvcfSWQF+wPM2sJ/Bx4Koy1i8SMESe34fbzs/hg4QZufHmOwkPCLpzBkQGsLvV6TbCs3DbuXgTsABoeYtsHgd8BB/3bYWajzCzHzHI2bdp0pH0QiQlX923LH87L4v2F6/nlKwoPCa+YujhuZucBG9191qHauvt4d8929+zGjRtXQXUikXVtv7bcdl4W7y1Yz00TFR4SPuGcn3It0KrU65bBsvLarDGzJCAd2HKQbQcBg8zsXCANqGtmL7r7FeHpgkhsGdmvLe7On/+5GLO5PHRZF5ISY+r7ocSAcH6iZgKZZtbWzFIIXeyeXKbNZGBE8HwIMMXdPVg+NLjrqi2QCcxw91vdvaW7twn2N0WhIfKfrjulHb8/9wT+OX8dN786lyIdeUglC9sRh7sXmdlY4AMgEZjg7gvN7A4gx90nA08DL5hZHrCVUBgQtHsNWAQUAWPcXb9yEjlM1/dvR4k7f3lvCWbGA5d21pGHVBoLfcGv3rKzsz0nJyfSZYhUuSc+X85d7y1hcJcW3H9pFxITLNIlSYwws1nunl3eunBe4xCRCBt9antK3Pnr+7kYcJ/CQyqBgkOkmvuv0zrgDvd8kEuCGfdc0lnhIUdFwSESB8ac3gF3594Pl4LBPUMUHnLkFBwicWLsGZmUONz/0VIM469DTlJ4yBFRcIjEkV+emYk7PPDxUvYVFfPAZV1I1t1WUkEKDpE4c9NZmaQlJ/CX95awt7CYRy/vRlpy4qE3FAnoq4ZIHPrFqe2DIdk3cv3zOewpKIp0SRJDFBwicerKPm2495LOTM3bzIgJM9i5tzDSJUmMUHCIxLEh3VvyyLBuzPluO1c8NZ1tuwsiXZLEAAWHSJz7+UnNefLK7ixZn8/Q8dPYmL830iVJlFNwiAhnntCUZ67uwXdb93DZk9P4fvsPkS5JopiCQ0QA6NuhES+M7Mnm/H1c8sTXfLtld6RLkiil4BCRH2W3acAro3qzp6CIS574mqUb8iNdkkQhBYeI/IeOGem8+os+AFzyxNfM+nZbhCuSaKPgEJGfOLZpHd644WTq10zmiqem81nuxkiXJFFEwSEi5WrVoCavjz6Zdo1rcd1zOfx9btmZnyVeKThE5IAa10nllVG96X5MfW6aOJdnp66MdEkSBRQcInJQddOSee7anpyd1ZTb/7GI+z/MJR5mDpUDU3CIyCGlJSfy2PBuXJbdioen5PF/315AcYnCI15pdFwROSxJiQncdXEn6tdK4YnPl7N9TyH3X9aZ1CSNrBtvFBwictjMjFsGHk/DWin877uL2bangCeu7E7dtORIlyZVSKeqRKTCru/fjvsv7cyMlVu55PGvNURJnFFwiMgRuahbS569pidrt//ARY99xeJ1OyNdklQRBYeIHLF+mY14ffS/f2X+5bLNEa5IqoKCQ0SOygnN6/LWmJNpWb8GVz8zg0mz1kS6JAkzBYeIHLXm6TV4bXQferVrwG9fn8fDnyzTbz2qMQWHiFSKumnJPHN1Ty7qlsH9Hy3llje+obC4JNJlSRjodlwRqTQpSQncd0lnMurV4JEpeazetofHhnejXs2USJcmlUhHHCJSqcyM35x9HPdf2pmcVdu48LGvWLFpV6TLkkqk4BCRsLioW0tevr4XO38o5IJxU5mapzuuqouwBoeZDTCzXDPLM7NbylmfamavBuunm1mbUutuDZbnmtk5wbJWZvapmS0ys4VmdlM46xeRo5PdpgFvj+lLs/Q0rpowg5emfxvpkqQShC04zCwRGAcMBLKAYWaWVabZSGCbu3cAHgDuDrbNAoYCJwIDgMeC/RUBv3H3LKA3MKacfYpIFGnVoCZv3HAy/TMb8fu3FvDHfyykSBfNY1o4jzh6AnnuvsLdC4CJwOAybQYDzwXPJwFnmpkFyye6+z53XwnkAT3dfZ27zwZw93xgMZARxj6ISCWok5bMUyN6cG3ftjwzdRUjn8thx57CSJclRyicwZEBrC71eg0//Uf+xzbuXgTsABoezrbBaa2uwPTy3tzMRplZjpnlbNq06Yg7ISKVIzHB+MP5Wfzlok58tXwzg8Z9Se76/EiXJUcgJi+Om1lt4A3gZncvd4Acdx/v7tnunt24ceOqLVBEDmhYz9ZMHNWbPQXFXDBuKu/M/z7SJUkFhTM41gKtSr1uGSwrt42ZJQHpwJaDbWtmyYRC4yV3fzMslYtIWHU/pgH/vLEfWS3qMvblOfzlvcWaGCqGhDM4ZgKZZtbWzFIIXeyeXKbNZGBE8HwIMMVD4xRMBoYGd121BTKBGcH1j6eBxe5+fxhrF5Ewa1I3jVeu780VvVvz5OcruPqZGWzbXRDpsuQwhC04gmsWY4EPCF3Efs3dF5rZHWY2KGj2NNDQzPKAXwO3BNsuBF4DFgHvA2PcvRjoC1wJnGFmc4PHueHqg4iEV0pSAn++oBN3X9yJ6Su2cv6jX7Jg7Y5IlyWHYPEwEFl2drbn5OREugwROYi5q7cz+oVZbN1TwO3nn8iwnq0InWSQSDCzWe6eXd66mLw4LiLVT5dW9Xjnl/3o1bYB/+etb7hp4lx27SuKdFlSDgWHiESNRrVTee6anvz27GN5Z/73DHrkS80sGIUUHCISVRISjLFnZPLy9b3Zta+IC8ZN5ZUZ32l+jyii4BCRqNS7XUPevekUerZtwK1vfsPNr84lf69+bR4NFBwiErX2n7r673OO4x/zvufch78gZ9XWSJcV9xQcIhLVEhKMMad34PXRJ2MYlz75Nfd9mKvZBSNIwSEiMaH7MfV596ZTuLhbSx6ZkseQJ75m5ebdkS4rLik4RCRm1E5N4p5LOvPY8G6s2rybcx/6QhfOI0DBISIx59xOzfng5v50O6Yet775DVdNmMHa7T9Euqy4oeAQkZjULD2NF67txZ8Gn8isb7dx9v2f8+K0bynRYIlhp+AQkZiVkGBc2acNH9zcn66t6/N/317A8Kem892WPZEurVpTcIhIzGvVoCYvjOzJXy7qxDdrd3DOg/9iwpcrNVR7mCg4RKRaMDOG9WzNh7/qT692DbjjnUVcMG4q89dsj3Rp1Y6CQ0SqlRb1avDM1T14eFhX1u/cy+BxU7nt7QXs+EG/Oq8sCg4RqXbMjEGdW/DJb05lRJ82vDT9W86873PenrNWt+5WAgWHiFRbddOSuX3QiUwe24+Memnc/OpcLntymk5fHSUFh4hUex0z0nnzv/py54WdWLF5F4MendjyWOkAAAt/SURBVMqvX5vL+h17I11aTFJwiEhcSEwwLu/Vmk9/exqjT23PO/PWcfq9n/Hgx0v5oaA40uXFFAWHiMSVOmnJ3DLweD75zamccXwTHvx4Gafe8ykvfL2KgiINnHg4FBwiEpdaNajJuOHdeH10H45pWJPb/r6QM+77jNdzVlOkkXcPSsEhInGtR5sGvPaLPjx7TQ/q10zhvyfN5+wH/8U/5n2v4UsOwOLh1rTs7GzPycmJdBkiEuXcnQ8WbuD+j3JZumEX7RrXYnT/9lzQNYOUpPj6nm1ms9w9u9x1Cg4Rkf9UXOK8+806Hv9sOYvW7aRZ3TSuO6Utw3q2plZqUqTLqxIKDgWHiBwBd+dfyzbz+Gd5TFuxlfQayQzv1ZrhvY8ho16NSJcXVgoOBYeIHKXZ323jic+W8/HiDQD8LKspV/Vpw8ntG2JmEa6u8ik4FBwiUknWbNvDS9O/Y+KM79i2p5AOTWoztEcrLuiaQaPaqZEur9IoOBQcIlLJ9hYW88/563h+2rfMW72dpATjtOOaMKR7Bmcc3zTmL6YrOBQcIhJGyzbkM2n2Gt6avZaN+fuoXzOZAR2bMbBjc/q0b0hyYuyFiIJDwSEiVaCouIQv8zbz5uy1fLJ4A7sLikmvkcxZJzRlYMdm9O3QiBopiZEu87AcLDji474yEZEqkJSYwGnHNeG045qwt7CYL5Zt5r0F6/ho0XremL2GlKQEerZpwCmZjTglszEnNK8TkxfWdcQhIhJmBUUlTF+5hc9zN/HFss3kbsgHoFHtVHq0qU/3Y+rTtXV9OmbUJTUpOo5IInbEYWYDgIeAROApd7+rzPpU4HmgO7AFuMzdVwXrbgVGAsXAL939g8PZp4hItElJSuCUzMacktkYgA079/LFss18uWwTs77bxnsL1ofaJSbQMaMuJzSvy3HN6nBc0zoc36wu6TWTI1n+T4TtiMPMEoGlwM+ANcBMYJi7LyrV5r+Ak9x9tJkNBS5098vMLAt4BegJtAA+Bo4NNjvoPsujIw4RiWYb8/cy+9vtzP5uG3O+28aS9fnk7y36cX2TOqm0blCTlvVr0LJ+TVo1qEGTOmnUr5VCw1op1K+VQq2UxEo97RWpI46eQJ67rwiKmAgMBkr/Iz8YuD14Pgl41EI9HwxMdPd9wEozywv2x2HsU0QkpjSpk8aAjs0Y0LEZEPrF+vqde1myPp+l6/NZtnEXa7btIefbbfxj/jqKyxl8MSnBSEtOJDUp4cc/G9VO5bXRfSq93nAGRwawutTrNUCvA7Vx9yIz2wE0DJZPK7NtRvD8UPsEwMxGAaMAWrdufWQ9EBGJADOjeXoNmqfX4PTjmvzHusLiEtbv2MvmXfvYuruArbsL2LangO17CtlXVMLewmL2Fpawt6iYOmEaV6va3lXl7uOB8RA6VRXhckREKkVyYgKtGtSkVYOaEashnL9KWQu0KvW6ZbCs3DZmlgSkE7pIfqBtD2efIiISRuEMjplAppm1NbMUYCgwuUybycCI4PkQYIqHrtZPBoaaWaqZtQUygRmHuU8REQmjsJ2qCq5ZjAU+IHTr7AR3X2hmdwA57j4ZeBp4Ibj4vZVQEBC0e43QRe8iYIy7FwOUt89w9UFERH5KPwAUEZGfONjtuLE38paIiESUgkNERCpEwSEiIhWi4BARkQqJi4vjZrYJ+PYIN28EbK7EciJBfYgO6kP0qA79CHcfjnH3xuWtiIvgOBpmlnOgOwtihfoQHdSH6FEd+hHJPuhUlYiIVIiCQ0REKkTBcWjjI11AJVAfooP6ED2qQz8i1gdd4xARkQrREYeIiFSIgkNERCpEwXEAZjbAzHLNLM/Mbol0PQdjZhPMbKOZLSi1rIGZfWRmy4I/6wfLzcweDvo138y6Ra7yH2ttZWafmtkiM1toZjcFy2OmDwBmlmZmM8xsXtCPPwbL25rZ9KDeV4MpAQimDXg1WD7dzNpEsv7SzCzRzOaY2TvB65jqg5mtMrNvzGyumeUEy2Lt81TPzCaZ2RIzW2xmfaKlDwqOcphZIjAOGAhkAcPMLCuyVR3Us8CAMstuAT5x90zgk+A1hPqUGTxGAY9XUY0HUwT8xt2zgN7AmOC/dyz1AWAfcIa7dwa6AAPMrDdwN/CAu3cAtgEjg/YjgW3B8geCdtHiJmBxqdex2IfT3b1Lqd86xNrn6SHgfXc/HuhM6P9HdPTB3fUo8wD6AB+Uen0rcGuk6zpEzW2ABaVe5wLNg+fNgdzg+ZPAsPLaRcsD+DvwsxjvQ01gNtCL0K97k8p+tgjNK9MneJ4UtLMoqL0loX+UzgDeASwG+7AKaFRmWcx8ngjNhrqy7H/LaOmDjjjKlwGsLvV6TbAsljR193XB8/VA0+B5VPctONXRFZhODPYhOMUzF9gIfAQsB7a7e1HQpHStP/YjWL8DaFi1FZfrQeB3QEnwuiGx1wcHPjSzWWY2KlgWS5+ntsAm4JnglOFTZlaLKOmDgiMOeOgrSNTfd21mtYE3gJvdfWfpdbHSB3cvdvcuhL619wSOj3BJFWJm5wEb3X1WpGs5Sv3cvRuhUzhjzKx/6ZUx8HlKAroBj7t7V2A3/z4tBUS2DwqO8q0FWpV63TJYFks2mFlzgODPjcHyqOybmSUTCo2X3P3NYHFM9aE0d98OfErotE49M9s/TXPpWn/sR7A+HdhSxaWW1RcYZGargImETlc9RGz1AXdfG/y5EXiLUIjH0udpDbDG3acHrycRCpKo6IOCo3wzgczgTpIUQnOhT45wTRU1GRgRPB9B6LrB/uVXBXdh9AZ2lDr0jQgzM0Lzzy929/tLrYqZPgCYWWMzqxc8r0HoOs1iQgEyJGhWth/7+zcEmBJ8i4wYd7/V3Vu6extCn/sp7j6cGOqDmdUyszr7nwNnAwuIoc+Tu68HVpvZccGiM4FFREsfInkBKJofwLnAUkLnqH8f6XoOUesrwDqgkNA3lZGEzjN/AiwDPgYaBG2N0B1jy4FvgOwoqL8foUPu+cDc4HFuLPUhqOskYE7QjwXAH4Ll7YAZQB7wOpAaLE8LXucF69tFug9l+nMa8E6s9SGodV7wWLj/728Mfp66ADnB5+ltoH609EFDjoiISIXoVJWIiFSIgkNERCpEwSEiIhWi4BARkQpRcIiISIUoOEQOwcy+Cv5sY2aXV/K+/0957yUSzXQ7rshhMrPTgN+6+3kV2CbJ/z3GU3nrd7l77cqoT6Sq6IhD5BDMbFfw9C7glGCOh18FAxreY2YzgzkQfhG0P83MvjCzyYR+7YuZvR0MuLdw/6B7ZnYXUCPY30ul3yv4BfA9ZrYgmFfislL7/qzUPA0vBb+8x8zustCcJvPN7N6q/G8k8SXp0E1EJHALpY44ggDY4e49zCwVmGpmHwZtuwEd3X1l8Ppad98aDEUy08zecPdbzGyshwZFLOsiQr8c7gw0Crb5V7CuK3Ai8D0wFehrZouBC4Hj3d33D30iEg464hA5cmcTGh9oLqFh4BsSmkgHYEap0AD4pZnNA6YRGowuk4PrB7ziodF2NwCfAz1K7XuNu5cQGp6lDaHhzPcCT5vZRcCeo+6dyAEoOESOnAE3emiWuS7u3tbd9x9x7P6xUejayFmEJjzqTGg8q7SjeN99pZ4XE5pgqYjQCLCTgPOA949i/yIHpeAQOXz5QJ1Srz8AbgiGhMfMjg1GYy0rndD0qnvM7HhC0+PuV7h/+zK+AC4LrqM0BvoTGkSwXMFcJunu/i7wK0KnuETCQtc4RA7ffKA4OOX0LKF5KtoAs4ML1JuAC8rZ7n1gdHAdIpfQ6ar9xgPzzWy2h4Yv3+8tQnN5zCM0cvDv3H19EDzlqQP83czSCB0J/frIuihyaLodV0REKkSnqkREpEIUHCIiUiEKDhERqRAFh4iIVIiCQ0REKkTBISIiFaLgEBGRCvn/UzdvjOzZjZcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('lr')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown below, our model achieves an F1-Sccore of 83.04 with only a few minutes of training." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " F1: 83.04\n", " precision recall f1-score support\n", "\n", " PER 0.92 0.94 0.93 1097\n", " LOC 0.88 0.90 0.89 772\n", " MISC 0.74 0.76 0.75 1187\n", " ORG 0.72 0.84 0.77 882\n", "\n", "micro avg 0.81 0.85 0.83 3938\n", "macro avg 0.81 0.85 0.83 3938\n", "\n" ] }, { "data": { "text/plain": [ "0.8303891290920322" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Make Predictions" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Marke', 'B-PER'),\n", " ('Rutte', 'I-PER'),\n", " ('is', 'O'),\n", " ('een', 'O'),\n", " ('Nederlandse', 'B-MISC'),\n", " ('politicus', 'O'),\n", " ('die', 'O'),\n", " ('momenteel', 'O'),\n", " ('premier', 'O'),\n", " ('van', 'O'),\n", " ('Nederland', 'B-LOC'),\n", " ('is', 'O'),\n", " ('.', 'O')]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dutch_text = \"\"\"Marke Rutte is een Nederlandse politicus die momenteel premier van Nederland is.\"\"\"\n", "predictor.predict(dutch_text)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_dutch_nermodel')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `predictor` can be re-loaded from disk with with `load_predictor`:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.load_predictor('/tmp/my_dutch_nermodel')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Marke', 'B-PER'),\n", " ('Rutte', 'I-PER'),\n", " ('is', 'O'),\n", " ('een', 'O'),\n", " ('Nederlandse', 'B-MISC'),\n", " ('politicus', 'O'),\n", " ('die', 'O'),\n", " ('momenteel', 'O'),\n", " ('premier', 'O'),\n", " ('van', 'O'),\n", " ('Nederland', 'B-LOC'),\n", " ('is', 'O'),\n", " ('.', 'O')]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(dutch_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }