{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess Data\n", "\n", "\n", "The CoNLL2003 NER dataset can be downloaded from [here](https://github.com/amaiya/ktrain/tree/master/ktrain/tests/conll2003)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of sentences: 14041\n", "Number of words in the dataset: 23623\n", "Tags: ['B-MISC', 'I-MISC', 'I-PER', 'B-ORG', 'I-ORG', 'I-LOC', 'O', 'B-PER', 'B-LOC']\n", "Number of Labels: 9\n", "Longest sentence: 113 words\n" ] } ], "source": [ "TDATA = 'data/conll2003/train.txt'\n", "VDATA = 'data/conll2003/valid.txt'\n", "(trn, val, preproc) = text.entities_from_conll2003(TDATA, val_filepath=VDATA)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Define a Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = text.sequence_tagger('bilstm-crf', preproc)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train and Evaluate Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "439/439 [==============================] - 123s 281ms/step - loss: 9.4288 - val_loss: 9.3419\n", "Epoch 2/5\n", "439/439 [==============================] - 120s 274ms/step - loss: 9.0161 - val_loss: 9.2529\n", "Epoch 3/5\n", "439/439 [==============================] - 119s 271ms/step - loss: 8.9478 - val_loss: 9.2348\n", "Epoch 4/5\n", "439/439 [==============================] - 120s 273ms/step - loss: 8.9266 - val_loss: 9.2303\n", "Epoch 5/5\n", "439/439 [==============================] - 120s 273ms/step - loss: 8.9159 - val_loss: 9.2310\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.001, 5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " F1: 87.20\n", " precision recall f1-score support\n", "\n", " LOC 0.87 0.94 0.91 1837\n", " ORG 0.82 0.80 0.81 1341\n", " MISC 0.88 0.78 0.83 922\n", " PER 0.89 0.91 0.90 1842\n", "\n", "micro avg 0.87 0.88 0.87 5942\n", "macro avg 0.87 0.88 0.87 5942\n", "\n" ] } ], "source": [ "learner.validate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the `view_top_losses` method to inspect the sentences we're getting the most wrong. Here, we can see our model has trouble with movie titles, which is understandable since it is mixed into a catch-all miscellaneous category." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total incorrect: 12\n", "Word True : (Pred)\n", "==============================\n", "Best :O (B-ORG)\n", "known :O (O)\n", "for :O (O)\n", "appearances :O (O)\n", "in :O (O)\n", "\" :O (O)\n", "Ice :B-MISC (O)\n", "Cold :I-MISC (O)\n", "in :I-MISC (O)\n", "Alex :I-MISC (B-PER)\n", "\" :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "Lawrence :B-MISC (B-PER)\n", "of :I-MISC (I-MISC)\n", "Arabia :I-MISC (I-LOC)\n", "\" :O (O)\n", "and :O (O)\n", ", :O (O)\n", "as :O (O)\n", "Cardinal :O (B-PER)\n", "Wolsey :B-PER (I-PER)\n", ", :O (O)\n", "in :O (O)\n", "\" :O (O)\n", "Anne :B-MISC (B-PER)\n", "of :I-MISC (O)\n", "a :I-MISC (O)\n", "Thousand :I-MISC (I-MISC)\n", "Days :I-MISC (I-MISC)\n", "\" :O (O)\n", ". :O (O)\n", "\n", "\n" ] } ], "source": [ "learner.view_top_losses(n=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions on New Data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('As', 'O'),\n", " ('of', 'O'),\n", " ('2019', 'O'),\n", " (',', 'O'),\n", " ('Donald', 'B-PER'),\n", " ('Trump', 'I-PER'),\n", " ('is', 'O'),\n", " ('still', 'O'),\n", " ('the', 'O'),\n", " ('President', 'O'),\n", " ('of', 'O'),\n", " ('the', 'O'),\n", " ('United', 'B-LOC'),\n", " ('States', 'I-LOC'),\n", " ('.', 'O')]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict('As of 2019, Donald Trump is still the President of the United States.')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "reloaded_predictor = ktrain.load_predictor('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Paul', 'B-PER'),\n", " ('Newman', 'I-PER'),\n", " ('is', 'O'),\n", " ('my', 'O'),\n", " ('favorite', 'O'),\n", " ('actor', 'O'),\n", " ('.', 'O')]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict('Paul Newman is my favorite actor.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }