{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess Data\n", "\n", "\n", "The CoNLL2003 NER dataset can be downloaded from [here](https://github.com/amaiya/ktrain/tree/master/ktrain/tests/conll2003)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: utf-8 (if wrong, set manually)\n", "Number of sentences: 14041\n", "Number of words in the dataset: 23623\n", "Tags: ['B-PER', 'O', 'I-MISC', 'B-ORG', 'I-LOC', 'I-ORG', 'I-PER', 'B-MISC', 'B-LOC']\n", "Number of Labels: 9\n", "Longest sentence: 113 words\n" ] } ], "source": [ "TDATA = 'data/conll2003/train.txt'\n", "VDATA = 'data/conll2003/valid.txt'\n", "(trn, val, preproc) = text.entities_from_conll2003(TDATA, val_filepath=VDATA)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Define a Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example notebook, we will build a Bidirectional LSTM model that employs the use of [pretrained BERT word embeddings](https://arxiv.org/abs/1810.04805). By default, the `sequence_tagger` will use a pretrained multilingual model (i.e., `bert-base-multilingual-cased`) that supports 157 different languages. However, since we are training a English-language model on an English-only dataset, it is better to select the English pretrained BERT model: `bert-base-cased`. Notice that we selected the **cased** model, as case is important for English NER, as entities are often capitalized. A full list of available pretrained models is [listed here](https://huggingface.co/transformers/pretrained_models.html). *ktrain* currently supports any `bert-*` model in addition to any `distilbert-*` model. One can also employ the use of BERT-based [community-uploaded moels](https://huggingface.co/models) that focus on specific domains such as the biomedical or scientific domains (e.g, BioBERT, SciBERT). To use SciBERT, for example, set `bert_model` to `allenai/scibert_scivocab_uncased`. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bilstm: Bidirectional LSTM (https://arxiv.org/abs/1603.01360)\n", "bilstm-bert: Bidirectional LSTM w/ BERT embeddings\n", "bilstm-crf: Bidirectional LSTM-CRF (https://arxiv.org/abs/1603.01360)\n", "bilstm-elmo: Bidirectional LSTM w/ Elmo embeddings [English only]\n", "bilstm-crf-elmo: Bidirectional LSTM-CRF w/ Elmo embeddings [English only]\n" ] } ], "source": [ "text.print_sequence_taggers()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embedding schemes employed (combined with concatenation):\n", "\tword embeddings initialized randomly\n", "\tBERT embeddings with bert-base-cased\n", "\n" ] } ], "source": [ "model = text.sequence_tagger('bilstm-bert', preproc, bert_model='bert-base-cased')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the output above, we see that the model is configured to use both BERT pretrained word embeddings and randomly-initialized word emeddings. Instead of randomly-initialized word vectors, one can also select pretrained fasttext word vectors from [Facebook's fasttext site](https://fasttext.cc/docs/en/crawl-vectors.html) and supply the URL via the `wv_path_or_url` parameter:\n", "```python\n", "wv_path_or_url='https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz')\n", "```\n", "We have not used fasttext word embeddings in this example - only BERT word embeddings." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train and Evaluate Model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preparing train data ...done.\n", "preparing valid data ...done.\n", "Train for 110 steps, validate for 26 steps\n", "Epoch 1/10\n", "110/110 [==============================] - 61s 551ms/step - loss: 0.1072 - val_loss: 0.0301\n", "Epoch 2/10\n", "110/110 [==============================] - 54s 489ms/step - loss: 0.0280 - val_loss: 0.0214\n", "Epoch 3/10\n", "110/110 [==============================] - 53s 486ms/step - loss: 0.0159 - val_loss: 0.0179\n", "Epoch 4/10\n", "110/110 [==============================] - 54s 487ms/step - loss: 0.0104 - val_loss: 0.0165\n", "Epoch 5/10\n", "110/110 [==============================] - 53s 484ms/step - loss: 0.0087 - val_loss: 0.0161\n", "Epoch 6/10\n", "110/110 [==============================] - 53s 481ms/step - loss: 0.0129 - val_loss: 0.0176\n", "Epoch 7/10\n", "110/110 [==============================] - 53s 480ms/step - loss: 0.0094 - val_loss: 0.0167\n", "Epoch 8/10\n", "110/110 [==============================] - 53s 481ms/step - loss: 0.0060 - val_loss: 0.0164\n", "Epoch 9/10\n", "110/110 [==============================] - 53s 485ms/step - loss: 0.0041 - val_loss: 0.0155\n", "Epoch 10/10\n", "110/110 [==============================] - 53s 486ms/step - loss: 0.0033 - val_loss: 0.0157\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.01, 2, cycle_len=5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " F1: 92.62\n", " precision recall f1-score support\n", "\n", " MISC 0.84 0.84 0.84 922\n", " PER 0.96 0.96 0.96 1842\n", " LOC 0.95 0.96 0.96 1837\n", " ORG 0.88 0.92 0.90 1341\n", "\n", "micro avg 0.92 0.93 0.93 5942\n", "macro avg 0.92 0.93 0.93 5942\n", "\n" ] }, { "data": { "text/plain": [ "0.9262014208106979" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the `view_top_losses` method to inspect the sentences we're getting the most wrong. Here, we can see our model has trouble with titles, which is understandable since it is mixed into a catch-all miscellaneous category." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total incorrect: 11\n", "Word True : (Pred)\n", "==============================\n", "The :O (O)\n", "titles :O (O)\n", "of :O (O)\n", "his :O (O)\n", "other :O (O)\n", "novels :O (O)\n", "translate :O (O)\n", "as :O (O)\n", "\" :O (O)\n", "In :B-MISC (O)\n", "the :I-MISC (O)\n", "Year :I-MISC (O)\n", "of :I-MISC (I-MISC)\n", "January :I-MISC (O)\n", "\" :O (O)\n", "( :O (O)\n", "1963 :O (O)\n", ") :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "The :B-MISC (O)\n", "Collapse :I-MISC (O)\n", "\" :O (O)\n", "( :O (O)\n", "1964 :O (O)\n", ") :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "Sleeping :B-MISC (B-MISC)\n", "Bread :I-MISC (I-MISC)\n", "\" :O (O)\n", "( :O (O)\n", "1975 :O (O)\n", ") :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "The :B-MISC (O)\n", "Decaying :I-MISC (B-MISC)\n", "Mansion :I-MISC (I-MISC)\n", "\" :O (O)\n", "( :O (O)\n", "1977 :O (O)\n", ") :O (O)\n", "and :O (O)\n", "\" :O (O)\n", "A :B-MISC (B-MISC)\n", "World :I-MISC (I-MISC)\n", "of :I-MISC (I-MISC)\n", "Things :I-MISC (O)\n", "\" :O (O)\n", "( :O (O)\n", "1982 :O (O)\n", ") :O (O)\n", ", :O (O)\n", "followed :O (O)\n", "by :O (O)\n", "\" :O (O)\n", "The :B-MISC (O)\n", "Knot :I-MISC (O)\n", ", :O (O)\n", "\" :O (O)\n", "\" :O (O)\n", "Soul :B-MISC (B-MISC)\n", "Alone :I-MISC (I-MISC)\n", "\" :O (O)\n", "and :O (O)\n", ", :O (O)\n", "most :O (O)\n", "recently :O (O)\n", ", :O (O)\n", "\" :O (O)\n", "A :B-MISC (B-MISC)\n", "Woman :I-MISC (I-MISC)\n", ". :O (O)\n", "\" :O (O)\n", "\n", "\n" ] } ], "source": [ "learner.view_top_losses(n=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions on New Data" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('As', 'O'),\n", " ('of', 'O'),\n", " ('2019', 'O'),\n", " (',', 'O'),\n", " ('Donald', 'B-PER'),\n", " ('Trump', 'I-PER'),\n", " ('is', 'O'),\n", " ('still', 'O'),\n", " ('the', 'O'),\n", " ('President', 'O'),\n", " ('of', 'O'),\n", " ('the', 'O'),\n", " ('United', 'B-LOC'),\n", " ('States', 'I-LOC'),\n", " ('.', 'O')]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict('As of 2019, Donald Trump is still the President of the United States.')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "reloaded_predictor = ktrain.load_predictor('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Paul', 'B-PER'),\n", " ('Newman', 'I-PER'),\n", " ('is', 'O'),\n", " ('my', 'O'),\n", " ('favorite', 'O'),\n", " ('actor', 'O'),\n", " ('.', 'O')]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict('Paul Newman is my favorite actor.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }