{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Learning for Natural Language Processing with Pytorch\n", "This tutorial will walk you through the key ideas of deep learning programming using Pytorch.\n", "Many of the concepts (such as the computation graph abstraction and autograd) are not unique to Pytorch and are relevant to any deep learning tool kit out there.\n", "\n", "I am writing this tutorial to focus specifically on NLP for people who have never written code in any deep learning framework (e.g, TensorFlow, Theano, Keras, Dynet). It assumes working knowledge of core NLP problems: part-of-speech tagging, language modeling, etc. It also assumes familiarity with neural networks at the level of an intro AI class (such as one from the Russel and Norvig book). Usually, these courses cover the basic backpropagation algorithm on feed-forward neural networks, and make the point that they are chains of compositions of linearities and non-linearities. This tutorial aims to get you started writing deep learning code, given you have this prerequisite knowledge.\n", "\n", "Note this is about *models*, not data. For all of the models, I just create a few test examples with small dimensionality so you can see how the weights change as it trains. If you have some real data you want to try, you should be able to rip out any of the models from this notebook and use them on it." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.autograd as autograd\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "torch.manual_seed(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Introduction to Torch's tensor library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All of deep learning is computations on tensors, which are generalizations of a matrix that can be indexed in more than 2 dimensions. We will see exactly what this means in-depth later. First, lets look what we can do with tensors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating Tensors\n", "Tensors can be created from Python lists with the torch.Tensor() function." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 1\n", " 2\n", " 3\n", "[torch.FloatTensor of size 3]\n", "\n", "\n", " 1 2 3\n", " 4 5 6\n", "[torch.FloatTensor of size 2x3]\n", "\n", "\n", "(0 ,.,.) = \n", " 1 2\n", " 3 4\n", "\n", "(1 ,.,.) = \n", " 5 6\n", " 7 8\n", "[torch.FloatTensor of size 2x2x2]\n", "\n" ] } ], "source": [ "# Create a torch.Tensor object with the given data. It is a 1D vector\n", "V_data = [1., 2., 3.]\n", "V = torch.Tensor(V_data)\n", "print V\n", "\n", "# Creates a matrix\n", "M_data = [[1., 2., 3.], [4., 5., 6]]\n", "M = torch.Tensor(M_data)\n", "print M\n", "\n", "# Create a 3D tensor of size 2x2x2.\n", "T_data = [[[1.,2.], [3.,4.]],\n", " [[5.,6.], [7.,8.]]]\n", "T = torch.Tensor(T_data)\n", "print T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is a 3D tensor anyway?\n", "Think about it like this.\n", "If you have a vector, indexing into the vector gives you a scalar. If you have a matrix, indexing into the matrix gives you a vector. If you have a 3D tensor, then indexing into the tensor gives you a matrix!\n", "\n", "A note on terminology: when I say \"tensor\" in this tutorial, it refers to any torch.Tensor object. Vectors and matrices are special cases of torch.Tensors, where their dimension is 1 and 2 respectively. When I am talking about 3D tensors, I will explicitly use the term \"3D tensor\"." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n", "\n", " 1\n", " 2\n", " 3\n", "[torch.FloatTensor of size 3]\n", "\n", "\n", " 1 2\n", " 3 4\n", "[torch.FloatTensor of size 2x2]\n", "\n" ] } ], "source": [ "# Index into V and get a scalar\n", "print V[0]\n", "\n", "# Index into M and get a vector\n", "print M[0]\n", "\n", "# Index into T and get a matrix\n", "print T[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also create tensors of other datatypes. The default, as you can see, is Float.\n", "To create a tensor of integer types, try torch.LongTensor(). Check the documentation for more data types, but Float and Long will be the most common." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can create a tensor with random data and the supplied dimensionality with torch.randn()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "(0 ,.,.) = \n", " -2.9718 1.7070 -0.4305 -2.2820 0.5237\n", " 0.0004 -1.2039 3.5283 0.4434 0.5848\n", " 0.8407 0.5510 0.3863 0.9124 -0.8410\n", " 1.2282 -1.8661 1.4146 -1.8781 -0.4674\n", "\n", "(1 ,.,.) = \n", " -0.7576 0.4215 -0.4827 -1.1198 0.3056\n", " 1.0386 0.5206 -0.5006 1.2182 0.2117\n", " -1.0613 -1.9441 -0.9596 0.5489 -0.9901\n", " -0.3826 1.5037 1.8267 0.5561 1.6445\n", "\n", "(2 ,.,.) = \n", " 0.4973 -1.5067 1.7661 -0.3569 -0.1713\n", " 0.4068 -0.4284 -1.1299 1.4274 -1.4027\n", " 1.4825 -1.1559 1.6190 0.9581 0.7747\n", " 0.1940 0.1687 0.3061 1.0743 -1.0327\n", "[torch.FloatTensor of size 3x4x5]\n", "\n" ] } ], "source": [ "x = torch.randn((3, 4, 5))\n", "print x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Operations with Tensors\n", "You can operate on tensors in the ways you would expect." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 5\n", " 7\n", " 9\n", "[torch.FloatTensor of size 3]\n", "\n" ] } ], "source": [ "x = torch.Tensor([ 1., 2., 3. ])\n", "y = torch.Tensor([ 4., 5., 6. ])\n", "z = x + y\n", "print z" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [the documentation](http://pytorch.org/docs/torch.html) for a complete list of the massive number of operations available to you. They expand beyond just mathematical operations.\n", "\n", "One helpful operation that we will make use of later is concatenation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 1.0930 0.7769 -1.3128 0.7099 0.9944\n", "-0.2694 -0.6491 -0.1373 -0.2954 -0.7725\n", "-0.2215 0.5074 -0.6794 -1.6115 0.5230\n", "-0.8890 0.2620 0.0302 0.0013 -1.3987\n", " 1.4666 -0.1028 -0.0097 -0.8420 -0.2067\n", "[torch.FloatTensor of size 5x5]\n", "\n", "\n", " 1.0672 0.1732 -0.6873 0.3620 0.3776 -0.2443 -0.5850 2.0812\n", " 0.3111 0.2358 -1.0658 -0.1186 0.4903 0.8349 0.8894 0.4148\n", "[torch.FloatTensor of size 2x8]\n", "\n" ] } ], "source": [ "# By default, it concatenates along the first axis (concatenates rows)\n", "x_1 = torch.randn(2, 5)\n", "y_1 = torch.randn(3, 5)\n", "z_1 =torch.cat([x_1, y_1])\n", "print z_1\n", "\n", "# Concatenate columns:\n", "x_2 = torch.randn(2, 3)\n", "y_2 = torch.randn(2, 5)\n", "z_2 = torch.cat([x_2, y_2], 1) # second arg specifies which axis to concat along\n", "print z_2\n", "\n", "# If your tensors are not compatible, torch will complain. Uncomment to see the error\n", "# torch.cat([x_1, x_2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reshaping Tensors\n", "Use the .view() method to reshape a tensor.\n", "This method receives heavy use, because many neural network components expect their inputs to have a certain shape.\n", "Often you will need to reshape before passing your data to the component." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "(0 ,.,.) = \n", " 0.0507 -0.9644 -2.0111 0.5245\n", " 2.1332 -0.0822 0.8388 -1.3233\n", " 0.0701 1.2200 0.4251 -1.2328\n", "\n", "(1 ,.,.) = \n", " -0.6195 1.5133 1.9954 -0.6585\n", " -0.4139 -0.2250 -0.6890 0.9882\n", " 0.7404 -2.0990 1.2582 -0.3990\n", "[torch.FloatTensor of size 2x3x4]\n", "\n", "\n", "\n", "Columns 0 to 9 \n", " 0.0507 -0.9644 -2.0111 0.5245 2.1332 -0.0822 0.8388 -1.3233 0.0701 1.2200\n", "-0.6195 1.5133 1.9954 -0.6585 -0.4139 -0.2250 -0.6890 0.9882 0.7404 -2.0990\n", "\n", "Columns 10 to 11 \n", " 0.4251 -1.2328\n", " 1.2582 -0.3990\n", "[torch.FloatTensor of size 2x12]\n", "\n", "\n", "\n", "Columns 0 to 9 \n", " 0.0507 -0.9644 -2.0111 0.5245 2.1332 -0.0822 0.8388 -1.3233 0.0701 1.2200\n", "-0.6195 1.5133 1.9954 -0.6585 -0.4139 -0.2250 -0.6890 0.9882 0.7404 -2.0990\n", "\n", "Columns 10 to 11 \n", " 0.4251 -1.2328\n", " 1.2582 -0.3990\n", "[torch.FloatTensor of size 2x12]\n", "\n" ] } ], "source": [ "x = torch.randn(2, 3, 4)\n", "print x\n", "print x.view(2, 12) # Reshape to 2 rows, 12 columns\n", "print x.view(2, -1) # Same as above. If one of the dimensions is -1, its size can be inferred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# 2. Computation Graphs and Automatic Differentiation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The concept of a computation graph is essential to efficient deep learning programming, because it allows you to not have to write the back propagation gradients yourself. A computation graph is simply a specification of how your data is combined to give you the output. Since the graph totally specifies what parameters were involved with which operations, it contains enough information to compute derivatives. This probably sounds vague, so lets see what is going on using the fundamental class of Pytorch: autograd.Variable.\n", "\n", "First, think from a programmers perspective. What is stored in the torch.Tensor objects we were creating above?\n", "Obviously the data and the shape, and maybe a few other things. But when we added two tensors together, we got an output tensor. All this output tensor knows is its data and shape. It has no idea that it was the sum of two other tensors (it could have been read in from a file, it could be the result of some other operation, etc.)\n", "\n", "The Variable class keeps track of how it was created. Lets see it in action." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 1\n", " 2\n", " 3\n", "[torch.FloatTensor of size 3]\n", "\n", "\n", " 5\n", " 7\n", " 9\n", "[torch.FloatTensor of size 3]\n", "\n", "\n" ] } ], "source": [ "# Variables wrap tensor objects\n", "x = autograd.Variable( torch.Tensor([1., 2., 3]), requires_grad=True )\n", "# You can access the data with the .data attribute\n", "print x.data\n", "\n", "# You can also do all the same operations you did with tensors with Variables.\n", "y = autograd.Variable( torch.Tensor([4., 5., 6]), requires_grad=True )\n", "z = x + y\n", "print z.data\n", "\n", "# BUT z knows something extra.\n", "print z.grad_fn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So Variables know what created them. z knows that it wasn't read in from a file, it wasn't the result of a multiplication or exponential or whatever. And if you keep following z.grad_fn, you will find yourself at x and y.\n", "\n", "But how does that help us compute a gradient?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", " 21\n", "[torch.FloatTensor of size 1]\n", "\n", "\n" ] } ], "source": [ "# Lets sum up all the entries in z\n", "s = z.sum()\n", "print s\n", "print s.grad_fn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So now, what is the derivative of this sum with respect to the first component of x? In math, we want\n", "$$ \\frac{\\partial s}{\\partial x_0} $$\n", "Well, s knows that it was created as a sum of the tensor z. z knows that it was the sum x + y.\n", "So \n", "$$ s = \\overbrace{x_0 + y_0}^\\text{$z_0$} + \\overbrace{x_1 + y_1}^\\text{$z_1$} + \\overbrace{x_2 + y_2}^\\text{$z_2$} $$\n", "And so s contains enough information to determine that the derivative we want is 1!\n", "\n", "Of course this glosses over the challenge of how to actually compute that derivative. The point here is that s is carrying along enough information that it is possible to compute it. In reality, the developers of Pytorch program the sum() and + operations to know how to compute their gradients, and run the back propagation algorithm. An in-depth discussion of that algorithm is beyond the scope of this tutorial." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets have Pytorch compute the gradient, and see that we were right: (note if you run this block multiple times, the gradient will increment. That is because Pytorch *accumulates* the gradient into the .grad property, since for many models this is very convenient.)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", " 1\n", " 1\n", " 1\n", "[torch.FloatTensor of size 3]\n", "\n" ] } ], "source": [ "s.backward() # calling .backward() on any variable will run backprop, starting from it.\n", "print x.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Understanding what is going on in the block below is crucial for being a successful programmer in deep learning." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "None\n" ] } ], "source": [ "x = torch.randn((2,2))\n", "y = torch.randn((2,2))\n", "z = x + y # These are Tensor types, and backprop would not be possible\n", "\n", "var_x = autograd.Variable( x )\n", "var_y = autograd.Variable( y )\n", "var_z = var_x + var_y # var_z contains enough information to compute gradients, as we saw above\n", "print var_z.grad_fn\n", "\n", "var_z_data = var_z.data # Get the wrapped Tensor object out of var_z...\n", "new_var_z = autograd.Variable( var_z_data ) # Re-wrap the tensor in a new variable\n", "\n", "# ... does new_var_z have information to backprop to x and y?\n", "# NO!\n", "print new_var_z.grad_fn\n", "# And how could it? We yanked the tensor out of var_z (that is what var_z.data is). This tensor\n", "# doesn't know anything about how it was computed. We pass it into new_var_z, and this is all the information\n", "# new_var_z gets. If var_z_data doesn't know how it was computed, theres no way new_var_z will.\n", "# In essence, we have broken the variable away from its past history" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the basic, extremely important rule for computing with autograd.Variables (note this is more general than Pytorch. There is an equivalent object in every major deep learning toolkit):\n", "\n", "** If you want the error from your loss function to backpropogate to a component of your network, you MUST NOT break the Variable chain from that component to your loss Variable. If you do, the loss will have no idea your component exists, and its parameters can't be updated. **\n", "\n", "I say this in bold, because this error can creep up on you in very subtle ways (I will show some such ways below), and it will not cause your code to crash or complain, so you must be careful." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Deep Learning Building Blocks: Affine maps, non-linearities and objectives" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deep learning consists of composing linearities with non-linearities in clever ways. The introduction of non-linearities allows for powerful models. In this section, we will play with these core components, make up an objective function, and see how the model is trained." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Affine Maps\n", "One of the core workhorses of deep learning is the affine map, which is a function $f(x)$ where\n", "$$ f(x) = Ax + b $$ for a matrix $A$ and vectors $x, b$. The parameters to be learned here are $A$ and $b$. Often, $b$ is refered to as the *bias* term." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pytorch and most other deep learning frameworks do things a little differently than traditional linear algebra. It maps the rows of the input instead of the columns. That is, the $i$'th row of the output below is the mapping of the $i$'th row of the input under $A$, plus the bias term. Look at the example below." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", " 0.4825 0.0247 0.4566\n", "-0.0652 -0.7002 -0.4353\n", "[torch.FloatTensor of size 2x3]\n", "\n" ] } ], "source": [ "lin = nn.Linear(5, 3) # maps from R^5 to R^3, parameters A, b\n", "data = autograd.Variable( torch.randn(2, 5) ) # data is 2x5. A maps from 5 to 3... can we map \"data\" under A?\n", "print lin(data) # yes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Non-Linearities\n", "First, note the following fact, which will explain why we need non-linearities in the first place.\n", "Suppose we have two affine maps $f(x) = Ax + b$ and $g(x) = Cx + d$. What is $f(g(x))$?\n", "$$ f(g(x)) = A(Cx + d) + b = ACx + (Ad + b) $$\n", "$AC$ is a matrix and $Ad + b$ is a vector, so we see that composing affine maps gives you an affine map.\n", "\n", "From this, you can see that if you wanted your neural network to be long chains of affine compositions, that this adds no new power to your model than just doing a single affine map.\n", "\n", "If we introduce non-linearities in between the affine layers, this is no longer the case, and we can build much more powerful models.\n", "\n", "There are a few core non-linearities. $\\tanh(x), \\sigma(x), \\text{ReLU}(x)$ are the most common.\n", "You are probably wondering: \"why these functions? I can think of plenty of other non-linearities.\"\n", "The reason for this is that they have gradients that are easy to compute, and computing gradients is essential for learning. For example\n", "$$ \\frac{d\\sigma}{dx} = \\sigma(x)(1 - \\sigma(x)) $$\n", "\n", "A quick note: although you may have learned some neural networks in your intro to AI class where $\\sigma(x)$ was the default non-linearity, typically people shy away from it in practice. This is because the gradient *vanishes* very quickly as the absolute value of the argument grows. Small gradients means it is hard to learn. Most people default to tanh or ReLU." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-1.0246 -1.0300\n", "-1.0129 0.0055\n", "[torch.FloatTensor of size 2x2]\n", "\n", "Variable containing:\n", "1.00000e-03 *\n", " 0.0000 0.0000\n", " 0.0000 5.5350\n", "[torch.FloatTensor of size 2x2]\n", "\n" ] } ], "source": [ "# In pytorch, most non-linearities are in torch.functional (we have it imported as F)\n", "# Note that non-linearites typically don't have parameters like affine maps do.\n", "# That is, they don't have weights that are updated during training.\n", "data = autograd.Variable( torch.randn(2, 2) )\n", "print data\n", "print F.relu(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Softmax and Probabilities\n", "The function $\\text{Softmax}(x)$ is also just a non-linearity, but it is special in that it usually is the last operation done in a network. This is because it takes in a vector of real numbers and returns a probability distribution. Its definition is as follows. Let $x$ be a vector of real numbers (positive, negative, whatever, there are no constraints). Then the i'th component of $\\text{Softmax}(x)$ is\n", "$$ \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)} $$\n", "It should be clear that the output is a probability distribution: each element is non-negative and the sum over all components is 1.\n", "\n", "You could also think of it as just applying an element-wise exponentiation operator to the input to make everything non-negative and then dividing by the normalization constant." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-0.9347\n", "-0.9882\n", " 1.3801\n", "-0.1173\n", " 0.9317\n", "[torch.FloatTensor of size 5]\n", "\n", "Variable containing:\n", " 0.0481\n", " 0.0456\n", " 0.4867\n", " 0.1089\n", " 0.3108\n", "[torch.FloatTensor of size 5]\n", "\n", "Variable containing:\n", " 1\n", "[torch.FloatTensor of size 1]\n", "\n", "Variable containing:\n", "-3.0350\n", "-3.0885\n", "-0.7201\n", "-2.2176\n", "-1.1686\n", "[torch.FloatTensor of size 5]\n", "\n" ] } ], "source": [ "# Softmax is also in torch.functional\n", "data = autograd.Variable( torch.randn(5) )\n", "print data\n", "print F.softmax(data)\n", "print F.softmax(data).sum() # Sums to 1 because it is a distribution!\n", "print F.log_softmax(data) # theres also log_softmax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Objective Functions\n", "The objective function is the function that your network is being trained to minimize (in which case it is often called a *loss function* or *cost function*).\n", "This proceeds by first choosing a training instance, running it through your neural network, and then computing the loss of the output. The parameters of the model are then updated by taking the derivative of the loss function. Intuitively, if your model is completely confident in its answer, and its answer is wrong, your loss will be high. If it is very confident in its answer, and its answer is correct, the loss will be low.\n", "\n", "The idea behind minimizing the loss function on your training examples is that your network will hopefully generalize well and have small loss on unseen examples in your dev set, test set, or in production.\n", "An example loss function is the *negative log likelihood loss*, which is a very common objective for multi-class classification. For supervised multi-class classification, this means training the network to minimize the negative log probability of the correct output (or equivalently, maximize the log probability of the correct output)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Optimization and Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So what we can compute a loss function for an instance? What do we do with that?\n", "We saw earlier that autograd.Variable's know how to compute gradients with respect to the things that were used to compute it. Well, since our loss is an autograd.Variable, we can compute gradients with respect to all of the parameters used to compute it! Then we can perform standard gradient updates. Let $\\theta$ be our parameters, $L(\\theta)$ the loss function, and $\\eta$ a positive learning rate. Then:\n", "\n", "$$ \\theta^{(t+1)} = \\theta^{(t)} - \\eta \\nabla_\\theta L(\\theta) $$\n", "\n", "There are a huge collection of algorithms and active research in attempting to do something more than just this vanilla gradient update. Many attempt to vary the learning rate based on what is happening at train time. You don't need to worry about what specifically these algorithms are doing unless you are really interested. Torch provies many in the torch.optim package, and they are all completely transparent. Using the simplest gradient update is the same as the more complicated algorithms. Trying different update algorithms and different parameters for the update algorithms (like different initial learning rates) is important in optimizing your network's performance. Often, just replacing vanilla SGD with an optimizer like Adam or RMSProp will boost performance noticably." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Creating Network Components in Pytorch\n", "Before we move on to our focus on NLP, lets do an annotated example of building a network in Pytorch using only affine maps and non-linearities. We will also see how to compute a loss function, using Pytorch's built in negative log likelihood, and update parameters by backpropagation.\n", "\n", "All network components should inherit from nn.Module and override the forward() method. That is about it, as far as the boilerplate is concerned. Inheriting from nn.Module provides functionality to your component. For example, it makes it keep track of its trainable parameters, you can swap it between CPU and GPU with the .cuda() or .cpu() functions, etc.\n", "\n", "Let's write an annotated example of a network that takes in a sparse bag-of-words representation and outputs a probability distribution over two labels: \"English\" and \"Spanish\". This model is just logistic regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: Logistic Regression Bag-of-Words classifier\n", "Our model will map a sparse BOW representation to log probabilities over labels. We assign each word in the vocab an index. For example, say our entire vocab is two words \"hello\" and \"world\", with indices 0 and 1 respectively.\n", "The BoW vector for the sentence \"hello hello hello hello\" is\n", "$$ \\left[ 4, 0 \\right] $$\n", "For \"hello world world hello\", it is \n", "$$ \\left[ 2, 2 \\right] $$\n", "etc.\n", "In general, it is\n", "$$ \\left[ \\text{Count}(\\text{hello}), \\text{Count}(\\text{world}) \\right] $$\n", "\n", "Denote this BOW vector as $x$.\n", "The output of our network is:\n", "$$ \\log \\text{Softmax}(Ax + b) $$\n", "That is, we pass the input through an affine map and then do log softmax." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'en': 3, 'No': 9, 'buena': 14, 'it': 7, 'at': 22, 'sea': 12, 'cafeteria': 5, 'Yo': 23, 'la': 4, 'to': 8, 'creo': 10, 'is': 16, 'a': 18, 'good': 19, 'get': 20, 'idea': 15, 'que': 11, 'not': 17, 'me': 0, 'on': 25, 'gusta': 1, 'lost': 21, 'Give': 6, 'una': 13, 'si': 24, 'comer': 2}\n" ] } ], "source": [ "data = [ (\"me gusta comer en la cafeteria\".split(), \"SPANISH\"),\n", " (\"Give it to me\".split(), \"ENGLISH\"),\n", " (\"No creo que sea una buena idea\".split(), \"SPANISH\"),\n", " (\"No it is not a good idea to get lost at sea\".split(), \"ENGLISH\") ]\n", "\n", "test_data = [ (\"Yo creo que si\".split(), \"SPANISH\"),\n", " (\"it is lost on me\".split(), \"ENGLISH\")]\n", "\n", "# word_to_ix maps each word in the vocab to a unique integer, which will be its\n", "# index into the Bag of words vector\n", "word_to_ix = {}\n", "for sent, _ in data + test_data:\n", " for word in sent:\n", " if word not in word_to_ix:\n", " word_to_ix[word] = len(word_to_ix)\n", "print word_to_ix\n", "\n", "VOCAB_SIZE = len(word_to_ix)\n", "NUM_LABELS = 2" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class BoWClassifier(nn.Module): # inheriting from nn.Module!\n", " \n", " def __init__(self, num_labels, vocab_size):\n", " # calls the init function of nn.Module. Dont get confused by syntax,\n", " # just always do it in an nn.Module\n", " super(BoWClassifier, self).__init__()\n", " \n", " # Define the parameters that you will need. In this case, we need A and b,\n", " # the parameters of the affine mapping.\n", " # Torch defines nn.Linear(), which provides the affine map.\n", " # Make sure you understand why the input dimension is vocab_size\n", " # and the output is num_labels!\n", " self.linear = nn.Linear(vocab_size, num_labels)\n", " \n", " # NOTE! The non-linearity log softmax does not have parameters! So we don't need\n", " # to worry about that here\n", " \n", " def forward(self, bow_vec):\n", " # Pass the input through the linear layer,\n", " # then pass that through log_softmax.\n", " # Many non-linearities and other functions are in torch.nn.functional\n", " return F.log_softmax(self.linear(bow_vec))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def make_bow_vector(sentence, word_to_ix):\n", " vec = torch.zeros(len(word_to_ix))\n", " for word in sentence:\n", " vec[word_to_ix[word]] += 1\n", " return vec.view(1, -1)\n", "\n", "def make_target(label, label_to_ix):\n", " return torch.LongTensor([label_to_ix[label]])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "\n", "Columns 0 to 9 \n", " 0.1553 0.0992 -0.0282 0.1496 0.1823 -0.1915 0.0641 -0.0007 0.0477 -0.1672\n", "-0.1950 0.1070 0.0459 -0.1361 -0.0680 0.0308 0.0106 -0.1926 0.1514 0.0820\n", "\n", "Columns 10 to 19 \n", "-0.1511 0.1126 0.1763 -0.1710 -0.0196 -0.0568 0.0307 0.1733 -0.0360 -0.0471\n", "-0.0560 -0.0115 0.1602 0.1038 0.0484 -0.0128 -0.1899 -0.0906 0.1684 0.1301\n", "\n", "Columns 20 to 25 \n", "-0.1031 0.1031 0.1582 0.1065 0.0289 -0.0779\n", " 0.0749 0.0201 0.1951 -0.1686 -0.1285 -0.0108\n", "[torch.FloatTensor of size 2x26]\n", "\n", "Parameter containing:\n", "-0.1423\n", " 0.0952\n", "[torch.FloatTensor of size 2]\n", "\n" ] } ], "source": [ "model = BoWClassifier(NUM_LABELS, VOCAB_SIZE)\n", "\n", "# the model knows its parameters. The first output below is A, the second is b.\n", "# Whenever you assign a component to a class variable in the __init__ function of a module,\n", "# which was done with the line\n", "# self.linear = nn.Linear(...)\n", "# Then through some Python magic from the Pytorch devs, your module (in this case, BoWClassifier)\n", "# will store knowledge of the nn.Linear's parameters\n", "for param in model.parameters():\n", " print param" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-0.5357 -0.8801\n", "[torch.FloatTensor of size 1x2]\n", "\n" ] } ], "source": [ "# To run the model, pass in a BoW vector, but wrapped in an autograd.Variable\n", "sample = data[0]\n", "bow_vector = make_bow_vector(sample[0], word_to_ix)\n", "log_probs = model(autograd.Variable(bow_vector))\n", "print log_probs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Which of the above values corresponds to the log probability of ENGLISH, and which to SPANISH? We never defined it, but we need to if we want to train the thing." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "label_to_ix = { \"SPANISH\": 0, \"ENGLISH\": 1 }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So lets train! To do this, we pass instances through to get log probabilities, compute a loss function, compute the gradient of the loss function, and then update the parameters with a gradient step. Loss functions are provided by Torch in the nn package. nn.NLLLoss() is the negative log likelihood loss we want. It also defines optimization functions in torch.optim. Here, we will just use SGD.\n", "\n", "Note that the *input* to NLLLoss is a vector of log probabilities, and a target label. It doesn't compute the log probabilities for us. This is why the last layer of our network is log softmax.\n", "The loss function nn.CrossEntropyLoss() is the same as NLLLoss(), except it does the log softmax for you." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-0.5874 -0.8114\n", "[torch.FloatTensor of size 1x2]\n", "\n", "Variable containing:\n", "-0.4588 -0.9999\n", "[torch.FloatTensor of size 1x2]\n", "\n", "Variable containing:\n", "-0.1511\n", "-0.0560\n", "[torch.FloatTensor of size 2]\n", "\n" ] } ], "source": [ "# Run on test data before we train, just to see a before-and-after\n", "for instance, label in test_data:\n", " bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))\n", " log_probs = model(bow_vec)\n", " print log_probs\n", "print next(model.parameters())[:,word_to_ix[\"creo\"]] # Print the matrix column corresponding to \"creo\"" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "loss_function = nn.NLLLoss()\n", "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", "\n", "# Usually you want to pass over the training data several times.\n", "# 100 is much bigger than on a real data set, but real datasets have more than\n", "# two instances. Usually, somewhere between 5 and 30 epochs is reasonable.\n", "for epoch in xrange(100):\n", " for instance, label in data:\n", " # Step 1. Remember that Pytorch accumulates gradients. We need to clear them out\n", " # before each instance\n", " model.zero_grad()\n", " \n", " # Step 2. Make our BOW vector and also we must wrap the target in a Variable\n", " # as an integer. For example, if the target is SPANISH, then we wrap the integer\n", " # 0. The loss function then knows that the 0th element of the log probabilities is\n", " # the log probability corresponding to SPANISH\n", " bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))\n", " target = autograd.Variable(make_target(label, label_to_ix))\n", " \n", " # Step 3. Run our forward pass.\n", " log_probs = model(bow_vec)\n", " \n", " # Step 4. Compute the loss, gradients, and update the parameters by calling\n", " # optimizer.step()\n", " loss = loss_function(log_probs, target)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-0.0904 -2.4483\n", "[torch.FloatTensor of size 1x2]\n", "\n", "Variable containing:\n", "-2.2816 -0.1077\n", "[torch.FloatTensor of size 1x2]\n", "\n", "Variable containing:\n", " 0.3345\n", "-0.5416\n", "[torch.FloatTensor of size 2]\n", "\n" ] } ], "source": [ "for instance, label in test_data:\n", " bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))\n", " log_probs = model(bow_vec)\n", " print log_probs\n", "print next(model.parameters())[:,word_to_ix[\"creo\"]] # Index corresponding to Spanish goes up, English goes down!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We got the right answer! You can see that the log probability for Spanish is much higher in the first example, and the log probability for English is much higher in the second for the test data, as it should be.\n", "\n", "Now you see how to make a Pytorch component, pass some data through it and do gradient updates.\n", "We are ready to dig deeper into what deep NLP has to offer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 6. Word Embeddings: Encoding Lexical Semantics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Word embeddings are dense vectors of real numbers, one per word in your vocabulary.\n", "In NLP, it is almost always the case that your features are words! But how should you represent a word in a computer?\n", "You could store its ascii character representation, but that only tells you what the word *is*, it doesn't say much about what it *means* (you might be able to derive its part of speech from its affixes, or properties from its capitalization, but not much). Even more, in what sense could you combine these representations?\n", "We often want dense outputs from our neural networks, where the inputs are $|V|$ dimensional, where $V$ is our vocabulary, but often the outputs are only a few dimensional (if we are only predicting a handful of labels, for instance). How do we get from a massive dimensional space to a smaller dimensional space?\n", "\n", "How about instead of ascii representations, we use a one-hot encoding? That is, we represent the word $w$ by\n", "$$ \\overbrace{\\left[ 0, 0, \\dots, 1, \\dots, 0, 0 \\right]}^\\text{|V| elements} $$\n", "where the 1 is in a location unique to $w$. Any other word will have a 1 in some other location, and a 0 everywhere else.\n", "\n", "There is an enormous drawback to this representation, besides just how huge it is. It basically treats all words as independent entities with no relation to each other. What we really want is some notion of *similarity* between words. Why? Let's see an example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we are building a language model. Suppose we have seen the sentences\n", "* The mathematician ran to the store.\n", "* The physicist ran to the store.\n", "* The mathematician solved the open problem.\n", "\n", "in our training data.\n", "Now suppose we get a new sentence never before seen in our training data:\n", "* The physicist solved the open problem.\n", "\n", "Our language model might do OK on this sentence, but wouldn't it be much better if we could use the following two facts:\n", "* We have seen mathematician and physicist in the same role in a sentence. Somehow they have a semantic relation.\n", "* We have seen mathematician in the same role in this new unseen sentence as we are now seeing physicist.\n", "\n", "and then infer that physicist is actually a good fit in the new unseen sentence? This is what we mean by a notion of similarity: we mean *semantic similarity*, not simply having similar orthographic representations. It is a technique to combat the sparsity of linguistic data, by connecting the dots between what we have seen and what we haven't. This example of course relies on a fundamental linguistic assumption: that words appearing in similar contexts are related to each other semantically. This is called the [distributional hypothesis](https://en.wikipedia.org/wiki/Distributional_semantics)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Getting Dense Word Embeddings\n", "\n", "How can we solve this problem? That is, how could we actually encode semantic similarity in words?\n", "Maybe we think up some semantic attributes. For example, we see that both mathematicians and physicists can run, so maybe we give these words a high score for the \"is able to run\" semantic attribute. Think of some other attributes, and imagine what you might score some common words on those attributes.\n", "\n", "If each attribute is a dimension, then we might give each word a vector, like this:\n", "$$ q_\\text{mathematician} = \\left[ \\overbrace{2.3}^\\text{can run},\n", "\\overbrace{9.4}^\\text{likes coffee}, \\overbrace{-5.5}^\\text{majored in Physics}, \\dots \\right] $$\n", "$$ q_\\text{physicist} = \\left[ \\overbrace{2.5}^\\text{can run},\n", "\\overbrace{9.1}^\\text{likes coffee}, \\overbrace{6.4}^\\text{majored in Physics}, \\dots \\right] $$\n", "\n", "Then we can get a measure of similarity between these words by doing:\n", "$$ \\text{Similarity}(\\text{physicist}, \\text{mathematician}) = q_\\text{physicist} \\cdot q_\\text{mathematician} $$\n", "\n", "Although it is more common to normalize by the lengths:\n", "$$ \\text{Similarity}(\\text{physicist}, \\text{mathematician}) = \\frac{q_\\text{physicist} \\cdot q_\\text{mathematician}}\n", "{\\| q_\\text{\\physicist} \\| \\| q_\\text{mathematician} \\|} = \\cos (\\phi) $$\n", "Where $\\phi$ is the angle between the two vectors. That way, extremely similar words (words whose embeddings point in the same direction) will have similarity 1. Extremely dissimilar words should have similarity -1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can think of the sparse one-hot vectors from the beginning of this section as a special case of these new vectors we have defined, where each word basically has similarity 0, and we gave each word some unique semantic attribute. These new vectors are *dense*, which is to say their entries are (typically) non-zero.\n", "\n", "But these new vectors are a big pain: you could think of thousands of different semantic attributes that might be relevant to determining similarity, and how on earth would you set the values of the different attributes? Central to the idea of deep learning is that the neural network learns representations of the features, rather than requiring the programmer to design them herself. So why not just let the word embeddings be parameters in our model, and then be updated during training? This is exactly what we will do. We will have some *latent semantic attributes* that the network can, in principle, learn. Note that the word embeddings will probably not be interpretable. That is, although with our hand-crafted vectors above we can see that mathematicians and physicists are similar in that they both like coffee, if we allow a neural network to learn the embeddings and see that both mathematicians and physicisits have a large value in the second dimension, it is not clear what that means. They are similar in some latent semantic dimension, but this probably has no interpretation to us." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In summary, **word embeddings are a representation of the *semantics* of a word, efficiently encoding semantic information that might be relevant to the task at hand**. You can embed other things too: part of speech tags, parse trees, anything! The idea of feature embeddings is central to the field." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Word Embeddings in Pytorch\n", "Before we get to a worked example and an exercise, a few quick notes about how to use embeddings in Pytorch and in deep learning programming in general.\n", "Similar to how we defined a unique index for each word when making one-hot vectors, we also need to define an index for each word when using embeddings. These will be keys into a lookup table. That is, embeddings are stored as a $|V| \\times D$ matrix, where $D$ is the dimensionality of the embeddings, such that the word assigned index $i$ has its embedding stored in the $i$'th row of the matrix. In all of my code, the mapping from words to indices is a dictionary named word_to_ix.\n", "\n", "The module that allows you to use embeddings is torch.nn.Embedding, which takes two arguments: the vocabulary size, and the dimensionality of the embeddings.\n", "\n", "To index into this table, you must use torch.LongTensor (since the indices are integers, not floats)." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", " 0.5952 -0.2683 -0.3664 -1.0555 0.6623\n", "[torch.FloatTensor of size 1x5]\n", "\n" ] } ], "source": [ "word_to_ix = { \"hello\": 0, \"world\": 1 }\n", "embeds = nn.Embedding(2, 5) # 2 words in vocab, 5 dimensional embeddings\n", "lookup_tensor = torch.LongTensor([word_to_ix[\"hello\"]])\n", "hello_embed = embeds( autograd.Variable(lookup_tensor) )\n", "print hello_embed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### An Example: N-Gram Language Modeling\n", "Recall that in an n-gram language model, given a sequence of words $w$, we want to compute\n", "$$ P(w_i | w_{i-1}, w_{i-2}, \\dots, w_{i-n+1} ) $$\n", "Where $w_i$ is the ith word of the sequence.\n", "\n", "In this example, we will compute the loss function on some training examples and update the parameters with backpropagation." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(['When', 'forty'], 'winters'), (['forty', 'winters'], 'shall'), (['winters', 'shall'], 'besiege')]\n" ] } ], "source": [ "CONTEXT_SIZE = 2\n", "EMBEDDING_DIM = 10\n", "# We will use Shakespeare Sonnet 2\n", "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", "And dig deep trenches in thy beauty's field,\n", "Thy youth's proud livery so gazed on now,\n", "Will be a totter'd weed of small worth held:\n", "Then being asked, where all thy beauty lies,\n", "Where all the treasure of thy lusty days;\n", "To say, within thine own deep sunken eyes,\n", "Were an all-eating shame, and thriftless praise.\n", "How much more praise deserv'd thy beauty's use,\n", "If thou couldst answer 'This fair child of mine\n", "Shall sum my count, and make my old excuse,'\n", "Proving his beauty by succession thine!\n", "This were to be new made when thou art old,\n", "And see thy blood warm when thou feel'st it cold.\"\"\".split()\n", "# we should tokenize the input, but we will ignore that for now\n", "# build a list of tuples. Each tuple is ([ word_i-2, word_i-1 ], target word)\n", "trigrams = [ ([test_sentence[i], test_sentence[i+1]], test_sentence[i+2]) for i in xrange(len(test_sentence) - 2) ]\n", "print trigrams[:3] # print the first 3, just so you can see what they look like" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vocab = set(test_sentence)\n", "word_to_ix = { word: i for i, word in enumerate(vocab) }" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class NGramLanguageModeler(nn.Module):\n", " \n", " def __init__(self, vocab_size, embedding_dim, context_size):\n", " super(NGramLanguageModeler, self).__init__()\n", " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " self.linear1 = nn.Linear(context_size * embedding_dim, 128)\n", " self.linear2 = nn.Linear(128, vocab_size)\n", " \n", " def forward(self, inputs):\n", " embeds = self.embeddings(inputs).view((1, -1))\n", " out = F.relu(self.linear1(embeds))\n", " out = self.linear2(out)\n", " log_probs = F.log_softmax(out)\n", " return log_probs" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[\n", " 518.8207\n", "[torch.FloatTensor of size 1]\n", ", \n", " 516.3852\n", "[torch.FloatTensor of size 1]\n", ", \n", " 513.9670\n", "[torch.FloatTensor of size 1]\n", ", \n", " 511.5646\n", "[torch.FloatTensor of size 1]\n", ", \n", " 509.1782\n", "[torch.FloatTensor of size 1]\n", ", \n", " 506.8095\n", "[torch.FloatTensor of size 1]\n", ", \n", " 504.4555\n", "[torch.FloatTensor of size 1]\n", ", \n", " 502.1131\n", "[torch.FloatTensor of size 1]\n", ", \n", " 499.7835\n", "[torch.FloatTensor of size 1]\n", ", \n", " 497.4669\n", "[torch.FloatTensor of size 1]\n", "]\n" ] } ], "source": [ "losses = []\n", "loss_function = nn.NLLLoss()\n", "model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)\n", "optimizer = optim.SGD(model.parameters(), lr=0.001)\n", "\n", "for epoch in xrange(10):\n", " total_loss = torch.Tensor([0])\n", " for context, target in trigrams:\n", " \n", " # Step 1. Prepare the inputs to be passed to the model (i.e, turn the words\n", " # into integer indices and wrap them in variables)\n", " context_idxs = map(lambda w: word_to_ix[w], context)\n", " context_var = autograd.Variable( torch.LongTensor(context_idxs) )\n", " \n", " # Step 2. Recall that torch *accumulates* gradients. Before passing in a new instance,\n", " # you need to zero out the gradients from the old instance\n", " model.zero_grad()\n", " \n", " # Step 3. Run the forward pass, getting log probabilities over next words\n", " log_probs = model(context_var)\n", " \n", " # Step 4. Compute your loss function. (Again, Torch wants the target word wrapped in a variable)\n", " loss = loss_function(log_probs, autograd.Variable(torch.LongTensor([word_to_ix[target]])))\n", " \n", " # Step 5. Do the backward pass and update the gradient\n", " loss.backward()\n", " optimizer.step()\n", " \n", " total_loss += loss.data\n", " losses.append(total_loss)\n", "print losses # The loss decreased every iteration over the training data!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exercise: Computing Word Embeddings: Continuous Bag-of-Words" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Continuous Bag-of-Words model (CBOW) is frequently used in NLP deep learning. It is a model that tries to predict words given the context of a few words before and a few words after the target word. This is distinct from language modeling, since CBOW is not sequential and does not have to be probabilistic. Typcially, CBOW is used to quickly train word embeddings, and these embeddings are used to initialize the embeddings of some more complicated model. Usually, this is referred to as *pretraining embeddings*. It almost always helps performance a couple of percent.\n", "\n", "The CBOW model is as follows. Given a target word $w_i$ and an $N$ context window on each side, $w_{i-1}, \\dots, w_{i-N}$ and $w_{i+1}, \\dots, w_{i+N}$, referring to all context words collectively as $C$, CBOW tries to minimize\n", "$$ -\\log p(w_i | C) = \\log \\text{Softmax}(A(\\sum_{w \\in C} q_w) + b) $$\n", "where $q_w$ is the embedding for word $w$.\n", "\n", "Implement this model in Pytorch by filling in the class below. Some tips:\n", "* Think about which parameters you need to define.\n", "* Make sure you know what shape each operation expects. Use .view() if you need to reshape." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study'), (['to', 'study', 'idea', 'of'], 'the'), (['study', 'the', 'of', 'a'], 'idea')]\n" ] } ], "source": [ "CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right\n", "raw_text = \"\"\"We are about to study the idea of a computational process. Computational processes are abstract\n", "beings that inhabit computers. As they evolve, processes manipulate other abstract\n", "things called data. The evolution of a process is directed by a pattern of rules\n", "called a program. People create programs to direct processes. In effect,\n", "we conjure the spirits of the computer with our spells.\"\"\".split()\n", "word_to_ix = { word: i for i, word in enumerate(set(raw_text)) }\n", "data = []\n", "for i in xrange(2, len(raw_text) - 2):\n", " context = [ raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2] ]\n", " target = raw_text[i]\n", " data.append( (context, target) )\n", "print data[:5]" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CBOW(nn.Module):\n", " \n", " def __init__(self):\n", " pass\n", " \n", " def forward(self, inputs):\n", " pass" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 0\n", " 13\n", " 47\n", " 4\n", "[torch.LongTensor of size 4]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create your model and train. here are some functions to help you make the data ready for use by your module\n", "def make_context_vector(context, word_to_ix):\n", " idxs = map(lambda w: word_to_ix[w], context)\n", " tensor = torch.LongTensor(idxs)\n", " return autograd.Variable(tensor)\n", "\n", "make_context_vector(data[0][0], word_to_ix) # example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 7. Sequence Models and Long-Short Term Memory Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At this point, we have seen various feed-forward networks.\n", "That is, there is no state maintained by the network at all.\n", "This might not be the behavior we want.\n", "Sequence models are central to NLP: they are models where there is some sort of dependence through time between your inputs.\n", "The classical example of a sequence model is the Hidden Markov Model for part-of-speech tagging. Another example is the conditional random field.\n", "\n", "A recurrent neural network is a network that maintains some kind of state.\n", "For example, its output could be used as part of the next input, so that information can propogate along as the network passes over the sequence.\n", "In the case of an LSTM, for each element in the sequence, there is a corresponding *hidden state* $h_t$, which in principle can contain information from arbitrary points earlier in the sequence.\n", "We can use the hidden state to predict words in a language model, part-of-speech tags, and a myriad of other things." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LSTM's in Pytorch\n", "\n", "Before getting to the example, note a few things.\n", "Pytorch's LSTM expects all of its inputs to be 3D tensors.\n", "The semantics of the axes of these tensors is important.\n", "The first axis is the sequence itself, the second indexes instances in the mini-batch, and the third indexes elements of the input.\n", "We haven't discussed mini-batching, so lets just ignore that and assume we will always have just 1 dimension on the second axis.\n", "If we want to run the sequence model over the sentence \"The cow jumped\", our input should look like\n", "$$ \n", "\\begin{bmatrix}\n", "\\overbrace{q_\\text{The}}^\\text{row vector} \\\\\n", "q_\\text{cow} \\\\\n", "q_\\text{jumped}\n", "\\end{bmatrix}\n", "$$\n", "Except remember there is an additional 2nd dimension with size 1.\n", "\n", "In addition, you could go through the sequence one at a time, in which case the 1st axis will have size 1 also.\n", "\n", "Let's see a quick example." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "(0 ,.,.) = \n", " -0.1488 0.0412 0.0348\n", "\n", "(1 ,.,.) = \n", " -0.1873 0.0244 0.1239\n", "\n", "(2 ,.,.) = \n", " -0.1872 -0.0043 0.2133\n", "\n", "(3 ,.,.) = \n", " -0.1322 0.0667 0.3406\n", "\n", "(4 ,.,.) = \n", " -0.1737 0.0413 0.2257\n", "[torch.FloatTensor of size 5x1x3]\n", "\n", "(Variable containing:\n", "(0 ,.,.) = \n", " -0.1737 0.0413 0.2257\n", "[torch.FloatTensor of size 1x1x3]\n", ", Variable containing:\n", "(0 ,.,.) = \n", " -0.4370 0.0700 0.4025\n", "[torch.FloatTensor of size 1x1x3]\n", ")\n" ] } ], "source": [ "lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3\n", "inputs = [ autograd.Variable(torch.randn((1,3))) for _ in xrange(5) ] # make a sequence of length 5\n", "\n", "# initialize the hidden state. \n", "hidden = (autograd.Variable(torch.randn(1,1,3)), autograd.Variable(torch.randn((1,1,3))))\n", "for i in inputs:\n", " # Step through the sequence one element at a time.\n", " # after each step, hidden contains the hidden state.\n", " out, hidden = lstm(i.view(1,1,-1), hidden)\n", " \n", "# alternatively, we can do the entire sequence all at once.\n", "# the first value returned by LSTM is all of the hidden states throughout the sequence.\n", "# the second is just the most recent hidden state (compare the last slice of \"out\" with \"hidden\" below,\n", "# they are the same)\n", "# The reason for this is that:\n", "# \"out\" will give you access to all hidden states in the sequence\n", "# \"hidden\" will allow you to continue the sequence and backpropogate, by passing it as an argument\n", "# to the lstm at a later time\n", "inputs = torch.cat(inputs).view(len(inputs), 1, -1) # Add the extra 2nd dimension\n", "hidden = (autograd.Variable(torch.randn(1,1,3)), autograd.Variable(torch.randn((1,1,3)))) # clean out hidden state\n", "out, hidden = lstm(inputs, hidden)\n", "print out\n", "print hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: An LSTM for Part-of-Speech Tagging" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we will use an LSTM to get part of speech tags.\n", "We will not use Viterbi or Forward-Backward or anything like that, but as a (challenging) exercise to the reader, think about how Viterbi could be used after you have seen what is going on.\n", "\n", "The model is as follows: let our input sentence be $w_1, \\dots, w_M$, where $w_i \\in V$, our vocab.\n", "Also, let $T$ be our tag set, and $y_i$ the tag of word $w_i$. Denote our prediction of the tag of word $w_i$ by $\\hat{y}_i$.\n", "\n", "This is a structure prediction, model, where our output is a sequence $\\hat{y}_1, \\dots, \\hat{y}_M$, where $\\hat{y}_i \\in T$.\n", "\n", "To do the prediction, pass an LSTM over the sentence. Denote the hidden state at timestep $i$ as $h_i$. Also, assign each tag a unique index (like how we had word_to_ix in the word embeddings section).\n", "Then our prediction rule for $\\hat{y}_i$ is\n", "$$ \\hat{y}_i = \\text{argmax}_j \\ (\\log \\text{Softmax}(Ah_i + b))_j $$\n", "That is, take the log softmax of the affine map of the hidden state, and the predicted tag is the tag that has the maximum value in this vector. Note this implies immediately that the dimensionality of the target space of $A$ is $|T|$." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_sequence(seq, to_ix):\n", " idxs = map(lambda w: to_ix[w], seq)\n", " tensor = torch.LongTensor(idxs)\n", " return autograd.Variable(tensor)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'Everybody': 5, 'ate': 2, 'apple': 4, 'that': 7, 'read': 6, 'dog': 1, 'book': 8, 'the': 3, 'The': 0}\n" ] } ], "source": [ "training_data = [\n", " (\"The dog ate the apple\".split(), [\"DET\", \"NN\", \"V\", \"DET\", \"NN\"]),\n", " (\"Everybody read that book\".split(), [\"NN\", \"V\", \"DET\", \"NN\"])\n", "]\n", "word_to_ix = {}\n", "for sent, tags in training_data:\n", " for word in sent:\n", " if word not in word_to_ix:\n", " word_to_ix[word] = len(word_to_ix)\n", "print word_to_ix\n", "tag_to_ix = {\"DET\": 0, \"NN\": 1, \"V\": 2}\n", "\n", "# These will usually be more like 32 or 64 dimensional.\n", "# We will keep them small, so we can see how the weights change as we train.\n", "EMBEDDING_DIM = 6\n", "HIDDEN_DIM = 6" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class LSTMTagger(nn.Module):\n", " \n", " def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):\n", " super(LSTMTagger, self).__init__()\n", " self.hidden_dim = hidden_dim\n", " \n", " self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " \n", " # The LSTM takes word embeddings as inputs, and outputs hidden states\n", " # with dimensionality hidden_dim.\n", " self.lstm = nn.LSTM(embedding_dim, hidden_dim)\n", " \n", " # The linear layer that maps from hidden state space to tag space\n", " self.hidden2tag = nn.Linear(hidden_dim, tagset_size)\n", " self.hidden = self.init_hidden()\n", " \n", " def init_hidden(self):\n", " # Before we've done anything, we dont have any hidden state.\n", " # Refer to the Pytorch documentation to see exactly why they have this dimensionality.\n", " # The axes semantics are (num_layers, minibatch_size, hidden_dim)\n", " return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)),\n", " autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))\n", " \n", " def forward(self, sentence):\n", " embeds = self.word_embeddings(sentence)\n", " lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)\n", " tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))\n", " tag_scores = F.log_softmax(tag_space)\n", " return tag_scores" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))\n", "loss_function = nn.NLLLoss()\n", "optimizer = optim.SGD(model.parameters(), lr=0.1)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-1.0104 -1.0989 -1.1949\n", "-0.9876 -1.0711 -1.2556\n", "-1.0028 -1.0819 -1.2235\n", "-1.0130 -1.0500 -1.2486\n", "-1.0932 -1.1502 -1.0548\n", "[torch.FloatTensor of size 5x3]\n", "\n" ] } ], "source": [ "# See what the scores are before training\n", "# Note that element i,j of the output is the score for tag j for word i.\n", "inputs = prepare_sequence(training_data[0][0], word_to_ix)\n", "tag_scores = model(inputs)\n", "print tag_scores" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for epoch in xrange(300): # again, normally you would NOT do 300 epochs, it is toy data\n", " for sentence, tags in training_data:\n", " # Step 1. Remember that Pytorch accumulates gradients. We need to clear them out\n", " # before each instance\n", " model.zero_grad()\n", " \n", " # Also, we need to clear out the hidden state of the LSTM, detaching it from its\n", " # history on the last instance.\n", " model.hidden = model.init_hidden()\n", " \n", " # Step 2. Get our inputs ready for the network, that is, turn them into Variables\n", " # of word indices.\n", " sentence_in = prepare_sequence(sentence, word_to_ix)\n", " targets = prepare_sequence(tags, tag_to_ix)\n", " \n", " # Step 3. Run our forward pass.\n", " tag_scores = model(sentence_in)\n", " \n", " # Step 4. Compute the loss, gradients, and update the parameters by calling\n", " # optimizer.step()\n", " loss = loss_function(tag_scores, targets)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-0.3814 -1.1630 -5.3981\n", "-3.9328 -0.0260 -5.1055\n", "-3.0100 -2.7411 -0.1208\n", "-0.1156 -2.4364 -3.8311\n", "-3.6673 -0.0336 -4.8994\n", "[torch.FloatTensor of size 5x3]\n", "\n" ] } ], "source": [ "# See what the scores are after training\n", "inputs = prepare_sequence(training_data[0][0], word_to_ix)\n", "tag_scores = model(inputs)\n", "# The sentence is \"the dog ate the apple\". i,j corresponds to score for tag j for word i.\n", "# The predicted tag is the maximum scoring tag.\n", "# Here, we can see the predicted sequence below is 0 1 2 0 1\n", "# since 0 is index of the maximum value of row 1,\n", "# 1 is the index of maximum value of row 2, etc.\n", "# Which is DET NOUN VERB DET NOUN, the correct sequence!\n", "print tag_scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exercise: Augmenting the LSTM part-of-speech tagger with character-level features\n", "In the example above, each word had an embedding, which served as the inputs to our sequence model.\n", "Let's augment the word embeddings with a representation derived from the characters of the word.\n", "We expect that this should help significantly, since character-level information like affixes have\n", "a large bearing on part-of-speech. For example, words with the affix *-ly* are almost always tagged as adverbs in English.\n", "\n", "Do do this, let $c_w$ be the character-level representation of word $w$. Let $x_w$ be the word embedding as before.\n", "Then the input to our sequence model is the concatenation of $x_w$ and $c_w$. So if $x_w$ has dimension 5, and $c_w$ dimension 3, then our LSTM should accept an input of dimension 8.\n", "\n", "To get the character level representation, do an LSTM over the characters of a word, and let $c_w$ be the final hidden state of this LSTM.\n", "Hints:\n", "* There are going to be two LSTM's in your new model. The original one that outputs POS tag scores, and the new one that outputs a character-level representation of each word.\n", "* To do a sequence model over characters, you will have to embed characters. The character embeddings will be the input to the character LSTM." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 8. Advanced: Dynamic Toolkits, Dynamic Programming, and the BiLSTM-CRF" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Dyanmic versus Static Deep Learning Toolkits\n", "\n", "Pytorch is a *dynamic* neural network kit. Another example of a dynamic kit is [Dynet](https://github.com/clab/dynet) (I mention this because working with Pytorch and Dynet is similar. If you see an example in Dynet, it will probably help you implement it in Pytorch). The opposite is the *static* tool kit, which includes Theano, Keras, TensorFlow, etc.\n", "The core difference is the following:\n", "* In a static toolkit, you define a computation graph once, compile it, and then stream instances to it.\n", "* In a dynamic toolkit, you define a computation graph *for each instance*. It is never compiled and is executed on-the-fly\n", "\n", "Without a lot of experience, it is difficult to appreciate the difference.\n", "One example is to suppose we want to build a deep constituent parser.\n", "Suppose our model involves roughly the following steps:\n", "* We build the tree bottom up\n", "* Tag the root nodes (the words of the sentence)\n", "* From there, use a neural network and the embeddings of the words\n", "to find combinations that form constituents. Whenever you form a new constituent,\n", "use some sort of technique to get an embedding of the constituent.\n", "In this case, our network architecture will depend completely on the input sentence.\n", "In the sentence \"The green cat scratched the wall\", at some point in the model, we will want to combine\n", "the span $(i,j,r) = (1, 3, \\text{NP})$ (that is, an NP constituent spans word 1 to word 3, in this case \"The green cat\").\n", "\n", "However, another sentence might be \"Somewhere, the big fat cat scratched the wall\". In this sentence, we will want to form the constituent $(2, 4, NP)$ at some point.\n", "The constituents we will want to form will depend on the instance. If we just compile the computation graph once, as in a static toolkit, it will be exceptionally difficult or impossible to program this logic. In a dynamic toolkit though, there isn't just 1 pre-defined computation graph. There can be a new computation graph for each instance, so this problem goes away.\n", "\n", "Dynamic toolkits also have the advantage of being easier to debug and the code more closely resembling the host language (by that I mean that Pytorch and Dynet look more like actual Python code than Keras or Theano).\n", "\n", "I mention this distinction here, because the exercise in this section is to implement a model which closely resembles structure perceptron, and I believe this model would be difficult to implement in a static toolkit. I think that the advantage of dynamic toolkits for linguistic structure prediction cannot be overstated." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bi-LSTM Conditional Random Field Discussion\n", "\n", "For this section, we will see a full, complicated example of a Bi-LSTM Conditional Random Field for named-entity recognition. The LSTM tagger above is typically sufficient for part-of-speech tagging, but a sequence model like the CRF is really essential for strong performance on NER. Familiarity with CRF's is assumed. Although this name sounds scary, all the model is is a CRF but where an LSTM provides the features. This is an advanced model though, far more complicated than any earlier model in this tutorial. If you want to skip it, that is fine. To see if you're ready, see if you can:\n", "\n", "* Write the recurrence for the viterbi variable at step i for tag k.\n", "* Modify the above recurrence to compute the forward variables instead.\n", "* Modify again the above recurrence to compute the forward variables in log-space (hint: log-sum-exp)\n", "\n", "If you can do those three things, you should be able to understand the code below.\n", "Recall that the CRF computes a conditional probability. Let $y$ be a tag sequence and $x$ an input sequence of words. Then we compute\n", "$$ P(y|x) = \\frac{\\exp{(\\text{Score}(x, y)})}{\\sum_{y'} \\exp{(\\text{Score}(x, y')})} $$\n", "\n", "Where the score is determined by defining some log potentials $\\log \\psi_i(x,y)$ such that\n", "$$ \\text{Score}(x,y) = \\sum_i \\log \\psi_i(x,y) $$\n", "To make the partition function tractable, the potentials must look only at local features.\n", "\n", "In the Bi-LSTM CRF, we define two kinds of potentials: emission and transition. The emission potential for the word at index $i$ comes from the hidden state of the Bi-LSTM at timestep $i$. The transition scores are stored in a $|T|x|T|$ matrix $\\textbf{P}$, where $T$ is the tag set. In my implementation, $\\textbf{P}_{j,k}$ is the score of transitioning to tag $j$ from tag $k$. So:\n", "\n", "$$ \\text{Score}(x,y) = \\sum_i \\log \\psi_\\text{EMIT}(y_i \\rightarrow x_i) + \\log \\psi_\\text{TRANS}(y_{i-1} \\rightarrow y_i) $$\n", "$$ = \\sum_i h_i[y_i] + \\textbf{P}_{y_i, y_{i-1}} $$\n", "where in this second expression, we think of the tags as being assigned unique non-negative indices.\n", "\n", "If the above discussion was too brief, you can check out [this](http://www.cs.columbia.edu/%7Emcollins/crf.pdf) write up from Michael Collins on CRFs.\n", "\n", "### The Forward Algorithm in Log-Space and the Log-Sum-Exp Trick\n", "\n", "As hinted at above, computing the forward variables requires using a log-sum-exp. I want to explain why, since it was a little confusing to me at first, and many resources just present the forward algorithm in potential space. The recurrence for the forward variable at the $i$'th word for the tag $j$, $\\alpha_i(j)$, is\n", "$$ \\alpha_i(j) = \\sum_{j' \\in T} \\psi_\\text{EMIT}(j \\rightarrow i) \\times \\psi_\\text{TRANS}(j' \\rightarrow j) \\times \\alpha_{i-1}(j') $$\n", "\n", "This is numerically unstable, and underflow is likely. It is also inconvenient to work with proper non-negative potentials in our model. We instead want to compute $\\log \\alpha_i(j)$. What we need to do is to multiply the potentials, which corresponds to adding log potentials. Then, we have to sum over tags, but what is the corresponding operation to summing over tags in log space? It is not clear. Instead, we need to transform out of log-space, take the product of potentials, do the sum over tags, and then transform back to log space. This is broken down in the revised recurrence below:\n", "\n", "$$ \\log \\alpha_i(j) = \\log \\overbrace{\\sum_{j' \\in T} \\exp{(\\log \\psi_\\text{EMIT}(j \\rightarrow i) + \\log \\psi_\\text{TRANS}(j' \\rightarrow j) + \\log \\alpha_{i-1}(j'))}}^\\text{transform out of log-space and compute forward variable} $$\n", "\n", "If you carry out elementary exponential / logarithm identities in the stuff under the overbrace above, you will see that it computes the same thing as the first recurrence, then just takes the logarithm. Log-sum-exp appears a fair bit in machine learning, and there is a [well-known trick](https://en.wikipedia.org/wiki/LogSumExp) to computing it in a numerically stable way. I use this trick in my log_sum_exp function below (I don't think Pytorch provides this function in its library).\n", "\n", "### Implementation Notes\n", "\n", "The example below implements the forward algorithm in log space to compute the partition function, and the viterbi algorithm to decode. Backpropagation will compute the gradients automatically for us. We don't have to do anything by hand.\n", "\n", "The implementation is not optimized. If you understand what is going on, you'll probably quickly see that iterating over the next tag in the forward algorithm could probably be done in one big operation. I wanted to code to be more readable. If you want to make the relevant change, you could probably use this tagger for real tasks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: Bidirectional LSTM Conditional Random Field for Named-Entity Recognition" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Helper functions to make the code more readable.\n", "def to_scalar(var):\n", " # returns a python float\n", " return var.view(-1).data.tolist()[0]\n", "\n", "def argmax(vec):\n", " # return the argmax as a python int\n", " _, idx = torch.max(vec, 1)\n", " return to_scalar(idx)\n", "\n", "# Compute log sum exp in a numerically stable way for the forward algorithm\n", "def log_sum_exp(vec):\n", " max_score = vec[0, argmax(vec)]\n", " max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])\n", " return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))\n", " \n", "\n", "class BiLSTM_CRF(nn.Module):\n", " \n", " def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):\n", " super(BiLSTM_CRF, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.hidden_dim = hidden_dim\n", " self.vocab_size = vocab_size\n", " self.tag_to_ix = tag_to_ix\n", " self.tagset_size = len(tag_to_ix)\n", " \n", " self.word_embeds = nn.Embedding(vocab_size, embedding_dim)\n", " self.lstm = nn.LSTM(embedding_dim, hidden_dim/2, num_layers=1, bidirectional=True)\n", " \n", " # Maps the output of the LSTM into tag space.\n", " self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)\n", " \n", " # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.\n", " self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))\n", " \n", " # These two statements enforce the constraint that we never transfer *to* the start tag,\n", " # and we never transfer *from* the stop tag (the model would probably learn this anyway,\n", " # so this enforcement is likely unimportant)\n", " self.transitions.data[tag_to_ix[START_TAG], :] = -10000\n", " self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000\n", " \n", " self.hidden = self.init_hidden()\n", " \n", " def init_hidden(self):\n", " return ( autograd.Variable( torch.randn(2, 1, self.hidden_dim)),\n", " autograd.Variable( torch.randn(2, 1, self.hidden_dim)) )\n", " \n", " \n", " def _forward_alg(self, feats):\n", " # Do the forward algorithm to compute the partition function\n", " init_alphas = torch.Tensor(1, self.tagset_size).fill_(-10000.)\n", " # START_TAG has all of the score.\n", " init_alphas[0][self.tag_to_ix[START_TAG]] = 0.\n", " \n", " # Wrap in a variable so that we will get automatic backprop\n", " forward_var = autograd.Variable(init_alphas)\n", " \n", " # Iterate through the sentence\n", " for feat in feats:\n", " alphas_t = [] # The forward variables at this timestep\n", " for next_tag in xrange(self.tagset_size):\n", " # broadcast the emission score: it is the same regardless of the previous tag\n", " emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)\n", " # the ith entry of trans_score is the score of transitioning to next_tag from i\n", " trans_score = self.transitions[next_tag].view(1, -1)\n", " # The ith entry of next_tag_var is the value for the edge (i -> next_tag)\n", " # before we do log-sum-exp\n", " next_tag_var = forward_var + trans_score + emit_score\n", " # The forward variable for this tag is log-sum-exp of all the scores.\n", " alphas_t.append(log_sum_exp(next_tag_var))\n", " forward_var = torch.cat(alphas_t).view(1, -1)\n", " terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n", " alpha = log_sum_exp(terminal_var)\n", " return alpha\n", " \n", " def _get_lstm_features(self, sentence):\n", " self.hidden = self.init_hidden()\n", " embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)\n", " lstm_out, self.hidden = self.lstm(embeds)\n", " lstm_out = lstm_out.view(len(sentence), self.hidden_dim)\n", " lstm_feats = self.hidden2tag(lstm_out)\n", " return lstm_feats\n", " \n", " def _score_sentence(self, feats, tags):\n", " # Gives the score of a provided tag sequence\n", " score = autograd.Variable( torch.Tensor([0]) )\n", " tags = torch.cat( [torch.LongTensor([self.tag_to_ix[START_TAG]]), tags] )\n", " for i, feat in enumerate(feats):\n", " score = score + self.transitions[tags[i+1], tags[i]] + feat[tags[i+1]]\n", " score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]\n", " return score\n", " \n", " def _viterbi_decode(self, feats):\n", " backpointers = []\n", " \n", " # Initialize the viterbi variables in log space\n", " init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.)\n", " init_vvars[0][self.tag_to_ix[START_TAG]] = 0\n", " \n", " # forward_var at step i holds the viterbi variables for step i-1 \n", " forward_var = autograd.Variable(init_vvars)\n", " for feat in feats:\n", " bptrs_t = [] # holds the backpointers for this step\n", " viterbivars_t = [] # holds the viterbi variables for this step\n", " \n", " for next_tag in range(self.tagset_size):\n", " # next_tag_var[i] holds the viterbi variable for tag i at the previous step,\n", " # plus the score of transitioning from tag i to next_tag.\n", " # We don't include the emission scores here because the max\n", " # does not depend on them (we add them in below)\n", " next_tag_var = forward_var + self.transitions[next_tag]\n", " best_tag_id = argmax(next_tag_var)\n", " bptrs_t.append(best_tag_id)\n", " viterbivars_t.append(next_tag_var[0][best_tag_id])\n", " # Now add in the emission scores, and assign forward_var to the set\n", " # of viterbi variables we just computed\n", " forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)\n", " backpointers.append(bptrs_t)\n", " \n", " # Transition to STOP_TAG\n", " terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n", " best_tag_id = argmax(terminal_var)\n", " path_score = terminal_var[0][best_tag_id]\n", " \n", " # Follow the back pointers to decode the best path.\n", " best_path = [best_tag_id]\n", " for bptrs_t in reversed(backpointers):\n", " best_tag_id = bptrs_t[best_tag_id]\n", " best_path.append(best_tag_id)\n", " # Pop off the start tag (we dont want to return that to the caller)\n", " start = best_path.pop()\n", " assert start == self.tag_to_ix[START_TAG] # Sanity check\n", " best_path.reverse()\n", " return path_score, best_path\n", " \n", " def neg_log_likelihood(self, sentence, tags):\n", " self.hidden = self.init_hidden()\n", " feats = self._get_lstm_features(sentence)\n", " forward_score = self._forward_alg(feats)\n", " gold_score = self._score_sentence(feats, tags)\n", " return forward_score - gold_score\n", " \n", " def forward(self, sentence): # dont confuse this with _forward_alg above.\n", " self.hidden = self.init_hidden()\n", " # Get the emission scores from the BiLSTM\n", " lstm_feats = self._get_lstm_features(sentence)\n", " \n", " # Find the best path, given the features.\n", " score, tag_seq = self._viterbi_decode(lstm_feats)\n", " return score, tag_seq\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": true }, "outputs": [], "source": [ "START_TAG = \"\"\n", "STOP_TAG = \"\"\n", "EMBEDDING_DIM = 5\n", "HIDDEN_DIM = 4\n", "\n", "# Make up some training data\n", "training_data = [ (\n", " \"the wall street journal reported today that apple corporation made money\".split(),\n", " \"B I I I O O O B I O O\".split()\n", "), (\n", " \"georgia tech is a university in georgia\".split(),\n", " \"B I O O O O B\".split()\n", ") ]\n", "\n", "word_to_ix = {}\n", "for sentence, tags in training_data:\n", " for word in sentence:\n", " if word not in word_to_ix:\n", " word_to_ix[word] = len(word_to_ix)\n", " \n", "tag_to_ix = { \"B\": 0, \"I\": 1, \"O\": 2, START_TAG: 3, STOP_TAG: 4 }" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)\n", "optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Variable containing:\n", " 13.2216\n", "[torch.FloatTensor of size 1]\n", ", [2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 0])\n" ] } ], "source": [ "# Check predictions before training\n", "precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)\n", "precheck_tags = torch.LongTensor([ tag_to_ix[t] for t in training_data[0][1] ])\n", "print model(precheck_sent)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Make sure prepare_sequence from earlier in the LSTM section is loaded\n", "for epoch in xrange(300): # again, normally you would NOT do 300 epochs, it is toy data\n", " for sentence, tags in training_data:\n", " # Step 1. Remember that Pytorch accumulates gradients. We need to clear them out\n", " # before each instance\n", " model.zero_grad()\n", " \n", " # Step 2. Get our inputs ready for the network, that is, turn them into Variables\n", " # of word indices.\n", " sentence_in = prepare_sequence(sentence, word_to_ix)\n", " targets = torch.LongTensor([ tag_to_ix[t] for t in tags ])\n", " \n", " # Step 3. Run our forward pass.\n", " neg_log_likelihood = model.neg_log_likelihood(sentence_in, targets)\n", " \n", " # Step 4. Compute the loss, gradients, and update the parameters by calling\n", " # optimizer.step()\n", " neg_log_likelihood.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Variable containing:\n", " 31.9680\n", "[torch.FloatTensor of size 1]\n", ", [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])\n" ] } ], "source": [ "# Check predictions after training\n", "precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)\n", "print model(precheck_sent)\n", "# We got it!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exercise: A new loss function for discriminative tagging\n", "It wasn't really necessary for us to create a computation graph when doing decoding, since we do not backpropagate from the viterbi path score. Since we have it anyway, try training the tagger where the loss function is the difference between the Viterbi path score and the score of the gold-standard path. It should be clear that this function is non-negative and 0 when the predicted tag sequence is the correct tag sequence. This is essentially *structured perceptron*.\n", "\n", "This modification should be short, since Viterbi and score_sentence are already implemented. This is an example of the shape of the computation graph *depending on the training instance*. Although I haven't tried implementing this in a static toolkit, I imagine that it is possible but much less straightforward.\n", "\n", "Pick up some real data and do a comparison!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.14" } }, "nbformat": 4, "nbformat_minor": 2 }