{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Named Entity Recognition (NER) with TensorFlow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Slides: https://docs.google.com/presentation/d/1eUEOTSeUnR2Sz1uDF4e3YvBaxQ1kLUjog9qkhgBEMV0/edit?usp=sharing" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import pandas as pd\n", "import numpy as np\n", "import csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Preparation\n", "\n", "Data is from here https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus/data, download ner_dataset.csv from the ZIP archive.\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "validation_sentence = 'While speaking on Channels Television on Thursday April 5 2018 Adesina said the fund is not just to intensify the military fight against Boko Haram but to fight other forms of insecurity in the country'\n", "\n", "validation_tags = ['O', 'O', 'O', 'B-ORG', 'I-ORG', 'O', 'B-TIM', 'I-TIM', 'I-TIM', 'I-TIM', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',\n", " 'O', 'B-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we parse the file to load sentences and tags into different lists. Also, we only want sentences not more than 35 words long (same length as our validation sentence above)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "sentences = []\n", "tags = []\n", "max_length = 50\n", "\n", "with open('data/ner_dataset.csv', 'rb') as csvfile:\n", " ner_data = csv.reader(csvfile, delimiter=',')\n", " sentence = []\n", " tag = []\n", " for row in ner_data:\n", " \n", " if row[3] == 'Tag':\n", " continue\n", " \n", " sentence.append(row[1])\n", " tag.append(row[3].upper())\n", " \n", " if row[1] == '.':\n", " if len(sentence) <= max_length:\n", " sentences.append(sentence)\n", " tags.append(tag)\n", " sentence = []\n", " tag = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is sample entries of `sentences` and `tags`" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[['Thousands', 'of', 'demonstrators', 'have', 'marched', 'through', 'London', 'to', 'protest', 'the', 'war', 'in', 'Iraq', 'and', 'demand', 'the', 'withdrawal', 'of', 'British', 'troops', 'from', 'that', 'country', '.'], ['Families', 'of', 'soldiers', 'killed', 'in', 'the', 'conflict', 'joined', 'the', 'protesters', 'who', 'carried', 'banners', 'with', 'such', 'slogans', 'as', '\"', 'Bush', 'Number', 'One', 'Terrorist', '\"', 'and', '\"', 'Stop', 'the', 'Bombings', '.']]\n", "\n", "[['O', 'O', 'O', 'O', 'O', 'O', 'B-GEO', 'O', 'O', 'O', 'O', 'O', 'B-GEO', 'O', 'O', 'O', 'O', 'O', 'B-GPE', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]\n" ] } ], "source": [ "print sentences[:2]\n", "print\n", "print tags[:2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll need to create a vocabulary from our sentences i.e a set of unique words. We'll do same for the tags too" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "unique_tags = list(set(t for tagset in tags for t in tagset))\n", "vocabulary = list(set(word for sentence in sentences for word in sentence))" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[' ', 'I-ART', 'B-EVE', 'B-GPE', 'B-TIM', 'I-TIM', 'B-ORG', 'B-ART', 'O', 'B-GEO', 'I-GPE', 'I-EVE', 'I-GEO', 'B-PER', 'I-PER', 'I-ORG', 'I-NAT', 'B-NAT']\n" ] } ], "source": [ "print unique_tags" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['heavily-fortified', 'mid-week', '1,800', 'land-for-peace', 'Pronk', 'woods', 'Safarova', 'Nampo', 'hanging', 'trawling']\n", "Number of words in vocabulary 34930\n" ] } ], "source": [ "print vocabulary[:10]\n", "print 'Number of words in vocabulary', len(vocabulary)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_sentences = sentences[:int(.7 * len(sentences))]\n", "train_tags = tags[:int(.7 * len(tags))]\n", "\n", "test_sentences = sentences[int(.7 * len(tags) + 1):]\n", "test_tags = tags[int(.7 * len(tags) + 1):]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(33341, 14289, 47631)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_sentences), len(test_sentences), len(sentences)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Architecture\n", "Simple LSTM network with a softmax at the end\n", "\n", "Important NOTE: If you want to run the network using a one-hot encoding of the words, make sure `batch_size` is set to something low. Higher values might result in your computer freezing. I tried on my core i5, 8GB RAM laptop and it wasn't pleasant. So stick with default value of 8 for batch_size or lower." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Parameters\n", "learning_rate = 0.001\n", "batch_size = 8\n", "target_size = len(unique_tags)\n", "display_size = 50\n", "\n", "# Network Parameters\n", "n_features = len(vocabulary)\n", "sequence_length = 10\n", "n_units = 64\n", "\n", "tf.reset_default_graph()\n", "\n", "# tf Graph input\n", "X = tf.placeholder('float', [None, max_length, n_features], name='X')\n", "Y = tf.placeholder('float', [None, max_length, target_size], name='Y')\n", "\n", "# Define weights\n", "weights = {\n", " 'out': tf.Variable(tf.random_normal([n_units, target_size]))\n", "}\n", "biases = {\n", " 'out': tf.Variable(tf.random_normal([target_size]))\n", "}" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "\n", "cell = tf.contrib.rnn.LSTMCell(n_hidden, state_is_tuple=True)\n", "\n", "output, _ = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)\n", "\n", "output = tf.reshape(output, [-1, n_units])\n", "\n", "prediction = tf.matmul(output, weights['out']) + biases['out']\n", "cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=Y))\n", "\n", "minimize = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of batches: 4167\n" ] } ], "source": [ "init = tf.global_variables_initializer()\n", "num_batches = int(len(train_sentences)) / batch_size\n", "epoch = 1\n", "print 'Number of batches:', num_batches" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "33341" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_sentences)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run graph using one-hot encoding of words" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "with tf.Session() as sess:\n", " sess.run(init)\n", " for i in range(epoch): \n", " \n", " for j in range(num_batches):\n", " ptr = 0\n", " batch_X = []\n", " batch_Y = []\n", " sequence_length = []\n", " for _ in range(batch_size):\n", " x, y = (train_sentences[ptr: ptr + 1], \n", " train_tags[ptr: ptr + 1]) \n", "\n", " x_one_hot = []\n", "\n", " for s in x[0]:\n", " x_one_hot.append(np.eye(len(vocabulary))[vocabulary.index(s)])\n", " \n", " sequence_length.append(len(x_one_hot))\n", " \n", " for remainder in range(max_length - len(x_one_hot)):\n", " x_one_hot.append([0]*len(vocabulary))\n", " \n", " batch_X.append(x_one_hot) \n", "\n", " y_one_hot = []\n", "\n", " for t in y[0]:\n", " y_one_hot.append(np.eye(target_size)[unique_tags.index(t)])\n", " \n", " for remainder in range(max_length - len(y_one_hot)):\n", " y_one_hot.append(np.eye(target_size)[unique_tags.index('O')])\n", " \n", " batch_Y.append(y_one_hot)\n", "\n", " ptr += 1\n", " \n", " _, entropy, preds = sess.run([minimize, cross_entropy, prediction], \n", " {X: np.array(batch_X).reshape(batch_size, max_length, len(vocabulary)), \n", " Y: np.array(batch_Y).reshape(batch_size, max_length, target_size),\n", " sequence_lengths: np.array(sequence_length)})\n", " \n", " if j % display_size == 0:\n", " print 'Loss at batch {0}'.format(j), entropy\n", "\n", " print \"Epoch \",str(i)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Word Embeddings\n", "We'll use Google's word2vec which you can grab from here https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit.\n", "To load the word embeddings, we'll neeed another tool, `gensim`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from gensim.models import word2vec, KeyedVectors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the word vectors like so. This operations takes a good while on my laptop; core i5." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "w2v = KeyedVectors.load_word2vec_format('/Users/h/Projects/Machine-Learning/GoogleNews-vectors-negative300.bin.gz', binary=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is how `boy` is represented according to the embedding" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 2.35351562e-01, 1.65039062e-01, 9.32617188e-02, -1.28906250e-01,\n", " 1.59912109e-02, 3.61328125e-02, -1.16699219e-01, -7.32421875e-02,\n", " 1.38671875e-01, 1.15356445e-02, 1.87500000e-01, -2.91015625e-01,\n", " 1.70898438e-02, -1.84570312e-01, -2.87109375e-01, 2.54821777e-03,\n", " -2.19726562e-01, 1.77734375e-01, -1.20605469e-01, 5.39550781e-02,\n", " 3.78417969e-02, 2.49023438e-01, 1.76757812e-01, 2.69775391e-02,\n", " 1.21093750e-01, -3.51562500e-01, -5.83496094e-02, 1.22070312e-01,\n", " 5.97656250e-01, -1.60156250e-01, 1.08398438e-01, -2.40478516e-02,\n", " -1.16699219e-01, 3.58886719e-02, -2.37304688e-01, 1.15234375e-01,\n", " 5.27343750e-01, -2.18750000e-01, -4.54101562e-02, 3.30078125e-01,\n", " 3.75976562e-02, -5.51757812e-02, 3.26171875e-01, 6.74438477e-03,\n", " 3.71093750e-01, 3.68652344e-02, 6.68945312e-02, 5.17578125e-02,\n", " -4.76074219e-02, -7.91015625e-02, 4.46777344e-02, 1.67968750e-01,\n", " 5.51757812e-02, -2.91015625e-01, 2.59765625e-01, -1.00097656e-01,\n", " -1.09863281e-01, -9.15527344e-03, 2.63671875e-02, -3.44238281e-02,\n", " 9.37500000e-02, 3.53515625e-01, 8.39843750e-02, -7.75146484e-03,\n", " 8.64257812e-02, -5.24902344e-02, -5.59082031e-02, -8.59375000e-02,\n", " 5.37109375e-02, -1.47094727e-02, 3.63769531e-02, 4.68750000e-02,\n", " -3.39843750e-01, 1.28906250e-01, -1.22558594e-01, 4.57031250e-01,\n", " 1.27929688e-01, -2.89062500e-01, 1.56250000e-01, 3.73535156e-02,\n", " 2.75390625e-01, -1.28784180e-02, -1.50390625e-01, -1.64062500e-01,\n", " -3.39843750e-01, 8.00781250e-02, -9.21630859e-03, 2.78320312e-02,\n", " 9.32617188e-02, 2.25830078e-02, -1.62353516e-02, -8.25195312e-02,\n", " -1.90429688e-02, -3.49121094e-02, 9.42382812e-02, 3.66210938e-02,\n", " 6.39648438e-02, 2.00195312e-01, -4.05273438e-02, -1.08886719e-01,\n", " -3.93676758e-03, -2.55859375e-01, 6.78710938e-02, -1.89453125e-01,\n", " 1.72851562e-01, -1.73828125e-01, 2.07031250e-01, -1.59179688e-01,\n", " 2.85339355e-03, -1.80664062e-01, -6.93359375e-02, 2.05078125e-01,\n", " 5.93261719e-02, -2.17773438e-01, -1.36718750e-01, -4.91333008e-03,\n", " -1.38671875e-01, -7.47070312e-02, -3.54003906e-02, 1.13769531e-01,\n", " 3.07617188e-02, -1.05957031e-01, -3.30078125e-01, -2.72216797e-02,\n", " -1.94091797e-02, 9.52148438e-02, 8.69140625e-02, -2.16064453e-02,\n", " -6.98242188e-02, -1.73828125e-01, -1.60156250e-01, -2.44140625e-01,\n", " 9.82666016e-03, 2.24609375e-02, -2.13867188e-01, 1.91406250e-01,\n", " 2.01171875e-01, 2.72216797e-02, 2.81982422e-02, 2.42187500e-01,\n", " 3.55468750e-01, -5.32226562e-02, 1.78710938e-01, 6.78710938e-02,\n", " -6.73828125e-02, 3.49609375e-01, -1.92382812e-01, -1.00097656e-02,\n", " -2.05078125e-01, -1.59179688e-01, 3.76953125e-01, -2.15820312e-01,\n", " -2.36328125e-01, 6.49414062e-02, -1.39770508e-02, 4.22363281e-02,\n", " 2.51464844e-02, -1.00585938e-01, 1.37695312e-01, -2.43164062e-01,\n", " 1.20605469e-01, 2.03857422e-02, 3.12500000e-01, 1.09863281e-01,\n", " -1.04980469e-01, -9.13085938e-02, 2.21679688e-01, -1.04003906e-01,\n", " 1.25976562e-01, 5.10253906e-02, 6.39648438e-02, -1.15722656e-01,\n", " -3.19824219e-02, -8.34960938e-02, -1.97265625e-01, -2.33154297e-02,\n", " 1.94335938e-01, 2.24609375e-01, -2.30468750e-01, 4.17480469e-02,\n", " 6.49414062e-02, -1.70898438e-01, 7.86132812e-02, -3.58886719e-02,\n", " -1.66015625e-01, 2.25585938e-01, 1.23535156e-01, 1.08398438e-01,\n", " 1.15722656e-01, 7.37304688e-02, -1.56250000e-02, -5.85937500e-02,\n", " -8.93554688e-02, 1.30859375e-01, 1.90429688e-01, -3.58886719e-02,\n", " -1.36718750e-02, -1.88476562e-01, -1.48437500e-01, -2.51953125e-01,\n", " -1.22558594e-01, -2.75390625e-01, -1.54296875e-01, -2.83203125e-01,\n", " 1.10839844e-01, -2.46093750e-01, 1.89453125e-01, -2.50244141e-02,\n", " 8.59375000e-02, -1.17675781e-01, -2.46582031e-02, -1.32812500e-01,\n", " 1.00097656e-01, -2.45117188e-01, -2.02148438e-01, -7.56835938e-02,\n", " 6.03027344e-02, 1.72851562e-01, -6.59179688e-02, 6.78710938e-02,\n", " 6.98242188e-02, -4.10156250e-02, 2.14843750e-01, 7.17773438e-02,\n", " -4.57763672e-03, -4.04357910e-04, 8.59375000e-02, -2.55859375e-01,\n", " -4.32128906e-02, -1.31835938e-01, 2.05078125e-02, -2.46093750e-01,\n", " -1.28906250e-01, 1.23535156e-01, -1.48437500e-01, 5.15136719e-02,\n", " -1.55273438e-01, -1.70898438e-01, 1.92382812e-01, 2.16796875e-01,\n", " 5.81054688e-02, -1.28906250e-01, -1.43554688e-01, -7.78198242e-03,\n", " -1.15234375e-01, 4.08203125e-01, -3.37890625e-01, 8.64257812e-02,\n", " 2.08007812e-01, 2.35595703e-02, 1.36718750e-01, -4.71191406e-02,\n", " 9.91210938e-02, 1.18164062e-01, 1.19140625e-01, 1.24511719e-01,\n", " 4.66308594e-02, 5.41992188e-02, -2.11914062e-01, -8.20312500e-02,\n", " -5.17578125e-02, 2.03857422e-02, -1.59179688e-01, -1.76757812e-01,\n", " 8.54492188e-02, 1.38671875e-01, -1.01562500e-01, 2.61230469e-02,\n", " -1.88476562e-01, -1.57470703e-02, 1.21093750e-01, -9.66796875e-02,\n", " 2.13623047e-02, -6.68945312e-02, -2.69775391e-02, 3.51562500e-02,\n", " 1.68945312e-01, 1.55639648e-02, -1.25976562e-01, -1.44531250e-01,\n", " 1.78710938e-01, -7.42187500e-02, 2.72216797e-02, 4.98046875e-01,\n", " -6.03027344e-02, -1.35742188e-01, -1.62109375e-01, 9.57031250e-02,\n", " -1.84326172e-02, 3.90625000e-01, 1.90429688e-02, -1.03149414e-02,\n", " -1.15234375e-01, -2.91015625e-01, -5.95703125e-02, -5.37109375e-02,\n", " -7.42187500e-02, -2.65625000e-01, -1.03027344e-01, 1.35742188e-01],\n", " dtype=float32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w2v.word_vec('boy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run graph with words represented as word2vec\n", "Same as architecture as pervious except `n_features` is now the dimension of the vector returned by word2vec" ] }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [], "source": [ "# Parameters\n", "learning_rate = 1e-5\n", "batch_size = 32\n", "target_size = len(unique_tags)\n", "display_size = 50\n", "\n", "# Network Parameters\n", "n_features = 300 # dimension of the vector return by word2vec\n", "sequence_length = max_length\n", "n_units = 128\n", "\n", "tf.reset_default_graph()\n", "\n", "# tf Graph input\n", "X = tf.placeholder('float', [None, max_length, n_features], name='X')\n", "Y = tf.placeholder('int32', [None], name='Y')\n", "sequence_lengths = tf.placeholder('int32', [None])\n", "\n", "# Define weights\n", "weights = {\n", " 'out': tf.Variable(tf.truncated_normal([2 * n_units, target_size]))\n", "}\n", "biases = {\n", " 'out': tf.Variable(tf.constant(0.1, shape=[target_size]))\n", "}" ] }, { "cell_type": "code", "execution_count": 152, "metadata": {}, "outputs": [], "source": [ "\n", "with tf.variable_scope('forward'):\n", " lstm_fw_cell = tf.contrib.rnn.LSTMCell(n_units, state_is_tuple=True)\n", "\n", "with tf.variable_scope('backward'):\n", " lstm_bw_cell = tf.contrib.rnn.LSTMCell(n_units, state_is_tuple=True)\n", "\n", "(output_fw, output_bw), states = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_fw_cell, \n", " cell_bw=lstm_bw_cell, \n", " inputs=X, \n", " sequence_length=sequence_lengths,\n", " dtype=tf.float32)\n", "output = tf.concat([output_fw, output_bw], axis=2)\n", "output = tf.reshape(output, [-1, 2 * n_units])\n", "prediction = tf.matmul(output, weights['out']) + biases['out']\n", "flattened_prediction = tf.reshape(prediction, [-1, target_size])\n", "losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=flattened_prediction, labels=Y)\n", "mask = tf.sequence_mask(sequence_lengths)\n", "losses = tf.boolean_mask(tf.reshape(losses, [-1, max_length]), mask)\n", "loss = tf.reduce_mean(losses)\n", "\n", "minimize = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)" ] }, { "cell_type": "code", "execution_count": 153, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of batches: 1041\n" ] } ], "source": [ "init = tf.global_variables_initializer()\n", "num_batches = int(len(train_sentences)) / batch_size\n", "epoch = 3\n", "print 'Number of batches:', num_batches" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss at batch 0 2.8496134\n", "Loss at batch 50 2.740778\n", "Loss at batch 100 2.4188466\n", "Loss at batch 150 2.1874278\n", "Loss at batch 200 1.9713194\n", "Loss at batch 250 2.1214645\n", "Loss at batch 300 2.0436077\n" ] } ], "source": [ "with tf.Session() as sess:\n", " sess.run(init)\n", " for i in range(epoch): \n", " ptr = 0\n", " for j in range(num_batches): \n", " batch_X = []\n", " batch_Y = []\n", " sequence_length = []\n", " for _ in range(batch_size):\n", " x, y = (train_sentences[ptr: ptr + 1], \n", " train_tags[ptr: ptr + 1]) \n", "\n", " x_word_vector = []\n", " \n", " sequence_length.append(len(x[0]))\n", " \n", " for s in x[0]:\n", " try:\n", " x_word_vector.append(w2v.word_vec(s))\n", " except:\n", " #if word isn't in the word2vec, use zeroes\n", " x_word_vector.append([0]*n_features) \n", " \n", " for remainder in range(max_length - len(x_word_vector)):\n", " #pad sentence remainder with zeroes\n", " x_word_vector.append([0]*n_features)\n", " \n", " batch_X.append(x_word_vector) \n", "\n", " y_word_vector = []\n", "\n", " for t in y[0]:\n", " y_word_vector.append(unique_tags.index(t))\n", " \n", " for remainder in range(max_length - len(y_word_vector)):\n", " y_word_vector.append(0)\n", " \n", " batch_Y.append(y_word_vector)\n", "\n", " ptr += 1\n", " \n", " _, entropy, preds = sess.run([minimize, loss, prediction],{X: np.array(batch_X).reshape(batch_size, max_length, n_features), \n", " Y: np.array(batch_Y).reshape(-1), \n", " sequence_lengths: np.array(sequence_length)})\n", "\n", " \n", " if j % display_size == 0:\n", " print 'Loss at batch {0}'.format(j), entropy\n", "\n", " print \"Epoch \", str(i)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obvious benefit of using word2vec is that the network runs faster, converges quicker too. Runs faster because we've reduced the feature representation from an outrageous dimension in the length of the vocabulary (thousands) to only 300, the dimension of the array returned by word2vec." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tf.Session() as sess:\n", " sess.run(init)\n", "\n", " valid_X = []\n", " \n", " for word in validation_sentence.split(' '):\n", " try:\n", " valid_X.append(w2v.word_vec(word))\n", " except:\n", " #if word isn't in the word2vec, use zeroes\n", " valid_X.append([0]*n_features)\n", " \n", " for remainder in range(max_length - len(valid_X)):\n", " #pad sentence remainder with zeroes\n", " valid_X.append([0]*n_features)\n", "\n", " valid_Y = []\n", "\n", " for t in validation_tags:\n", " valid_Y.append(unique_tags.index(t))\n", " \n", " for remainder in range(max_length - len(valid_Y)):\n", " valid_Y.append(0)\n", " \n", " preds = sess.run([prediction],{X: np.array(valid_X).reshape(1, max_length, n_features), Y: np.array(valid_Y), sequence_lengths: [35]})\n", "\n", " valid_words = validation_sentence.split(' ')\n", " preds = np.array(preds).reshape(max_length, target_size)\n", "\n", " for i, p in enumerate(preds):\n", " if i <= 34:\n", " print 'Word:', valid_words[i]\n", " print 'Actual:', validation_tags[i]\n", " print 'Predicted:', unique_tags[np.argmax(p[:34])]\n", " print " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Things to try\n", "- Add dropout\n", "- Replace softmax with Linear-Chain CRF\n", "- Try other word representations; Glove?\n", "- Tune batch size, learning rate, number of units in LSTM cell, number of epoch\n", "- Try GRU unit\n", "- Add MOAR layers!!!\n", "- Use longer sentences\n", "\n", "\n", "More importantly, train on a better dataset. Like I mentioned, NER is domain specific. Our validation sentence contains details perhaps specific to Nigeria:\n", " - the name Adesina and\n", " - Channels Television" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Resources\n", "- Sequence Tagging with Tensorflow https://guillaumegenthial.github.io/sequence-tagging-with-tensorflow.html" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.14" } }, "nbformat": 4, "nbformat_minor": 2 }