{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.16.1\n" ] } ], "source": [ "from itertools import chain\n", "import nltk\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "from sklearn.preprocessing import LabelBinarizer\n", "import sklearn\n", "import pycrfsuite\n", "\n", "print(sklearn.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Let's use CoNLL 2002 data to build a NER system\n", "\n", "CoNLL2002 corpus is available in NLTK. We use Spanish data." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[u'esp.testa',\n", " u'esp.testb',\n", " u'esp.train',\n", " u'ned.testa',\n", " u'ned.testb',\n", " u'ned.train']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nltk.corpus.conll2002.fileids()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.42 s, sys: 70.4 ms, total: 2.49 s\n", "Wall time: 2.55 s\n" ] } ], "source": [ "%%time\n", "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n", "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data format:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[(u'Melbourne', u'NP', u'B-LOC'),\n", " (u'(', u'Fpa', u'O'),\n", " (u'Australia', u'NP', u'B-LOC'),\n", " (u')', u'Fpt', u'O'),\n", " (u',', u'Fc', u'O'),\n", " (u'25', u'Z', u'O'),\n", " (u'may', u'NC', u'O'),\n", " (u'(', u'Fpa', u'O'),\n", " (u'EFE', u'NC', u'B-ORG'),\n", " (u')', u'Fpt', u'O'),\n", " (u'.', u'Fp', u'O')]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_sents[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Features\n", "\n", "Next, define some features. In this example we use word identity, word suffix, word shape and word POS tag; also, some information from nearby words is used. \n", "\n", "This makes a simple baseline, but you certainly can add and remove some features to get (much?) better results - experiment with it." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def word2features(sent, i):\n", " word = sent[i][0]\n", " postag = sent[i][1]\n", " features = [\n", " 'bias',\n", " 'word.lower=' + word.lower(),\n", " 'word[-3:]=' + word[-3:],\n", " 'word[-2:]=' + word[-2:],\n", " 'word.isupper=%s' % word.isupper(),\n", " 'word.istitle=%s' % word.istitle(),\n", " 'word.isdigit=%s' % word.isdigit(),\n", " 'postag=' + postag,\n", " 'postag[:2]=' + postag[:2],\n", " ]\n", " if i > 0:\n", " word1 = sent[i-1][0]\n", " postag1 = sent[i-1][1]\n", " features.extend([\n", " '-1:word.lower=' + word1.lower(),\n", " '-1:word.istitle=%s' % word1.istitle(),\n", " '-1:word.isupper=%s' % word1.isupper(),\n", " '-1:postag=' + postag1,\n", " '-1:postag[:2]=' + postag1[:2],\n", " ])\n", " else:\n", " features.append('BOS')\n", " \n", " if i < len(sent)-1:\n", " word1 = sent[i+1][0]\n", " postag1 = sent[i+1][1]\n", " features.extend([\n", " '+1:word.lower=' + word1.lower(),\n", " '+1:word.istitle=%s' % word1.istitle(),\n", " '+1:word.isupper=%s' % word1.isupper(),\n", " '+1:postag=' + postag1,\n", " '+1:postag[:2]=' + postag1[:2],\n", " ])\n", " else:\n", " features.append('EOS')\n", " \n", " return features\n", "\n", "\n", "def sent2features(sent):\n", " return [word2features(sent, i) for i in range(len(sent))]\n", "\n", "def sent2labels(sent):\n", " return [label for token, postag, label in sent]\n", "\n", "def sent2tokens(sent):\n", " return [token for token, postag, label in sent] " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is what word2features extracts:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['bias',\n", " u'word.lower=melbourne',\n", " u'word[-3:]=rne',\n", " u'word[-2:]=ne',\n", " 'word.isupper=False',\n", " 'word.istitle=True',\n", " 'word.isdigit=False',\n", " u'postag=NP',\n", " u'postag[:2]=NP',\n", " 'BOS',\n", " u'+1:word.lower=(',\n", " '+1:word.istitle=False',\n", " '+1:word.isupper=False',\n", " u'+1:postag=Fpa',\n", " u'+1:postag[:2]=Fp']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sent2features(train_sents[0])[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Extract the features from the data:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.24 s, sys: 287 ms, total: 2.53 s\n", "Wall time: 2.53 s\n" ] } ], "source": [ "%%time\n", "X_train = [sent2features(s) for s in train_sents]\n", "y_train = [sent2labels(s) for s in train_sents]\n", "\n", "X_test = [sent2features(s) for s in test_sents]\n", "y_test = [sent2labels(s) for s in test_sents]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model\n", "\n", "To train the model, we create pycrfsuite.Trainer, load the training data and call 'train' method. \n", "First, create pycrfsuite.Trainer and load the training data to CRFsuite:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.48 s, sys: 90.2 ms, total: 3.57 s\n", "Wall time: 3.56 s\n" ] } ], "source": [ "%%time\n", "trainer = pycrfsuite.Trainer(verbose=False)\n", "\n", "for xseq, yseq in zip(X_train, y_train):\n", " trainer.append(xseq, yseq)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set training parameters. We will use L-BFGS training algorithm (it is default) with Elastic Net (L1 + L2) regularization." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "trainer.set_params({\n", " 'c1': 1.0, # coefficient for L1 penalty\n", " 'c2': 1e-3, # coefficient for L2 penalty\n", " 'max_iterations': 50, # stop earlier\n", "\n", " # include transitions that are possible, but not observed\n", " 'feature.possible_transitions': True\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Possible parameters for the default training algorithm:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['feature.minfreq',\n", " 'feature.possible_states',\n", " 'feature.possible_transitions',\n", " 'c1',\n", " 'c2',\n", " 'max_iterations',\n", " 'num_memories',\n", " 'epsilon',\n", " 'period',\n", " 'delta',\n", " 'linesearch',\n", " 'max_linesearch']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.params()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train the model:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 18.8 s, sys: 102 ms, total: 18.9 s\n", "Wall time: 19.2 s\n" ] } ], "source": [ "%%time\n", "trainer.train('conll2002-esp.crfsuite')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "trainer.train saves model to a file:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-rw-r--r-- 1 gsh25 staff 600K Jun 22 14:56 ./conll2002-esp.crfsuite\r\n" ] } ], "source": [ "!ls -lh ./conll2002-esp.crfsuite" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also get information about the final state of the model by looking at the trainer's logparser. If we had tagged our input data using the optional group argument in add, and had used the optional holdout argument during train, there would be information about the trainer's performance on the holdout set as well. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{'active_features': 11346,\n", " 'error_norm': 1262.912078,\n", " 'feature_norm': 79.110017,\n", " 'linesearch_step': 1.0,\n", " 'linesearch_trials': 1,\n", " 'loss': 14807.577946,\n", " 'num': 50,\n", " 'scores': {},\n", " 'time': 0.342}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.logparser.last_iteration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also get this information for every step using trainer.logparser.iterations" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50 {'loss': 14807.577946, 'error_norm': 1262.912078, 'linesearch_trials': 1, 'active_features': 11346, 'num': 50, 'time': 0.342, 'scores': {}, 'linesearch_step': 1.0, 'feature_norm': 79.110017}\n" ] } ], "source": [ "print len(trainer.logparser.iterations), trainer.logparser.iterations[-1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make predictions\n", "\n", "To use the trained model, create pycrfsuite.Tagger, open the model and use \"tag\" method:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tagger = pycrfsuite.Tagger()\n", "tagger.open('conll2002-esp.crfsuite')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's tag a sentence to see how it works:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "La Coruña , 23 may ( EFECOM ) .\n", "\n", "Predicted: B-LOC I-LOC O O O O B-ORG O O\n", "Correct: B-LOC I-LOC O O O O B-ORG O O\n" ] } ], "source": [ "example_sent = test_sents[0]\n", "print(' '.join(sent2tokens(example_sent)), end='\\n\\n')\n", "\n", "print(\"Predicted:\", ' '.join(tagger.tag(sent2features(example_sent))))\n", "print(\"Correct: \", ' '.join(sent2labels(example_sent)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate the model" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def bio_classification_report(y_true, y_pred):\n", " \"\"\"\n", " Classification report for a list of BIO-encoded sequences.\n", " It computes token-level metrics and discards \"O\" labels.\n", " \n", " Note that it requires scikit-learn 0.15+ (or a version from github master)\n", " to calculate averages properly!\n", " \"\"\"\n", " lb = LabelBinarizer()\n", " y_true_combined = lb.fit_transform(list(chain.from_iterable(y_true)))\n", " y_pred_combined = lb.transform(list(chain.from_iterable(y_pred)))\n", " \n", " tagset = set(lb.classes_) - {'O'}\n", " tagset = sorted(tagset, key=lambda tag: tag.split('-', 1)[::-1])\n", " class_indices = {cls: idx for idx, cls in enumerate(lb.classes_)}\n", " \n", " return classification_report(\n", " y_true_combined,\n", " y_pred_combined,\n", " labels = [class_indices[cls] for cls in tagset],\n", " target_names = tagset,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predict entity labels for all sentences in our testing set ('testb' Spanish data):" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 598 ms, sys: 17.4 ms, total: 616 ms\n", "Wall time: 615 ms\n" ] } ], "source": [ "%%time\n", "y_pred = [tagger.tag(xseq) for xseq in X_test]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "..and check the result. Note this report is not comparable to results in CONLL2002 papers because here we check per-token results (not per-entity). Per-entity numbers will be worse. " ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " B-LOC 0.78 0.75 0.76 1084\n", " I-LOC 0.87 0.93 0.90 634\n", " B-MISC 0.69 0.47 0.56 339\n", " I-MISC 0.87 0.93 0.90 634\n", " B-ORG 0.82 0.87 0.84 735\n", " I-ORG 0.87 0.93 0.90 634\n", " B-PER 0.61 0.49 0.54 557\n", " I-PER 0.87 0.93 0.90 634\n", "\n", "avg / total 0.81 0.81 0.80 5251\n", "\n" ] } ], "source": [ "print(bio_classification_report(y_test, y_pred))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Let's check what classifier learned" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top likely transitions:\n", "B-ORG -> I-ORG 8.631963\n", "I-ORG -> I-ORG 7.833706\n", "B-PER -> I-PER 6.998706\n", "B-LOC -> I-LOC 6.913675\n", "I-MISC -> I-MISC 6.129735\n", "B-MISC -> I-MISC 5.538291\n", "I-LOC -> I-LOC 4.983567\n", "I-PER -> I-PER 3.748358\n", "B-ORG -> B-LOC 1.727090\n", "B-PER -> B-LOC 1.388267\n", "B-LOC -> B-LOC 1.240278\n", "O -> O 1.197929\n", "O -> B-ORG 1.097062\n", "I-PER -> B-LOC 1.083332\n", "O -> B-MISC 1.046113\n", "\n", "Top unlikely transitions:\n", "I-PER -> B-ORG -2.056130\n", "I-LOC -> I-ORG -2.143940\n", "B-ORG -> I-MISC -2.167501\n", "I-PER -> I-ORG -2.369380\n", "B-ORG -> I-PER -2.378110\n", "I-MISC -> I-PER -2.458788\n", "B-LOC -> I-PER -2.516414\n", "I-ORG -> I-MISC -2.571973\n", "I-LOC -> B-PER -2.697791\n", "I-LOC -> I-PER -3.065950\n", "I-ORG -> I-PER -3.364434\n", "O -> I-PER -7.322841\n", "O -> I-MISC -7.648246\n", "O -> I-ORG -8.024126\n", "O -> I-LOC -8.333815\n" ] } ], "source": [ "from collections import Counter\n", "info = tagger.info()\n", "\n", "def print_transitions(trans_features):\n", " for (label_from, label_to), weight in trans_features:\n", " print(\"%-6s -> %-7s %0.6f\" % (label_from, label_to, weight))\n", "\n", "print(\"Top likely transitions:\")\n", "print_transitions(Counter(info.transitions).most_common(15))\n", "\n", "print(\"\\nTop unlikely transitions:\")\n", "print_transitions(Counter(info.transitions).most_common()[-15:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that, for example, it is very likely that the beginning of an organization name (B-ORG) will be followed by a token inside organization name (I-ORG), but transitions to I-ORG from tokens with other labels are penalized. Also note I-PER -> B-LOC transition: a positive weight means that model thinks that a person name is often followed by a location.\n", "\n", "Check the state features:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top positive:\n", "8.886516 B-ORG word.lower=efe-cantabria\n", "8.743642 B-ORG word.lower=psoe-progresistas\n", "5.769032 B-LOC -1:word.lower=cantabria\n", "5.195429 I-LOC -1:word.lower=calle\n", "5.116821 O word.lower=mayo\n", "4.990871 O -1:word.lower=día\n", "4.910915 I-ORG -1:word.lower=l\n", "4.721572 B-MISC word.lower=diversia\n", "4.676259 B-ORG word.lower=telefónica\n", "4.334354 B-ORG word[-2:]=-e\n", "4.149862 B-ORG word.lower=amena\n", "4.141370 B-ORG word.lower=terra\n", "3.942852 O word.istitle=False\n", "3.926397 B-ORG word.lower=continente\n", "3.924672 B-ORG word.lower=acesa\n", "3.888706 O word.lower=euro\n", "3.856445 B-PER -1:word.lower=según\n", "3.812373 B-MISC word.lower=exteriores\n", "3.807582 I-MISC -1:word.lower=1.9\n", "3.807098 B-MISC word.lower=sanidad\n", "\n", "Top negative:\n", "-1.965379 O word.lower=fundación\n", "-1.981541 O -1:word.lower=británica\n", "-2.118347 O word.lower=061\n", "-2.190653 B-PER word[-3:]=nes\n", "-2.226373 B-ORG postag=SP\n", "-2.226373 B-ORG postag[:2]=SP\n", "-2.260972 O word[-3:]=uia\n", "-2.384920 O -1:word.lower=sección\n", "-2.483009 O word[-2:]=s.\n", "-2.535050 I-LOC BOS\n", "-2.583123 O -1:word.lower=sánchez\n", "-2.585756 O postag[:2]=NP\n", "-2.585756 O postag=NP\n", "-2.588899 O word[-2:]=om\n", "-2.738583 O -1:word.lower=carretera\n", "-2.913103 O word.istitle=True\n", "-2.926560 O word[-2:]=nd\n", "-2.946862 I-PER -1:word.lower=san\n", "-2.954094 B-PER -1:word.lower=del\n", "-3.529449 O word.isupper=True\n" ] } ], "source": [ "def print_state_features(state_features):\n", " for (attr, label), weight in state_features:\n", " print(\"%0.6f %-6s %s\" % (weight, label, attr)) \n", "\n", "print(\"Top positive:\")\n", "print_state_features(Counter(info.state_features).most_common(20))\n", "\n", "print(\"\\nTop negative:\")\n", "print_state_features(Counter(info.state_features).most_common()[-20:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some observations:\n", "\n", "* **8.743642 B-ORG word.lower=psoe-progresistas** - the model remembered names of some entities - maybe it is overfit, or maybe our features are not adequate, or maybe remembering is indeed helpful;\n", "* **5.195429 I-LOC -1:word.lower=calle**: \"calle\" is a street in Spanish; model learns that if a previous word was \"calle\" then the token is likely a part of location;\n", "* **-3.529449 O word.isupper=True**, ** -2.913103 O word.istitle=True **: UPPERCASED or TitleCased words are likely entities of some kind;\n", "* **-2.585756 O postag=NP** - proper nouns (NP is a proper noun in the Spanish tagset) are often entities." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What to do next\n", "\n", "1. Load 'testa' Spanish data.\n", "2. Use it to develop better features and to find best model parameters.\n", "3. Apply the model to 'testb' data again.\n", "\n", "The model in this notebook is just a starting point; you certainly can do better!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 0 }