{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.utils.data as data_utils\n", "import operator\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RNN intuition\n", "\n", "Let us assume that we have an input $x = [x_1, x_2, ..., x_N]$ and we need to learn the mapping for some output $y = [y_1, y_2, ..., y_N]$, where $N$ is variable for each instance. In this case we can't just use a simple feed forward neural network which maps $x \\rightarrow y$, as this will not work with variable length sequences. Furthermore, the number or parameters required for training such a network would be proportional to $size(x_i)*N$. This is a major memory cost. Additionally, if the sequence has some common mapping between $x_i$ and $y_i$, then we would be learning redundant weights for each pair in the sequence. This is where an RNN network is more useful. The basic idea is that each input $x_i$ is processed in a similar fashion using the same processing module and some additional context variable (which we will henseforth refer to as the **hidden state**). This hidden state should capture some information about the part of the sequence which has already been processed. Now at each step of the sequence we need to do the following:\n", "\n", "* Generate the output based on the previous hidden state and current input\n", "* Update the hidden state based on the previous hidden state and current input. \n", "\n", "The order of the above steps is not fixed and forms the basis of many RNN spin-offs. What is important, at each step, is to have a new output and a new hidden state. Sometimes, the hidden state and the outputs are the same, to make the network smaller. But the core idea remains same. Below we would like to formalize the general intuition of an RNN module. \n", "\n", "Initialize an initial hidden state $h_{0}$ with some initial value. \n", "\n", "At timestep n: \n", "$$\n", "\\begin{equation}\n", "h^{'}_{i} = f(x_{i},h_{i})\\\\\n", "y_{i} = g(x_{i},h^{'}_{i})\\\\\n", "h_{i+1} = h^{'}_{i}\\\\\n", "\\end{equation}\n", "$$\n", "\n", "Here $y_{i}$ is the output and $h^{'}_{i}$ is the intermediate hidden state." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Input2Hidden(nn.Module):\n", " def __init__(self, x_dim, concat_layers=False):\n", " \"\"\"Input2Hidden module\n", " \n", " Args:\n", " x_dim: input vector dimension\n", " concat_layers: weather to concat input and hidden layers or sum them\n", " \"\"\"\n", " super(Input2Hidden, self).__init__()\n", " self.concat_layers = concat_layers\n", " input_dim = x_dim\n", " if self.concat_layers:\n", " input_dim = 2*x_dim\n", " self.linear_layer = nn.Linear(input_dim, x_dim)\n", " \n", " def forward(self, x, h):\n", " if self.concat_layers:\n", " cell_input = torch.cat([x,h], dim=1)\n", " else:\n", " cell_input = x + h\n", " assert isinstance(cell_input, Variable)\n", " logit = F.tanh(self.linear_layer(cell_input))\n", " return logit\n", " \n", " \n", "class Hidden2Output(nn.Module):\n", " def __init__(self, x_dim, out_dim, concat_layers=False):\n", " \"\"\"Hidden2Output module\n", " \n", " Args:\n", " x_dim: input vector dimension\n", " out_dim: output vector dimension\n", " concat_layers: weather to concat input and hidden layers or sum them\n", " \"\"\"\n", " super(Hidden2Output, self).__init__()\n", " input_dim = x_dim\n", " self.concat_layers = concat_layers\n", " if self.concat_layers:\n", " input_dim = 2*x_dim\n", " self.linear_layer = nn.Linear(input_dim, out_dim)\n", " \n", " def forward(self, x, h):\n", " if self.concat_layers:\n", " cell_input = torch.cat([x,h], dim=1)\n", " else:\n", " cell_input = x + h\n", " assert isinstance(cell_input, Variable)\n", " logit = F.tanh(self.linear_layer(cell_input))\n", " return logit\n", " \n", " \n", "class CustomRNNCell(nn.Module):\n", " def __init__(self, i2h, h2o):\n", " super(CustomRNNCell, self).__init__()\n", " self.i2h = i2h\n", " self.h2o = h2o\n", " \n", " def forward(self, x, h):\n", " assert isinstance(x, Variable)\n", " assert isinstance(h, Variable)\n", " h_prime = self.i2h(x,h)\n", " assert isinstance(h_prime, Variable)\n", " output = self.h2o(x,h_prime)\n", " return output, h_prime\n", " \n", "class Model(nn.Module):\n", " def __init__(self, embedding, rnn_cell):\n", " super(Model, self).__init__()\n", " self.embedding = embedding\n", " self.rnn_cell = rnn_cell\n", " self.loss_function = nn.CrossEntropyLoss()\n", " \n", " def forward(self, word_ids, hidden=None):\n", " if hidden is None:\n", " hidden = Variable(torch.zeros(\n", " word_ids.data.shape[0],self.embedding.embedding_dim))\n", " assert isinstance(hidden, Variable)\n", " embeddings = self.embedding(word_ids)\n", " max_seq_length = word_ids.data.shape[-1]\n", " outputs, hidden_states = [], []\n", " for i in range(max_seq_length):\n", " x = embeddings[:, i, :]\n", " assert isinstance(x, Variable)\n", " #print(\"x={}\\nhidden={}\".format(x,hidden))\n", " output, hidden = self.rnn_cell(x, hidden)\n", " assert isinstance(output, Variable)\n", " assert isinstance(hidden, Variable)\n", " #print(\"output: {}, hidden: {}\".format(output.data.shape, hidden.data.shape))\n", " outputs.append(output.unsqueeze(1))\n", " hidden_states.append(hidden.unsqueeze(1))\n", " outputs = torch.cat(outputs, 1)\n", " hidden_states = torch.cat(hidden_states, 1)\n", " assert isinstance(outputs, Variable)\n", " assert isinstance(hidden_states, Variable)\n", " return outputs, hidden_states\n", " \n", " def loss(self, word_ids, target_ids, hidden=None):\n", " outputs, hidden_states = self.forward(word_ids, hidden=hidden)\n", " outputs = outputs.view(-1, outputs.data.shape[-1])\n", " target_ids = target_ids.view(-1)\n", " assert isinstance(outputs, Variable)\n", " assert isinstance(target_ids, Variable)\n", " #print(\"output={}\\ttargets={}\".format(outputs.data.shape,target_ids.data.shape))\n", " loss = self.loss_function(outputs, target_ids)\n", " return loss \n", " \n", " \n", " def predict(self, word_ids, hidden=None):\n", " outputs, hidden_states = self.forward(word_ids, hidden=hidden)\n", " outputs = outputs.view(-1, outputs.data.shape[-1])\n", " max_scores, predictions = outputs.max(1)\n", " predictions = predictions.view(*word_ids.data.shape)\n", " #print(word_ids.data.shape, predictions.data.shape)\n", " assert word_ids.data.shape == predictions.data.shape, \"word_ids: {}, predictions: {}\".format(\n", " word_ids.data.shape, predictions.data.shape\n", " )\n", " return predictions\n", " \n", " \n", "def tensors2variables(*args, requires_grad=False):\n", " return tuple(map(lambda x: Variable(x, requires_grad=requires_grad), args))\n", "\n", "def get_batch(tensor_types, *args, requires_grad=False):\n", " return tuple(map(lambda t,arg: Variable(t(arg), requires_grad=requires_grad), tensor_types, args))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning to predict bit flip\n", "\n", "Let us take a simple example of using an RNN to predict the flip in bits of an $N$ bit unsigned integer. In python for an integer n represented using $N$ bits, the unsigned bitflip can be written as `(~n) & ((1<