{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Skip-gram with negative sampling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I recommend you take a look at these material first." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n", "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import nltk\n", "import random\n", "import numpy as np\n", "from collections import Counter\n", "flatten = lambda l: [item for sublist in l for item in sublist]\n", "random.seed(1024)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.3.0.post4\n", "3.2.4\n" ] } ], "source": [ "print(torch.__version__)\n", "print(nltk.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", "torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def getBatch(batch_size, train_data):\n", " random.shuffle(train_data)\n", " sindex = 0\n", " eindex = batch_size\n", " while eindex < len(train_data):\n", " batch = train_data[sindex: eindex]\n", " temp = eindex\n", " eindex = eindex + batch_size\n", " sindex = temp\n", " yield batch\n", " \n", " if eindex >= len(train_data):\n", " batch = train_data[sindex:]\n", " yield batch" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_sequence(seq, word2index):\n", " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n", " return Variable(LongTensor(idxs))\n", "\n", "def prepare_word(word, word2index):\n", " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data load and Preprocessing " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n", "corpus = [[word.lower() for word in sent] for sent in corpus]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exclude sparse words " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word_count = Counter(flatten(corpus))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "MIN_COUNT = 3\n", "exclude = []" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for w, c in word_count.items():\n", " if c < MIN_COUNT:\n", " exclude.append(w)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare train data " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vocab = list(set(flatten(corpus)) - set(exclude))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word2index = {}\n", "for vo in vocab:\n", " if word2index.get(vo) is None:\n", " word2index[vo] = len(word2index)\n", " \n", "index2word = {v:k for k, v in word2index.items()}" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "WINDOW_SIZE = 5\n", "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n", "\n", "train_data = []\n", "\n", "for window in windows:\n", " for i in range(WINDOW_SIZE * 2 + 1):\n", " if window[i] in exclude or window[WINDOW_SIZE] in exclude: \n", " continue # min_count\n", " if i == WINDOW_SIZE or window[i] == '': \n", " continue\n", " train_data.append((window[WINDOW_SIZE], window[i]))\n", "\n", "X_p = []\n", "y_p = []\n", "\n", "for tr in train_data:\n", " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n", " y_p.append(prepare_word(tr[1], word2index).view(1, -1))\n", " \n", "train_data = list(zip(X_p, y_p))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "50242" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build Unigram Distribution**0.75 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$P(w)=U(w)^{3/4}/Z$$" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "Z = 0.001" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word_count = Counter(flatten(corpus))\n", "num_total_words = sum([c for w, c in word_count.items() if w not in exclude])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "unigram_table = []\n", "\n", "for vo in vocab:\n", " unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "478 3500\n" ] } ], "source": [ "print(len(vocab), len(unigram_table))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Negative Sampling " ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def negative_sampling(targets, unigram_table, k):\n", " batch_size = targets.size(0)\n", " neg_samples = []\n", " for i in range(batch_size):\n", " nsample = []\n", " target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n", " while len(nsample) < k: # num of sampling\n", " neg = random.choice(unigram_table)\n", " if word2index[neg] == target_index:\n", " continue\n", " nsample.append(neg)\n", " neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))\n", " \n", " return torch.cat(neg_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf
" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class SkipgramNegSampling(nn.Module):\n", " \n", " def __init__(self, vocab_size, projection_dim):\n", " super(SkipgramNegSampling, self).__init__()\n", " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n", " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n", " self.logsigmoid = nn.LogSigmoid()\n", " \n", " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n", " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n", " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n", " \n", " def forward(self, center_words, target_words, negative_words):\n", " center_embeds = self.embedding_v(center_words) # B x 1 x D\n", " target_embeds = self.embedding_u(target_words) # B x 1 x D\n", " \n", " neg_embeds = -self.embedding_u(negative_words) # B x K x D\n", " \n", " positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n", " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1\n", " \n", " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n", " \n", " return -torch.mean(loss)\n", " \n", " def prediction(self, inputs):\n", " embeds = self.embedding_v(inputs)\n", " \n", " return embeds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train " ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": true }, "outputs": [], "source": [ "EMBEDDING_SIZE = 30 \n", "BATCH_SIZE = 256\n", "EPOCH = 100\n", "NEG = 10 # Num of Negative Sampling" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "collapsed": true }, "outputs": [], "source": [ "losses = []\n", "model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)\n", "if USE_CUDA:\n", " model = model.cuda()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch : 0, mean_loss : 1.06\n", "Epoch : 10, mean_loss : 0.86\n", "Epoch : 20, mean_loss : 0.79\n", "Epoch : 30, mean_loss : 0.74\n", "Epoch : 40, mean_loss : 0.71\n", "Epoch : 50, mean_loss : 0.69\n", "Epoch : 60, mean_loss : 0.67\n", "Epoch : 70, mean_loss : 0.65\n", "Epoch : 80, mean_loss : 0.64\n", "Epoch : 90, mean_loss : 0.63\n" ] } ], "source": [ "for epoch in range(EPOCH):\n", " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", " \n", " inputs, targets = zip(*batch)\n", " \n", " inputs = torch.cat(inputs) # B x 1\n", " targets = torch.cat(targets) # B x 1\n", " negs = negative_sampling(targets, unigram_table, NEG)\n", " model.zero_grad()\n", "\n", " loss = model(inputs, targets, negs)\n", " \n", " loss.backward()\n", " optimizer.step()\n", " \n", " losses.append(loss.data.tolist()[0])\n", " if epoch % 10 == 0:\n", " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n", " losses = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test " ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def word_similarity(target, vocab):\n", " if USE_CUDA:\n", " target_V = model.prediction(prepare_word(target, word2index))\n", " else:\n", " target_V = model.prediction(prepare_word(target, word2index))\n", " similarities = []\n", " for i in range(len(vocab)):\n", " if vocab[i] == target: \n", " continue\n", " \n", " if USE_CUDA:\n", " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", " else:\n", " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n", " \n", " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]\n", " similarities.append([vocab[i], cosine_sim])\n", " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]" ] }, { "cell_type": "code", "execution_count": 212, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'passengers'" ] }, "execution_count": 212, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test = random.choice(list(vocab))\n", "test" ] }, { "cell_type": "code", "execution_count": 213, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[['am', 0.7353377342224121],\n", " ['passenger', 0.7154150605201721],\n", " ['cook', 0.6829826831817627],\n", " ['new', 0.6648461818695068],\n", " ['bedford', 0.6283411383628845],\n", " ['besides', 0.5972960591316223],\n", " ['themselves', 0.5964340567588806],\n", " ['grow', 0.5957046151161194],\n", " ['tell', 0.5952941179275513],\n", " ['get', 0.5943044424057007]]" ] }, "execution_count": 213, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word_similarity(test, vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }