{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import theano, theano.tensor as T\n", "import numpy as np\n", "import theano_lstm\n", "import random" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Nonsensical Language Model using Theano LSTM\n", "\n", "Today we will train a **nonsensical** language model !\n", "\n", "We will first collect some language data, convert it to numbers, and then feed it to a recurrent neural network and ask it to predict upcoming words. When we are done we will have a machine that can generate sentences from our made-up language ad-infinitum !\n", "\n", "### Collect Language Data\n", "\n", "The first step here is to get some data. Since we are basing our language on nonsense, we need to generate good nonsense using a sampler.\n", "\n", "Our sampler will take a probability table as input, e.g. a language where people are equally likely to say \"a\" or \"b\" would be written as follows:\n", "\n", " nonsense = Sampler({\"a\": 0.5, \"b\": 0.5})\n", " \n", "We get samples from this language like this:\n", "\n", " word = nonsense()\n", " \n", "We overloaded the `__call__` method and got this syntactic sugar." ] }, { "cell_type": "code", "execution_count": 192, "metadata": { "collapsed": true }, "outputs": [], "source": [ "## Fake dataset:\n", "\n", "class Sampler:\n", " def __init__(self, prob_table):\n", " total_prob = 0.0\n", " if type(prob_table) is dict:\n", " for key, value in prob_table.items():\n", " total_prob += value\n", " elif type(prob_table) is list:\n", " prob_table_gen = {}\n", " for key in prob_table:\n", " prob_table_gen[key] = 1.0 / (float(len(prob_table)))\n", " total_prob = 1.0\n", " prob_table = prob_table_gen\n", " else:\n", " raise ArgumentError(\"__init__ takes either a dict or a list as its first argument\")\n", " if total_prob <= 0.0:\n", " raise ValueError(\"Probability is not strictly positive.\")\n", " self._keys = []\n", " self._probs = []\n", " for key in prob_table:\n", " self._keys.append(key)\n", " self._probs.append(prob_table[key] / total_prob)\n", " \n", " def __call__(self):\n", " sample = random.random()\n", " seen_prob = 0.0\n", " for key, prob in zip(self._keys, self._probs):\n", " if (seen_prob + prob) >= sample:\n", " return key\n", " else:\n", " seen_prob += prob\n", " return key" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Parts of Speech\n", "\n", "Now that we have a `Sampler` we can create a couple different word groups that our language uses to distinguish between different probability distributions easily:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "samplers = {\n", " \"punctuation\": Sampler({\".\": 0.49, \",\": 0.5, \";\": 0.03, \"?\": 0.05, \"!\": 0.05}),\n", " \"stop\": Sampler({\"the\": 10, \"from\": 5, \"a\": 9, \"they\": 3, \"he\": 3, \"it\" : 2.5, \"she\": 2.7, \"in\": 4.5}),\n", " \"noun\": Sampler([\"cat\", \"broom\", \"boat\", \"dog\", \"car\", \"wrangler\", \"mexico\", \"lantern\", \"book\", \"paper\", \"joke\",\"calendar\", \"ship\", \"event\"]),\n", " \"verb\": Sampler([\"ran\", \"stole\", \"carried\", \"could\", \"would\", \"do\", \"can\", \"carry\", \"catapult\", \"jump\", \"duck\"]),\n", " \"adverb\": Sampler([\"rapidly\", \"calmly\", \"cooly\", \"in jest\", \"fantastically\", \"angrily\", \"dazily\"])\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simple Grammar\n", "\n", "To create sentences from our language we create a simple recursion that goes as follows:\n", "\n", "1. If the sentence we have ends with a full stop, a question mark, or an exclamation point then end at once!\n", "2. Else our sentence should have:\n", " * A stop word\n", " * A noun\n", " * An adverb (with prob 0.3), or 2 adverbs (with prob 0.3*0.3=0.09)\n", " * A verb\n", " * Another noun (with prob 0.2), or 2 more nouns connected by a dash (with prob 0.2*0.1=0.02)\n", "3. If our sentence is now over 500 characters, add a full stop and end at once!\n", "4. Else add some punctuation and go back to (1)" ] }, { "cell_type": "code", "execution_count": 193, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def generate_nonsense(word = \"\"):\n", " if word.endswith(\".\"):\n", " return word\n", " else:\n", " if len(word) > 0:\n", " word += \" \"\n", " word += samplers[\"stop\"]()\n", " word += \" \" + samplers[\"noun\"]()\n", " if random.random() > 0.7:\n", " word += \" \" + samplers[\"adverb\"]()\n", " if random.random() > 0.7:\n", " word += \" \" + samplers[\"adverb\"]()\n", " word += \" \" + samplers[\"verb\"]()\n", " if random.random() > 0.8:\n", " word += \" \" + samplers[\"noun\"]()\n", " if random.random() > 0.9:\n", " word += \"-\" + samplers[\"noun\"]()\n", " if len(word) > 500:\n", " word += \".\"\n", " else:\n", " word += \" \" + samplers[\"punctuation\"]()\n", " return generate_nonsense(word)\n", "\n", "def generate_dataset(total_size, ):\n", " sentences = []\n", " for i in range(total_size):\n", " sentences.append(generate_nonsense())\n", " return sentences\n", "\n", "# generate dataset \n", "lines = generate_dataset(100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Utilities\n", "\n", "Now that we have our training corpus for our language model (optionally you could gather an actual corpus from the web :), we can now create our first utility, `Vocab`, that will hold the mapping from words to an index, and perfom the conversions from words to indices and vice-versa:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "### Utilities:\n", "class Vocab:\n", " __slots__ = [\"word2index\", \"index2word\", \"unknown\"]\n", " \n", " def __init__(self, index2word = None):\n", " self.word2index = {}\n", " self.index2word = []\n", " \n", " # add unknown word:\n", " self.add_words([\"**UNKNOWN**\"])\n", " self.unknown = 0\n", " \n", " if index2word is not None:\n", " self.add_words(index2word)\n", " \n", " def add_words(self, words):\n", " for word in words:\n", " if word not in self.word2index:\n", " self.word2index[word] = len(self.word2index)\n", " self.index2word.append(word)\n", " \n", " def __call__(self, line):\n", " \"\"\"\n", " Convert from numerical representation to words\n", " and vice-versa.\n", " \"\"\"\n", " if type(line) is np.ndarray:\n", " return \" \".join([self.index2word[word] for word in line])\n", " if type(line) is list:\n", " if len(line) > 0:\n", " if line[0] is int:\n", " return \" \".join([self.index2word[word] for word in line])\n", " indices = np.zeros(len(line), dtype=np.int32)\n", " else:\n", " line = line.split(\" \")\n", " indices = np.zeros(len(line), dtype=np.int32)\n", " \n", " for i, word in enumerate(line):\n", " indices[i] = self.word2index.get(word, self.unknown)\n", " \n", " return indices\n", " \n", " @property\n", " def size(self):\n", " return len(self.index2word)\n", " \n", " def __len__(self):\n", " return len(self.index2word)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a Mapping from numbers to words\n", "\n", "Now we can use the `Vocab` class to gather all the words and store an Index:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vocab = Vocab()\n", "for line in lines:\n", " vocab.add_words(line.split(\" \"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To send our sentences in one big chunk to our neural network we transform each sentence into a row vector and place each of these rows into a bigger matrix that holds all these rows. Not all sentences have the same length, so we will pad those that are too short with 0s in `pad_into_matrix`:" ] }, { "cell_type": "code", "execution_count": 168, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def pad_into_matrix(rows, padding = 0):\n", " if len(rows) == 0:\n", " return np.array([0, 0], dtype=np.int32)\n", " lengths = map(len, rows)\n", " width = max(lengths)\n", " height = len(rows)\n", " mat = np.empty([height, width], dtype=rows[0].dtype)\n", " mat.fill(padding)\n", " for i, row in enumerate(rows):\n", " mat[i, 0:len(row)] = row\n", " return mat, list(lengths)\n", "\n", "# transform into big numerical matrix of sentences:\n", "numerical_lines = []\n", "for line in lines:\n", " numerical_lines.append(vocab(line))\n", "numerical_lines, numerical_lengths = pad_into_matrix(numerical_lines)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build a Recurrent Neural Network\n", "\n", "Now the real work is upon us! Thank goodness we have our language data ready. We now create a recurrent neural network by connecting an Embedding $E$ for each word in our corpus, and stacking some special cells together to form a prediction function. Mathematically we want:\n", "\n", "$$\\mathrm{argmax_{E, \\Phi}} {\\bf P}(w_{k+1}| w_{k}, \\dots, w_{0}; E, \\Phi) = f(x, h)$$\n", "\n", "with $f(\\cdot, \\cdot)$ the function our recurrent neural network performs at each timestep that takes as inputs:\n", "\n", "* an observation $x$, and\n", "* a previous state $h$,\n", "\n", "and outputs a probability distribution $\\hat{p}$ over the next word.\n", "\n", "We have $x = E[ w_{k}]$ our observation at time $k$, and $h$ the internal state of our neural network, and $\\Phi$ is the set of parameters used by our classifier, and recurrent neural network, and $E$ is the embedding for our words.\n", "\n", "In practice we will obtain $E$ and $\\Phi$ iteratively using gradient descent on the error our network is making in its prediction. To do this we define our error as the [Kullback-Leibler divergence](http://en.wikipedia.org/wiki/Kullback–Leibler_divergence) (a distance between probability distributions) between our estimate of $\\hat{p} = {\\bf P}(w_{k+1}| w_{k}, \\dots, w_{0}; E, \\Phi)$ and the actual value of ${\\bf P}(w_{k+1}| w_{k}, \\dots, w_{0})$ from the data (e.g. a probability distribution that is 1 for word $w_k$ and 0 elsewhere).\n", "\n", "\n", "#### Theano LSTM StackedCells function\n", "\n", "To build this predictive model we make use of [theano_lstm](https://github.com/JonathanRaiman/theano_lstm), a Python module for building recurrent neural networks using Theano. The first step we take is to declare what kind of cells we want to use by declaring a celltype. There are many different celltypes we can use, but the most common these days (and incidentally most effective) are `RNN` and `LSTM`. For a more in-depth discussion of how these work I suggest checking out [Arxiv](http://arxiv.org/find/all/1/all:+lstm/0/1/0/all/0/1), or [Alex Graves' website](http://www.cs.toronto.edu/~graves/), or [Wikipedia](http://en.wikipedia.org/wiki/Long_short_term_memory). Here we use `celltype = LSTM`.\n", "\n", " self.model = StackedCells(input_size, celltype=celltype, layers =[hidden_size] * stack_size)\n", " \n", "Once we've declared what kind of cells we want to use, we can now choose to add an Embedding to map integers (indices) to vectors (and in our case map words to their indices, then indices to word vectors we wish to train). Intuitively this lets the network separate and recognize what it is \"seeing\" or \"receiving\" at each timestep. To add an Embedding we create `Embedding(vocabulary_size, size_of_embedding_vectors)` and insert it at the begging of the `StackedCells`'s layers list (thereby telling `StackedCells` that this Embedding layer needs to be activated before the other ones):\n", " \n", " # add an embedding\n", " self.model.layers.insert(0, Embedding(vocab_size, input_size))\n", " \n", "The final output of our network needs to be a probability distribution over the next words (but in different application areas this could be a sentiment classification, a decision, a topic, etc...) so we add another layer that maps the internal state of the LSTMs to a probability distribution over the all the words in our language. To ensure that our prediction is indeed a probability distribution we \"activate\" our layer with a Softmax, meaning that we will exponentiate every value of the output, $q_i = e^{x_i}$, so that all values are positive, and then we will divide the output by its sum so that the output sums to 1:\n", "\n", "$$p_i = \\frac{q_i}{\\sum_j q_j}\\text{, and }\\sum_i p_i = 1.$$\n", " \n", " # add a classifier:\n", " self.model.layers.append(Layer(hidden_size, vocab_size, activation = softmax))\n", " \n", "For convenience we wrap this all in one class below.\n", "\n", "#### Prediction\n", "\n", "We have now defined our network. At each timestep we can produce a probability distribution for each input index:\n", "\n", " def create_prediction(self, greedy=False):\n", " def step(idx, *states):\n", " # new hiddens are the states we need to pass to LSTMs\n", " # from past. Because the StackedCells also include\n", " # the embeddings, and those have no state, we pass\n", " # a \"None\" instead:\n", " new_hiddens = [None] + list(states)\n", "\n", " new_states = self.model.forward(idx, prev_hiddens = new_hiddens)\n", " return new_states[1:]\n", " ...\n", " \n", "Our inputs are an integer matrix Theano symbolic variable:\n", " \n", " ...\n", " # in sequence forecasting scenario we take everything\n", " # up to the before last step, and predict subsequent\n", " # steps ergo, 0 ... n - 1, hence:\n", " inputs = self.input_mat[:, 0:-1]\n", " num_examples = inputs.shape[0]\n", " # pass this to Theano's recurrence relation function:\n", " ....\n", "\n", "Scan receives our recurrence relation `step` from above, and also needs to know what will be outputted at each step in `outputs_info`. We give `outputs_info` a set of variables corresponding to the hidden states of our StackedCells. Some of the layers have no hidden state, and thus we should simply pass a `None` to Theano, while others do require some initial state. In those cases with wrap their initial state inside a dictionary:\n", "\n", " def has_hidden(layer):\n", " \"\"\"\n", " Whether a layer has a trainable\n", " initial hidden state.\n", " \"\"\"\n", " return hasattr(layer, 'initial_hidden_state')\n", "\n", " def matrixify(vector, n):\n", " return T.repeat(T.shape_padleft(vector), n, axis=0)\n", "\n", " def initial_state(layer, dimensions = None):\n", " \"\"\"\n", " Initalizes the recurrence relation with an initial hidden state\n", " if needed, else replaces with a \"None\" to tell Theano that\n", " the network **will** return something, but it does not need\n", " to send it to the next step of the recurrence\n", " \"\"\"\n", " if dimensions is None:\n", " return layer.initial_hidden_state if has_hidden(layer) else None\n", " else:\n", " return matrixify(layer.initial_hidden_state, dimensions) if has_hidden(layer) else None\n", "\n", " def initial_state_with_taps(layer, dimensions = None):\n", " \"\"\"Optionally wrap tensor variable into a dict with taps=[-1]\"\"\"\n", " state = initial_state(layer, dimensions)\n", " if state is not None:\n", " return dict(initial=state, taps=[-1])\n", " else:\n", " return None\n", " \n", "Let's now create these inital states (note how we skip layer 1, the embeddings by doing `self.model.layers[1:]` in the iteration, this is because there is no point in passing these embeddings around in our recurrence because word vectors are only seen at the timestep they are received in this network):\n", "\n", " # choose what gets outputted at each timestep:\n", " outputs_info = [initial_state_with_taps(layer, num_examples) for layer in self.model.layers[1:]]\n", " result, _ = theano.scan(fn=step,\n", " sequences=[inputs.T],\n", " outputs_info=outputs_info)\n", "\n", " if greedy:\n", " return result[0]\n", " # softmaxes are the last layer of our network,\n", " # and are at the end of our results list:\n", " return result[-1].transpose((2,0,1))\n", " # we reorder the predictions to be:\n", " # 1. what row / example\n", " # 2. what timestep\n", " # 3. softmax dimension\n", "\n", "#### Error Function:\n", "\n", "Our error function uses `theano_lstm`'s `masked_loss` method. This method allows us to define ranges over which a probability distribution should obey a particular target distribution. We control this method by setting start and end points for these ranges. In doing so we mask the areas where we do not care what the network predicted.\n", "\n", "In our case our network predicts words we care about during the sentence, but when we pad our short sentences with 0s to fill our matrix, we do not care what the network does there, because this is happening outside the sentence we collected:\n", "\n", " def create_cost_fun (self):\n", " # create a cost function that\n", " # takes each prediction at every timestep\n", " # and guesses next timestep's value:\n", " what_to_predict = self.input_mat[:, 1:]\n", " # because some sentences are shorter, we\n", " # place masks where the sentences end:\n", " # (for how long is zero indexed, e.g. an example going from `[2,3)`)\n", " # has this value set 0 (here we substract by 1):\n", " for_how_long = self.for_how_long - 1\n", " # all sentences start at T=0:\n", " starting_when = T.zeros_like(self.for_how_long)\n", " \n", " self.cost = masked_loss(self.predictions,\n", " what_to_predict,\n", " for_how_long,\n", " starting_when).sum()\n", " \n", "#### Training Function\n", "\n", "We now have a cost function. To perform gradient descent we now need to tell Theano how each parameter must be updated at every training epoch. We `theano_lstm`'s `create_optimization_udpates` method to generate a dictionary of updates and to apply special gradient descent rules that accelerate and facilitate training (for instance scaling the gradients when they are too large or too little, and preventing gradients from becoming too big and making our model numerically unstable -- in this example we use [Adadelta](http://arxiv.org/abs/1212.5701):\n", "\n", " def create_training_function(self):\n", " updates, _, _, _, _ = create_optimization_updates(self.cost, self.params, method=\"adadelta\")\n", " self.update_fun = theano.function(\n", " inputs=[self.input_mat, self.for_how_long],\n", " outputs=self.cost,\n", " updates=updates,\n", " allow_input_downcast=True)\n", "\n", "PS: our parameters are obtained by calling `self.model.params`:\n", "\n", " @property\n", " def params(self):\n", " return self.model.params\n", " \n", "### Final Code" ] }, { "cell_type": "code", "execution_count": 189, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from theano_lstm import Embedding, LSTM, RNN, StackedCells, Layer, create_optimization_updates, masked_loss\n", "\n", "def softmax(x):\n", " \"\"\"\n", " Wrapper for softmax, helps with\n", " pickling, and removing one extra\n", " dimension that Theano adds during\n", " its exponential normalization.\n", " \"\"\"\n", " return T.nnet.softmax(x.T)\n", "\n", "def has_hidden(layer):\n", " \"\"\"\n", " Whether a layer has a trainable\n", " initial hidden state.\n", " \"\"\"\n", " return hasattr(layer, 'initial_hidden_state')\n", "\n", "def matrixify(vector, n):\n", " return T.repeat(T.shape_padleft(vector), n, axis=0)\n", "\n", "def initial_state(layer, dimensions = None):\n", " \"\"\"\n", " Initalizes the recurrence relation with an initial hidden state\n", " if needed, else replaces with a \"None\" to tell Theano that\n", " the network **will** return something, but it does not need\n", " to send it to the next step of the recurrence\n", " \"\"\"\n", " if dimensions is None:\n", " return layer.initial_hidden_state if has_hidden(layer) else None\n", " else:\n", " return matrixify(layer.initial_hidden_state, dimensions) if has_hidden(layer) else None\n", " \n", "def initial_state_with_taps(layer, dimensions = None):\n", " \"\"\"Optionally wrap tensor variable into a dict with taps=[-1]\"\"\"\n", " state = initial_state(layer, dimensions)\n", " if state is not None:\n", " return dict(initial=state, taps=[-1])\n", " else:\n", " return None\n", "\n", "class Model:\n", " \"\"\"\n", " Simple predictive model for forecasting words from\n", " sequence using LSTMs. Choose how many LSTMs to stack\n", " what size their memory should be, and how many\n", " words can be predicted.\n", " \"\"\"\n", " def __init__(self, hidden_size, input_size, vocab_size, stack_size=1, celltype=LSTM):\n", " # declare model\n", " self.model = StackedCells(input_size, celltype=celltype, layers =[hidden_size] * stack_size)\n", " # add an embedding\n", " self.model.layers.insert(0, Embedding(vocab_size, input_size))\n", " # add a classifier:\n", " self.model.layers.append(Layer(hidden_size, vocab_size, activation = softmax))\n", " # inputs are matrices of indices,\n", " # each row is a sentence, each column a timestep\n", " self._stop_word = theano.shared(np.int32(999999999), name=\"stop word\")\n", " self.for_how_long = T.ivector()\n", " self.input_mat = T.imatrix()\n", " self.priming_word = T.iscalar()\n", " self.srng = T.shared_randomstreams.RandomStreams(np.random.randint(0, 1024))\n", " # create symbolic variables for prediction:\n", " self.predictions = self.create_prediction()\n", " # create symbolic variable for greedy search:\n", " self.greedy_predictions = self.create_prediction(greedy=True)\n", " # create gradient training functions:\n", " self.create_cost_fun()\n", " self.create_training_function()\n", " self.create_predict_function()\n", " \n", " def stop_on(self, idx):\n", " self._stop_word.set_value(idx)\n", " \n", " @property\n", " def params(self):\n", " return self.model.params\n", " \n", " def create_prediction(self, greedy=False):\n", " def step(idx, *states):\n", " # new hiddens are the states we need to pass to LSTMs\n", " # from past. Because the StackedCells also include\n", " # the embeddings, and those have no state, we pass\n", " # a \"None\" instead:\n", " new_hiddens = [None] + list(states)\n", " \n", " new_states = self.model.forward(idx, prev_hiddens = new_hiddens)\n", " if greedy:\n", " new_idxes = new_states[-1]\n", " new_idx = new_idxes.argmax()\n", " # provide a stopping condition for greedy search:\n", " return ([new_idx.astype(self.priming_word.dtype)] + new_states[1:-1]), theano.scan_module.until(T.eq(new_idx,self._stop_word))\n", " else:\n", " return new_states[1:]\n", " # in sequence forecasting scenario we take everything\n", " # up to the before last step, and predict subsequent\n", " # steps ergo, 0 ... n - 1, hence:\n", " inputs = self.input_mat[:, 0:-1]\n", " num_examples = inputs.shape[0]\n", " # pass this to Theano's recurrence relation function:\n", " \n", " # choose what gets outputted at each timestep:\n", " if greedy:\n", " outputs_info = [dict(initial=self.priming_word, taps=[-1])] + [initial_state_with_taps(layer) for layer in self.model.layers[1:-1]]\n", " result, _ = theano.scan(fn=step,\n", " n_steps=200,\n", " outputs_info=outputs_info)\n", " else:\n", " outputs_info = [initial_state_with_taps(layer, num_examples) for layer in self.model.layers[1:]]\n", " result, _ = theano.scan(fn=step,\n", " sequences=[inputs.T],\n", " outputs_info=outputs_info)\n", " \n", " if greedy:\n", " return result[0]\n", " # softmaxes are the last layer of our network,\n", " # and are at the end of our results list:\n", " return result[-1].transpose((2,0,1))\n", " # we reorder the predictions to be:\n", " # 1. what row / example\n", " # 2. what timestep\n", " # 3. softmax dimension\n", " \n", " def create_cost_fun (self):\n", " # create a cost function that\n", " # takes each prediction at every timestep\n", " # and guesses next timestep's value:\n", " what_to_predict = self.input_mat[:, 1:]\n", " # because some sentences are shorter, we\n", " # place masks where the sentences end:\n", " # (for how long is zero indexed, e.g. an example going from `[2,3)`)\n", " # has this value set 0 (here we substract by 1):\n", " for_how_long = self.for_how_long - 1\n", " # all sentences start at T=0:\n", " starting_when = T.zeros_like(self.for_how_long)\n", " \n", " self.cost = masked_loss(self.predictions,\n", " what_to_predict,\n", " for_how_long,\n", " starting_when).sum()\n", " \n", " def create_predict_function(self):\n", " self.pred_fun = theano.function(\n", " inputs=[self.input_mat],\n", " outputs =self.predictions,\n", " allow_input_downcast=True\n", " )\n", " \n", " self.greedy_fun = theano.function(\n", " inputs=[self.priming_word],\n", " outputs=T.concatenate([T.shape_padleft(self.priming_word), self.greedy_predictions]),\n", " allow_input_downcast=True\n", " )\n", " \n", " def create_training_function(self):\n", " updates, _, _, _, _ = create_optimization_updates(self.cost, self.params, method=\"adadelta\")\n", " self.update_fun = theano.function(\n", " inputs=[self.input_mat, self.for_how_long],\n", " outputs=self.cost,\n", " updates=updates,\n", " allow_input_downcast=True)\n", " \n", " def __call__(self, x):\n", " return self.pred_fun(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Construct model\n", "\n", "We now declare the model and parametrize it to use an RNN, and make predictions in the range provided by our vocabulary. We also tell the greedy reconstruction search that it can consider a sentence as being over when the symbol corresponding to a period appears:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# construct model & theano functions:\n", "model = Model(\n", " input_size=10,\n", " hidden_size=10,\n", " vocab_size=len(vocab),\n", " stack_size=1, # make this bigger, but makes compilation slow\n", " celltype=RNN # use RNN or LSTM\n", ")\n", "model.stop_on(vocab.word2index[\".\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train Model\n", "\n", "We run 10,000 times through our data and every 500 epochs of training we output what the model considers to be a natural continuation to the sentence \"the\":\n" ] }, { "cell_type": "code", "execution_count": 191, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0, error=3877.55\n", "the .\n", "epoch 100, error=3873.32\n", "epoch 200, error=3868.80\n", "epoch 300, error=3863.65\n", "epoch 400, error=3857.58\n", "epoch 500, error=3850.15\n", "the .\n", "epoch 600, error=3840.67\n", "epoch 700, error=3828.21\n", "epoch 800, error=3811.36\n", "epoch 900, error=3787.88\n", "epoch 1000, error=3754.51\n", "the .\n", "epoch 1100, error=3707.27\n", "epoch 1200, error=3652.82\n", "epoch 1300, error=3794.47\n", "epoch 1400, error=3633.05\n", "epoch 1500, error=3749.59\n", "the .\n", "epoch 1600, error=3622.81\n", "epoch 1700, error=3728.75\n", "epoch 1800, error=3615.40\n", "epoch 1900, error=3711.92\n", "epoch 2000, error=3608.67\n", "the .\n", "epoch 2100, error=3697.46\n", "epoch 2200, error=3602.14\n", "epoch 2300, error=3684.72\n", "epoch 2400, error=3595.66\n", "epoch 2500, error=3673.21\n", "the .\n", "epoch 2600, error=3589.14\n", "epoch 2700, error=3662.57\n", "epoch 2800, error=3582.49\n", "epoch 2900, error=3652.51\n", "epoch 3000, error=3575.61\n", "the .\n", "epoch 3100, error=3642.76\n", "epoch 3200, error=3568.39\n", "epoch 3300, error=3633.05\n", "epoch 3400, error=3560.71\n", "epoch 3500, error=3623.09\n", "the event .\n", "epoch 3600, error=3552.42\n", "epoch 3700, error=3612.54\n", "epoch 3800, error=3543.32\n", "epoch 3900, error=3601.00\n", "epoch 4000, error=3533.19\n", "the event .\n", "epoch 4100, error=3588.00\n", "epoch 4200, error=3521.72\n", "epoch 4300, error=3572.95\n", "epoch 4400, error=3508.52\n", "epoch 4500, error=3555.13\n", "the event .\n", "epoch 4600, error=3493.12\n", "epoch 4700, error=3533.71\n", "epoch 4800, error=3474.91\n", "epoch 4900, error=3507.69\n", "epoch 5000, error=3453.10\n", "the event .\n", "epoch 5100, error=3476.03\n", "epoch 5200, error=3426.79\n", "epoch 5300, error=3437.64\n", "epoch 5400, error=3394.89\n", "epoch 5500, error=3391.61\n", "the event .\n", "epoch 5600, error=3356.28\n", "epoch 5700, error=3337.37\n", "epoch 5800, error=3309.92\n", "epoch 5900, error=3274.99\n", "epoch 6000, error=3255.30\n", "the event .\n", "epoch 6100, error=3205.48\n", "epoch 6200, error=3192.82\n", "epoch 6300, error=3130.87\n", "epoch 6400, error=3124.29\n", "epoch 6500, error=3053.95\n", "the event stole , .\n", "epoch 6600, error=3052.72\n", "epoch 6700, error=2977.69\n", "epoch 6800, error=2981.38\n", "epoch 6900, error=2904.48\n", "epoch 7000, error=2912.80\n", "the event carried , .\n", "epoch 7100, error=2836.22\n", "epoch 7200, error=2848.95\n", "epoch 7300, error=2774.26\n", "epoch 7400, error=2790.40\n", "epoch 7500, error=2719.00\n", "the event carried , the wrangler ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the calendar ran , the\n", "epoch 7600, error=2737.30\n", "epoch 7700, error=2670.22\n", "epoch 7800, error=2689.21\n", "epoch 7900, error=2627.33\n", "epoch 8000, error=2645.85\n", "the event carried , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a cat carry , a\n", "epoch 8100, error=2589.56\n", "epoch 8200, error=2607.03\n", "epoch 8300, error=2556.31\n", "epoch 8400, error=2572.67\n", "epoch 8500, error=2527.13\n", "the event carried , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the\n", "epoch 8600, error=2542.56\n", "epoch 8700, error=2501.63\n", "epoch 8800, error=2516.40\n", "epoch 8900, error=2479.40\n", "epoch 9000, error=2493.71\n", "the event carried , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a cat jump , a\n", "epoch 9100, error=2459.98\n", "epoch 9200, error=2473.99\n", "epoch 9300, error=2442.94\n", "epoch 9400, error=2456.79\n", "epoch 9500, error=2427.89\n", "the event carried , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the ship ran , the\n", "epoch 9600, error=2441.67\n", "epoch 9700, error=2414.49\n", "epoch 9800, error=2428.29\n", "epoch 9900, error=2402.47\n" ] } ], "source": [ "# train:\n", "for i in range(10000):\n", " error = model.update_fun(numerical_lines, numerical_lengths)\n", " if i % 100 == 0:\n", " print(\"epoch %(epoch)d, error=%(error).2f\" % ({\"epoch\": i, \"error\": error}))\n", " if i % 500 == 0:\n", " print(vocab(model.greedy_fun(vocab.word2index[\"the\"])))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.3" } }, "nbformat": 4, "nbformat_minor": 0 }