{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Skip-gram with naiive softmax " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I recommend you take a look at these material first." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n", "* https://arxiv.org/abs/1301.3781\n", "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import nltk\n", "import random\n", "import numpy as np\n", "from collections import Counter\n", "flatten = lambda l: [item for sublist in l for item in sublist]\n", "random.seed(1024)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.3.0.post4\n", "3.2.4\n" ] } ], "source": [ "print(torch.__version__)\n", "print(nltk.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", "torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def getBatch(batch_size, train_data):\n", " random.shuffle(train_data)\n", " sindex = 0\n", " eindex = batch_size\n", " while eindex < len(train_data):\n", " batch = train_data[sindex: eindex]\n", " temp = eindex\n", " eindex = eindex + batch_size\n", " sindex = temp\n", " yield batch\n", " \n", " if eindex >= len(train_data):\n", " batch = train_data[sindex:]\n", " yield batch" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_sequence(seq, word2index):\n", " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", " return Variable(LongTensor(idxs))\n", "\n", "def prepare_word(word, word2index):\n", " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data load and Preprocessing " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load corpus : Gutenberg corpus" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you don't have gutenberg corpus, you can download it first using nltk.download()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['austen-emma.txt',\n", " 'austen-persuasion.txt',\n", " 'austen-sense.txt',\n", " 'bible-kjv.txt',\n", " 'blake-poems.txt',\n", " 'bryant-stories.txt',\n", " 'burgess-busterbrown.txt',\n", " 'carroll-alice.txt',\n", " 'chesterton-ball.txt',\n", " 'chesterton-brown.txt',\n", " 'chesterton-thursday.txt',\n", " 'edgeworth-parents.txt',\n", " 'melville-moby_dick.txt',\n", " 'milton-paradise.txt',\n", " 'shakespeare-caesar.txt',\n", " 'shakespeare-hamlet.txt',\n", " 'shakespeare-macbeth.txt',\n", " 'whitman-leaves.txt']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nltk.corpus.gutenberg.fileids()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n", "corpus = [[word.lower() for word in sent] for sent in corpus]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extract Stopwords from unigram distribution's tails" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word_count = Counter(flatten(corpus))\n", "border = int(len(word_count) * 0.01) " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stopwords = word_count.most_common()[:border] + list(reversed(word_count.most_common()))[:border]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stopwords = [s[0] for s in stopwords]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stopwords" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build vocab" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vocab = list(set(flatten(corpus)) - set(stopwords))\n", "vocab.append('')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "592 583\n" ] } ], "source": [ "print(len(set(flatten(corpus))), len(vocab))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word2index = {'' : 0} \n", "\n", "for vo in vocab:\n", " if word2index.get(vo) is None:\n", " word2index[vo] = len(word2index)\n", "\n", "index2word = {v:k for k, v in word2index.items()} " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare train data " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "window data example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "WINDOW_SIZE = 3\n", "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "('', '', '', '[', 'moby', 'dick', 'by')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "windows[0]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n" ] } ], "source": [ "train_data = []\n", "\n", "for window in windows:\n", " for i in range(WINDOW_SIZE * 2 + 1):\n", " if i == WINDOW_SIZE or window[i] == '': \n", " continue\n", " train_data.append((window[WINDOW_SIZE], window[i]))\n", "\n", "print(train_data[:WINDOW_SIZE * 2])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_p = []\n", "y_p = []" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "('[', 'moby')" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data[0]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for tr in train_data:\n", " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n", " y_p.append(prepare_word(tr[1], word2index).view(1, -1))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_data = list(zip(X_p, y_p))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "7606" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Skipgram(nn.Module):\n", " \n", " def __init__(self, vocab_size, projection_dim):\n", " super(Skipgram,self).__init__()\n", " self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n", " self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n", "\n", " self.embedding_v.weight.data.uniform_(-1, 1) # init\n", " self.embedding_u.weight.data.uniform_(0, 0) # init\n", " #self.out = nn.Linear(projection_dim,vocab_size)\n", " def forward(self, center_words,target_words, outer_words):\n", " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", " outer_embeds = self.embedding_u(outer_words) # B x V x D\n", " \n", " scores = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n", " norm_scores = outer_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # BxVxD * BxDx1 => BxV\n", " \n", " nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores), 1).unsqueeze(1))) # log-softmax\n", " \n", " return nll # negative log likelihood\n", " \n", " def prediction(self, inputs):\n", " embeds = self.embedding_v(inputs)\n", " \n", " return embeds " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train " ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "collapsed": true }, "outputs": [], "source": [ "EMBEDDING_SIZE = 30\n", "BATCH_SIZE = 256\n", "EPOCH = 100" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "collapsed": true }, "outputs": [], "source": [ "losses = []\n", "model = Skipgram(len(word2index), EMBEDDING_SIZE)\n", "if USE_CUDA:\n", " model = model.cuda()\n", "optimizer = optim.Adam(model.parameters(), lr=0.01)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch : 0, mean_loss : 6.20\n", "Epoch : 10, mean_loss : 4.38\n", "Epoch : 20, mean_loss : 3.48\n", "Epoch : 30, mean_loss : 3.31\n", "Epoch : 40, mean_loss : 3.26\n", "Epoch : 50, mean_loss : 3.24\n", "Epoch : 60, mean_loss : 3.22\n", "Epoch : 70, mean_loss : 3.22\n", "Epoch : 80, mean_loss : 3.21\n", "Epoch : 90, mean_loss : 3.20\n" ] } ], "source": [ "for epoch in range(EPOCH):\n", " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", " \n", " inputs, targets = zip(*batch)\n", " \n", " inputs = torch.cat(inputs) # B x 1\n", " targets = torch.cat(targets) # B x 1\n", " vocabs = prepare_sequence(list(vocab), word2index).expand(inputs.size(0), len(vocab)) # B x V\n", " model.zero_grad()\n", "\n", " loss = model(inputs, targets, vocabs)\n", " \n", " loss.backward()\n", " optimizer.step()\n", " \n", " losses.append(loss.data.tolist()[0])\n", "\n", " if epoch % 10 == 0:\n", " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n", " losses = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def word_similarity(target, vocab):\n", " if USE_CUDA:\n", " target_V = model.prediction(prepare_word(target, word2index))\n", " else:\n", " target_V = model.prediction(prepare_word(target, word2index))\n", " similarities = []\n", " for i in range(len(vocab)):\n", " if vocab[i] == target: continue\n", " \n", " if USE_CUDA:\n", " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", " else:\n", " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n", " similarities.append([vocab[i], cosine_sim])\n", " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'least'" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test = random.choice(list(vocab))\n", "test" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[['at', 0.8147411346435547],\n", " ['every', 0.7143548130989075],\n", " ['case', 0.6975079774856567],\n", " ['secure', 0.6121522188186646],\n", " ['heart', 0.5974172949790955],\n", " ['including', 0.5867112278938293],\n", " ['please', 0.5557640194892883],\n", " ['has', 0.5536234974861145],\n", " ['while', 0.5366998314857483],\n", " ['you', 0.509368896484375]]" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word_similarity(test, vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }