{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 6. Recurrent Neural Networks and Language Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n", "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n", "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n", "* https://github.com/pytorch/examples/tree/master/word_language_model\n", "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import nltk\n", "import random\n", "import numpy as np\n", "from collections import Counter, OrderedDict\n", "import nltk\n", "from copy import deepcopy\n", "flatten = lambda l: [item for sublist in l for item in sublist]\n", "random.seed(1024)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", "torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_sequence(seq, to_index):\n", " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", " return LongTensor(idxs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data load and Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Penn TreeBank" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_ptb_dataset(filename, word2index=None):\n", " corpus = open(filename, 'r', encoding='utf-8').readlines()\n", " corpus = flatten([co.strip().split() + [''] for co in corpus])\n", " \n", " if word2index == None:\n", " vocab = list(set(corpus))\n", " word2index = {'': 0}\n", " for vo in vocab:\n", " if word2index.get(vo) is None:\n", " word2index[vo] = len(word2index)\n", " \n", " return prepare_sequence(corpus, word2index), word2index" ] }, { "cell_type": "code", "execution_count": 175, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n", "\n", "def batchify(data, bsz):\n", " # Work out how cleanly we can divide the dataset into bsz parts.\n", " nbatch = data.size(0) // bsz\n", " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n", " data = data.narrow(0, 0, nbatch * bsz)\n", " # Evenly divide the data across the bsz batches.\n", " data = data.view(bsz, -1).contiguous()\n", " if USE_CUDA:\n", " data = data.cuda()\n", " return data" ] }, { "cell_type": "code", "execution_count": 176, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def getBatch(data, seq_length):\n", " for i in range(0, data.size(1) - seq_length, seq_length):\n", " inputs = Variable(data[:, i: i + seq_length])\n", " targets = Variable(data[:, (i + 1): (i + 1) + seq_length].contiguous())\n", " yield (inputs, targets)" ] }, { "cell_type": "code", "execution_count": 177, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_data, word2index = prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n", "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt', word2index)\n", "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt', word2index)" ] }, { "cell_type": "code", "execution_count": 178, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "10000" ] }, "execution_count": 178, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(word2index)" ] }, { "cell_type": "code", "execution_count": 179, "metadata": { "collapsed": true }, "outputs": [], "source": [ "index2word = {v:k for k, v in word2index.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf
" ] }, { "cell_type": "code", "execution_count": 180, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class LanguageModel(nn.Module): \n", " def __init__(self, vocab_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.5):\n", "\n", " super(LanguageModel, self).__init__()\n", " self.n_layers = n_layers\n", " self.hidden_size = hidden_size\n", " self.embed = nn.Embedding(vocab_size, embedding_size)\n", " self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)\n", " self.linear = nn.Linear(hidden_size, vocab_size)\n", " self.dropout = nn.Dropout(dropout_p)\n", " \n", " def init_weight(self):\n", " self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n", " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n", " self.linear.bias.data.fill_(0)\n", " \n", " def init_hidden(self,batch_size):\n", " hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", " context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n", " return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden, context)\n", " \n", " def detach_hidden(self, hiddens):\n", " return tuple([hidden.detach() for hidden in hiddens])\n", " \n", " def forward(self, inputs, hidden, is_training=False): \n", "\n", " embeds = self.embed(inputs)\n", " if is_training:\n", " embeds = self.dropout(embeds)\n", " out,hidden = self.rnn(embeds, hidden)\n", " return self.linear(out.contiguous().view(out.size(0) * out.size(1), -1)), hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It takes for a while..." ] }, { "cell_type": "code", "execution_count": 181, "metadata": { "collapsed": true }, "outputs": [], "source": [ "EMBED_SIZE = 128\n", "HIDDEN_SIZE = 1024\n", "NUM_LAYER = 1\n", "LR = 0.01\n", "SEQ_LENGTH = 30 # for bptt\n", "BATCH_SIZE = 20\n", "EPOCH = 40\n", "RESCHEDULED = False" ] }, { "cell_type": "code", "execution_count": 182, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_data = batchify(train_data, BATCH_SIZE)\n", "dev_data = batchify(dev_data, BATCH_SIZE//2)\n", "test_data = batchify(test_data, BATCH_SIZE//2)" ] }, { "cell_type": "code", "execution_count": 185, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model = LanguageModel(len(word2index), EMBED_SIZE, HIDDEN_SIZE, NUM_LAYER, 0.5)\n", "model.init_weight() \n", "if USE_CUDA:\n", " model = model.cuda()\n", "loss_function = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=LR)" ] }, { "cell_type": "code", "execution_count": 186, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n", "[00/40] mean_loss : 5.88, Perplexity : 358.21\n", "[00/40] mean_loss : 5.55, Perplexity : 256.44\n", "[01/40] mean_loss : 5.38, Perplexity : 217.46\n", "[01/40] mean_loss : 5.21, Perplexity : 182.41\n", "[01/40] mean_loss : 5.10, Perplexity : 164.39\n", "[02/40] mean_loss : 5.08, Perplexity : 160.87\n", "[02/40] mean_loss : 4.99, Perplexity : 147.18\n", "[02/40] mean_loss : 4.92, Perplexity : 136.52\n", "[03/40] mean_loss : 4.92, Perplexity : 136.64\n", "[03/40] mean_loss : 4.86, Perplexity : 129.32\n", "[03/40] mean_loss : 4.80, Perplexity : 121.46\n", "[04/40] mean_loss : 4.80, Perplexity : 121.91\n", "[04/40] mean_loss : 4.77, Perplexity : 117.64\n", "[04/40] mean_loss : 4.71, Perplexity : 111.22\n", "[05/40] mean_loss : 4.72, Perplexity : 112.01\n", "[05/40] mean_loss : 4.70, Perplexity : 109.46\n", "[05/40] mean_loss : 4.64, Perplexity : 103.96\n", "[06/40] mean_loss : 4.66, Perplexity : 105.25\n", "[06/40] mean_loss : 4.64, Perplexity : 103.63\n", "[06/40] mean_loss : 4.60, Perplexity : 99.00\n", "[07/40] mean_loss : 4.60, Perplexity : 99.89\n", "[07/40] mean_loss : 4.59, Perplexity : 98.97\n", "[07/40] mean_loss : 4.55, Perplexity : 94.97\n", "[08/40] mean_loss : 4.56, Perplexity : 95.54\n", "[08/40] mean_loss : 4.56, Perplexity : 95.67\n", "[08/40] mean_loss : 4.52, Perplexity : 91.98\n", "[09/40] mean_loss : 4.53, Perplexity : 92.61\n", "[09/40] mean_loss : 4.53, Perplexity : 92.79\n", "[09/40] mean_loss : 4.50, Perplexity : 89.63\n", "[10/40] mean_loss : 4.50, Perplexity : 90.13\n", "[10/40] mean_loss : 4.50, Perplexity : 90.19\n", "[10/40] mean_loss : 4.47, Perplexity : 87.11\n", "[11/40] mean_loss : 4.48, Perplexity : 88.11\n", "[11/40] mean_loss : 4.48, Perplexity : 88.26\n", "[11/40] mean_loss : 4.45, Perplexity : 86.05\n", "[12/40] mean_loss : 4.46, Perplexity : 86.81\n", "[12/40] mean_loss : 4.47, Perplexity : 87.03\n", "[12/40] mean_loss : 4.43, Perplexity : 84.04\n", "[13/40] mean_loss : 4.45, Perplexity : 85.27\n", "[13/40] mean_loss : 4.45, Perplexity : 85.83\n", "[13/40] mean_loss : 4.42, Perplexity : 83.33\n", "[14/40] mean_loss : 4.43, Perplexity : 84.15\n", "[14/40] mean_loss : 4.43, Perplexity : 84.31\n", "[14/40] mean_loss : 4.41, Perplexity : 82.29\n", "[15/40] mean_loss : 4.43, Perplexity : 83.82\n", "[15/40] mean_loss : 4.43, Perplexity : 83.70\n", "[15/40] mean_loss : 4.40, Perplexity : 81.59\n", "[16/40] mean_loss : 4.42, Perplexity : 83.06\n", "[16/40] mean_loss : 4.42, Perplexity : 83.29\n", "[16/40] mean_loss : 4.39, Perplexity : 80.89\n", "[17/40] mean_loss : 4.41, Perplexity : 82.44\n", "[17/40] mean_loss : 4.41, Perplexity : 82.51\n", "[17/40] mean_loss : 4.39, Perplexity : 80.59\n", "[18/40] mean_loss : 4.40, Perplexity : 81.59\n", "[18/40] mean_loss : 4.41, Perplexity : 82.21\n", "[18/40] mean_loss : 4.38, Perplexity : 79.87\n", "[19/40] mean_loss : 4.40, Perplexity : 81.43\n", "[19/40] mean_loss : 4.40, Perplexity : 81.67\n", "[19/40] mean_loss : 4.37, Perplexity : 79.28\n", "[20/40] mean_loss : 4.40, Perplexity : 81.18\n", "[20/40] mean_loss : 4.40, Perplexity : 81.17\n", "[20/40] mean_loss : 4.37, Perplexity : 79.11\n", "[21/40] mean_loss : 4.40, Perplexity : 81.44\n", "[21/40] mean_loss : 4.34, Perplexity : 76.43\n", "[21/40] mean_loss : 4.21, Perplexity : 67.17\n", "[22/40] mean_loss : 4.26, Perplexity : 70.84\n", "[22/40] mean_loss : 4.26, Perplexity : 70.75\n", "[22/40] mean_loss : 4.17, Perplexity : 64.99\n", "[23/40] mean_loss : 4.22, Perplexity : 68.36\n", "[23/40] mean_loss : 4.22, Perplexity : 67.82\n", "[23/40] mean_loss : 4.15, Perplexity : 63.74\n", "[24/40] mean_loss : 4.20, Perplexity : 66.66\n", "[24/40] mean_loss : 4.20, Perplexity : 66.43\n", "[24/40] mean_loss : 4.14, Perplexity : 62.85\n", "[25/40] mean_loss : 4.18, Perplexity : 65.53\n", "[25/40] mean_loss : 4.17, Perplexity : 64.99\n", "[25/40] mean_loss : 4.13, Perplexity : 61.94\n", "[26/40] mean_loss : 4.17, Perplexity : 64.61\n", "[26/40] mean_loss : 4.16, Perplexity : 64.34\n", "[26/40] mean_loss : 4.12, Perplexity : 61.27\n", "[27/40] mean_loss : 4.15, Perplexity : 63.73\n", "[27/40] mean_loss : 4.15, Perplexity : 63.32\n", "[27/40] mean_loss : 4.11, Perplexity : 60.87\n", "[28/40] mean_loss : 4.14, Perplexity : 62.96\n", "[28/40] mean_loss : 4.14, Perplexity : 63.01\n", "[28/40] mean_loss : 4.10, Perplexity : 60.33\n", "[29/40] mean_loss : 4.14, Perplexity : 62.54\n", "[29/40] mean_loss : 4.13, Perplexity : 62.36\n", "[29/40] mean_loss : 4.10, Perplexity : 60.06\n", "[30/40] mean_loss : 4.13, Perplexity : 62.05\n", "[30/40] mean_loss : 4.13, Perplexity : 61.91\n", "[30/40] mean_loss : 4.09, Perplexity : 59.46\n", "[31/40] mean_loss : 4.12, Perplexity : 61.45\n", "[31/40] mean_loss : 4.11, Perplexity : 61.24\n", "[31/40] mean_loss : 4.08, Perplexity : 59.12\n", "[32/40] mean_loss : 4.11, Perplexity : 61.03\n", "[32/40] mean_loss : 4.11, Perplexity : 60.88\n", "[32/40] mean_loss : 4.07, Perplexity : 58.69\n", "[33/40] mean_loss : 4.11, Perplexity : 60.71\n", "[33/40] mean_loss : 4.10, Perplexity : 60.57\n", "[33/40] mean_loss : 4.07, Perplexity : 58.38\n", "[34/40] mean_loss : 4.10, Perplexity : 60.33\n", "[34/40] mean_loss : 4.10, Perplexity : 60.23\n", "[34/40] mean_loss : 4.06, Perplexity : 58.06\n", "[35/40] mean_loss : 4.09, Perplexity : 60.00\n", "[35/40] mean_loss : 4.09, Perplexity : 59.74\n", "[35/40] mean_loss : 4.06, Perplexity : 57.75\n", "[36/40] mean_loss : 4.09, Perplexity : 59.58\n", "[36/40] mean_loss : 4.09, Perplexity : 59.47\n", "[36/40] mean_loss : 4.05, Perplexity : 57.59\n", "[37/40] mean_loss : 4.08, Perplexity : 59.30\n", "[37/40] mean_loss : 4.08, Perplexity : 59.11\n", "[37/40] mean_loss : 4.05, Perplexity : 57.11\n", "[38/40] mean_loss : 4.08, Perplexity : 58.98\n", "[38/40] mean_loss : 4.07, Perplexity : 58.70\n", "[38/40] mean_loss : 4.04, Perplexity : 57.10\n", "[39/40] mean_loss : 4.07, Perplexity : 58.79\n", "[39/40] mean_loss : 4.07, Perplexity : 58.58\n", "[39/40] mean_loss : 4.04, Perplexity : 56.79\n" ] } ], "source": [ "for epoch in range(EPOCH):\n", " total_loss = 0\n", " losses = []\n", " hidden = model.init_hidden(BATCH_SIZE)\n", " for i,batch in enumerate(getBatch(train_data, SEQ_LENGTH)):\n", " inputs, targets = batch\n", " hidden = model.detach_hidden(hidden)\n", " model.zero_grad()\n", " preds, hidden = model(inputs, hidden, True)\n", "\n", " loss = loss_function(preds, targets.view(-1))\n", " losses.append(loss.data[0])\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) # gradient clipping\n", " optimizer.step()\n", "\n", " if i > 0 and i % 500 == 0:\n", " print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, np.mean(losses), np.exp(np.mean(losses))))\n", " losses = []\n", " \n", " # learning rate anealing\n", " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n", " if RESCHEDULED == False and epoch == EPOCH//2:\n", " LR *= 0.1\n", " optimizer = optim.Adam(model.parameters(), lr=LR)\n", " RESCHEDULED = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test " ] }, { "cell_type": "code", "execution_count": 189, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Perpelexity : 155.89\n" ] } ], "source": [ "total_loss = 0\n", "hidden = model.init_hidden(BATCH_SIZE//2)\n", "for batch in getBatch(test_data, SEQ_LENGTH):\n", " inputs,targets = batch\n", " \n", " hidden = model.detach_hidden(hidden)\n", " model.zero_grad()\n", " preds, hidden = model(inputs, hidden)\n", " total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n", "\n", "total_loss = total_loss[0]/test_data.size(1)\n", "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Further topics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Pointer Sentinel Mixture Models\n", "* Regularizing and Optimizing LSTM Language Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }