{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Doc2vec from scratch in PyTorch\n",
    "===============================\n",
    "\n",
    "Here we are implementing this useful algorithm with a library we know and trust. With luck this will be more accessible than reading the papers but more in-depth than typical \"install gensim and just do what I say\" tutorials, and still easy to understand for anyone whose maths skills have atrophied to nothing (like me). This is all based on the great work by [Nejc Ilenic](https://github.com/inejc/paragraph-vectors) and reading the referenced papers and gensim's source.\n",
    "\n",
    "`doc2vec` descends from `word2vec`, the basic form of which is that it is a model trained to predict the missing word in a context. Given sentences like \"the cat ___ on the mat\" it should predict \"sat\", and in doing so learn a useful representation of words. We can then extract the internal weights and re-use them as \"word embeddings\", vectors giving each word a position in N-dimensional space that is hopefully close to similar words and an appropriate distance from related words. \n",
    "\n",
    "`doc2vec` or \"Paragraph vectors\" extends the `word2vec` idea by simply adding a document id to each context. This helps the network learn associations between contexts and produces vectors that position each paragraph (document) in space."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need to load the data. We'll begin by overfitting on a tiny dataset just to check all the parts fit together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>In the week before their departure to Arrakis, when all the final scurrying about had reached a ...</td>\n",
       "      <td>[in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...</td>\n",
       "      <td>[it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...</td>\n",
       "      <td>[the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...</td>\n",
       "      <td>[by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                  text  \\\n",
       "0  In the week before their departure to Arrakis, when all the final scurrying about had reached a ...   \n",
       "1  It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...   \n",
       "2  The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...   \n",
       "3  By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...   \n",
       "\n",
       "                                                                                                tokens  \n",
       "0  [in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...  \n",
       "1  [it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...  \n",
       "2  [the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...  \n",
       "3  [by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...  "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import spacy\n",
    "\n",
    "nlp = spacy.load(\"en_core_web_sm\")\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", 100)\n",
    "\n",
    "example_df = pd.read_csv(\"data/example.csv\")\n",
    "\n",
    "def tokenize_text(df):\n",
    "    df[\"tokens\"] = df.text.str.lower().str.strip().apply(lambda x: [token.text.strip() for token in nlp(x) if token.text.isalnum()])\n",
    "    return df\n",
    "\n",
    "example_df = tokenize_text(example_df)\n",
    "\n",
    "example_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will need to construct a vocabulary so we can reference every word by an ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset comprises 4 documents and 106 unique words (over the limit of 1 occurrences)\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "\n",
    "class Vocab:\n",
    "    def __init__(self, all_tokens, min_count=2):\n",
    "        self.min_count = min_count\n",
    "        self.freqs = {t:n for t, n in Counter(all_tokens).items() if n >= min_count}\n",
    "        self.words = sorted(self.freqs.keys())\n",
    "        self.word2idx = {w: i for i, w in enumerate(self.words)}\n",
    "        \n",
    "vocab = Vocab([tok for tokens in example_df.tokens for tok in tokens], min_count=1)\n",
    "\n",
    "print(f\"Dataset comprises {len(example_df)} documents and {len(vocab.words)} unique words (over the limit of {vocab.min_count} occurrences)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Words that appear extremely rarely can harm performance, so we add a simple mechanism to strip those out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>tokens</th>\n",
       "      <th>length</th>\n",
       "      <th>clean_tokens</th>\n",
       "      <th>clean_length</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>In the week before their departure to Arrakis, when all the final scurrying about had reached a ...</td>\n",
       "      <td>[in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...</td>\n",
       "      <td>32</td>\n",
       "      <td>[in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...</td>\n",
       "      <td>[it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...</td>\n",
       "      <td>39</td>\n",
       "      <td>[it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...</td>\n",
       "      <td>[the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...</td>\n",
       "      <td>34</td>\n",
       "      <td>[the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...</td>\n",
       "      <td>[by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...</td>\n",
       "      <td>53</td>\n",
       "      <td>[by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...</td>\n",
       "      <td>53</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                  text  \\\n",
       "0  In the week before their departure to Arrakis, when all the final scurrying about had reached a ...   \n",
       "1  It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...   \n",
       "2  The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...   \n",
       "3  By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...   \n",
       "\n",
       "                                                                                                tokens  \\\n",
       "0  [in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...   \n",
       "1  [it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...   \n",
       "2  [the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...   \n",
       "3  [by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...   \n",
       "\n",
       "   length  \\\n",
       "0      32   \n",
       "1      39   \n",
       "2      34   \n",
       "3      53   \n",
       "\n",
       "                                                                                          clean_tokens  \\\n",
       "0  [in, the, week, before, their, departure, to, arrakis, when, all, the, final, scurrying, about, ...   \n",
       "1  [it, was, a, warm, night, at, castle, caladan, and, the, ancient, pile, of, stone, that, had, se...   \n",
       "2  [the, old, woman, was, let, in, by, the, side, door, down, the, vaulted, passage, by, paul, room...   \n",
       "3  [by, the, half, light, of, a, suspensor, lamp, dimmed, and, hanging, near, the, floor, the, awak...   \n",
       "\n",
       "   clean_length  \n",
       "0            32  \n",
       "1            39  \n",
       "2            34  \n",
       "3            53  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def clean_tokens(df, vocab):\n",
    "    df[\"length\"] = df.tokens.apply(len)\n",
    "    df[\"clean_tokens\"] = df.tokens.apply(lambda x: [t for t in x if t in vocab.freqs.keys()])\n",
    "    df[\"clean_length\"] = df.clean_tokens.apply(len)\n",
    "    return df\n",
    "\n",
    "example_df = clean_tokens(example_df, vocab)\n",
    "example_df[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The difficulty with our \"the cat _ on the mat\" problem is that the missing word could be any one in the vocabulary V and so the network would have |V| outputs for each input e.g. a huge vector containing zero for every word in the vocabulary and some positive number for \"sat\" if the network was perfectly trained. For calculating loss we need to turn that into a probabilty distribution, i.e. _softmax_ it. Computing the softmax for such a large vector is expensive.\n",
    "\n",
    "So the trick (one of many possible) we will use is _Noise Contrastive Estimation (NCE)_. We change our \"the cat _ on the mat\" problem into a multiple choice problem, asking the network to choose between \"sat\" and some random wrong answers like \"hopscotch\" and \"luxuriated\". This is easier to compute the softmax for since it's now a binary classifier (right or wrong answer) and the output is simply of a vector of size 1 + k where k is the number of random incorrect options.\n",
    "\n",
    "Happily, this alternative problem still learns equally useful word representations. We just need to adjust the examples and the loss function. There is a simplified version of the NCE loss function called _Negative Sampling (NEG)_ that we can use here.\n",
    "\n",
    "[Notes on Noise Contrastive Estimation and Negative Sampling (C. Dyer)](https://arxiv.org/abs/1410.8251) explains the derivation of the NCE and NEG loss functions.\n",
    "\n",
    "When we implement the loss function, we assume that the first element in a samples/scores vector is the score for the positive sample and the rest are negative samples. This convention saves us from having to pass around an auxiliary vector indicating which sample was positive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class NegativeSampling(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NegativeSampling, self).__init__()\n",
    "        self.log_sigmoid = nn.LogSigmoid()\n",
    "    def forward(self, scores):\n",
    "        batch_size = scores.shape[0]\n",
    "        n_negative_samples = scores.shape[1] - 1   # TODO average or sum the negative samples? Summing seems to be correct by the paper\n",
    "        positive = self.log_sigmoid(scores[:,0])\n",
    "        negatives = torch.sum(self.log_sigmoid(-scores[:,1:]), dim=1)\n",
    "        return -torch.sum(positive + negatives) / batch_size  # average for batch\n",
    "\n",
    "loss = NegativeSampling()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's helpful to play with some values to reassure ourselves that this function does the right thing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>scores</th>\n",
       "      <th>loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[1, -1, -1, -1]</td>\n",
       "      <td>tensor(1.2530)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0.5, -1, -1, -1]</td>\n",
       "      <td>tensor(1.4139)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, -1, -1, -1]</td>\n",
       "      <td>tensor(1.6329)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 0, 0, 0]</td>\n",
       "      <td>tensor(2.7726)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0, 0, 0, 1]</td>\n",
       "      <td>tensor(3.3927)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[0, 1, 1, 1]</td>\n",
       "      <td>tensor(4.6329)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[0.5, 1, 1, 1]</td>\n",
       "      <td>tensor(4.4139)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[1, 1, 1, 1]</td>\n",
       "      <td>tensor(4.2530)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              scores            loss\n",
       "0    [1, -1, -1, -1]  tensor(1.2530)\n",
       "1  [0.5, -1, -1, -1]  tensor(1.4139)\n",
       "2    [0, -1, -1, -1]  tensor(1.6329)\n",
       "3       [0, 0, 0, 0]  tensor(2.7726)\n",
       "4       [0, 0, 0, 1]  tensor(3.3927)\n",
       "5       [0, 1, 1, 1]  tensor(4.6329)\n",
       "6     [0.5, 1, 1, 1]  tensor(4.4139)\n",
       "7       [1, 1, 1, 1]  tensor(4.2530)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch \n",
    "\n",
    "data = [[[1, -1, -1, -1]],  # this dummy data uses -1 to 1, but the real model is unconstrained\n",
    "        [[0.5, -1, -1, -1]],\n",
    "        [[0, -1, -1, -1]],\n",
    "        [[0, 0, 0, 0]],\n",
    "        [[0, 0, 0, 1]],\n",
    "        [[0, 1, 1, 1]],\n",
    "        [[0.5, 1, 1, 1]],\n",
    "        [[1, 1, 1, 1]]]\n",
    "\n",
    "loss_df = pd.DataFrame(data, columns=[\"scores\"])\n",
    "loss_df[\"loss\"] = loss_df.scores.apply(lambda x: loss(torch.FloatTensor([x])))\n",
    "\n",
    "loss_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Higher scores for the positive sample (always the first element) reduce the loss but higher scores for the negative samples increase the loss. This looks like the right behaviour."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that in the bag, let's look at creating training data. The general idea is to create a set of examples where each example has:\n",
    "\n",
    "- doc id\n",
    "- sample ids - a collection of the target token and some noise tokens\n",
    "- context ids - tokens before and after the target token\n",
    "\n",
    "e.g. If our context size was 2, the first example from the above dataset would be:\n",
    "\n",
    "```\n",
    "{\"doc_id\": 0,\n",
    " \"sample_ids\": [word2idx[x] for x in [\"week\", \"random-word-from-vocab\", \"random-word-from-vocab\"...],\n",
    " \"context_ids\": [word2idx[x] for x in [\"in\", \"the\", \"before\", \"their\"]]}\n",
    " ```\n",
    " \n",
    " The random words are chosen according to a probability distribution:\n",
    " \n",
    " > a unigram distribution raised to the 3/4rd power, as proposed by T. Mikolov et al. in Distributed Representations of Words and Phrases and their Compositionality\n",
    "\n",
    "This has the effect of slightly increasing the relative probability of rare words (look at the graph of `y=x^0.75` below and see how the lower end is raised above `y=x`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.vegalite.v2+json": {
       "$schema": "https://vega.github.io/schema/vega-lite/v2.6.0.json",
       "config": {
        "view": {
         "height": 300,
         "width": 400
        }
       },
       "data": {
        "name": "data-afc3dc0951a9e12875119a5af86e52b5"
       },
       "datasets": {
        "data-afc3dc0951a9e12875119a5af86e52b5": [
         {
          "x": 0,
          "y": 0
         },
         {
          "x": 0.01,
          "y": 0.03162277660168379
         },
         {
          "x": 0.02,
          "y": 0.053182958969449884
         },
         {
          "x": 0.03,
          "y": 0.07208434242404263
         },
         {
          "x": 0.04,
          "y": 0.08944271909999159
         },
         {
          "x": 0.05,
          "y": 0.10573712634405642
         },
         {
          "x": 0.06,
          "y": 0.12123093028059741
         },
         {
          "x": 0.07,
          "y": 0.13608915892697748
         },
         {
          "x": 0.08,
          "y": 0.15042412372345573
         },
         {
          "x": 0.09,
          "y": 0.16431676725154984
         },
         {
          "x": 0.1,
          "y": 0.1778279410038923
         },
         {
          "x": 0.11,
          "y": 0.19100490227716513
         },
         {
          "x": 0.12,
          "y": 0.2038853093816547
         },
         {
          "x": 0.13,
          "y": 0.2164998073464082
         },
         {
          "x": 0.14,
          "y": 0.22887377179317683
         },
         {
          "x": 0.15,
          "y": 0.2410285256833955
         },
         {
          "x": 0.16,
          "y": 0.25298221281347033
         },
         {
          "x": 0.17,
          "y": 0.26475044029330763
         },
         {
          "x": 0.18,
          "y": 0.2763467610958144
         },
         {
          "x": 0.19,
          "y": 0.28778304315451386
         },
         {
          "x": 0.2,
          "y": 0.29906975624424414
         },
         {
          "x": 0.21,
          "y": 0.3102161981490854
         },
         {
          "x": 0.22,
          "y": 0.32123067524150845
         },
         {
          "x": 0.23,
          "y": 0.33212064831351956
         },
         {
          "x": 0.24,
          "y": 0.34289285156385596
         },
         {
          "x": 0.25,
          "y": 0.3535533905932738
         },
         {
          "x": 0.26,
          "y": 0.3641078238014289
         },
         {
          "x": 0.27,
          "y": 0.37456123052590357
         },
         {
          "x": 0.28,
          "y": 0.38491826849295824
         },
         {
          "x": 0.29,
          "y": 0.39518322257770583
         },
         {
          "x": 0.3,
          "y": 0.4053600464421103
         },
         {
          "x": 0.31,
          "y": 0.41545239829339137
         },
         {
          "x": 0.32,
          "y": 0.42546367175559907
         },
         {
          "x": 0.33,
          "y": 0.43539702265375557
         },
         {
          "x": 0.34,
          "y": 0.4452553923589699
         },
         {
          "x": 0.35000000000000003,
          "y": 0.45504152822405847
         },
         {
          "x": 0.36,
          "y": 0.46475800154489
         },
         {
          "x": 0.37,
          "y": 0.47440722340731084
         },
         {
          "x": 0.38,
          "y": 0.4839914587188715
         },
         {
          "x": 0.39,
          "y": 0.4935128386754873
         },
         {
          "x": 0.4,
          "y": 0.5029733718731741
         },
         {
          "x": 0.41000000000000003,
          "y": 0.5123749542422491
         },
         {
          "x": 0.42,
          "y": 0.5217193779544038
         },
         {
          "x": 0.43,
          "y": 0.5310083394307343
         },
         {
          "x": 0.44,
          "y": 0.5402434465602292
         },
         {
          "x": 0.45,
          "y": 0.549426225222706
         },
         {
          "x": 0.46,
          "y": 0.5585581251971565
         },
         {
          "x": 0.47000000000000003,
          "y": 0.5676405255254853
         },
         {
          "x": 0.48,
          "y": 0.576674739392341
         },
         {
          "x": 0.49,
          "y": 0.5856620185738529
         },
         {
          "x": 0.5,
          "y": 0.5946035575013605
         },
         {
          "x": 0.51,
          "y": 0.6035004969804791
         },
         {
          "x": 0.52,
          "y": 0.6123539276009055
         },
         {
          "x": 0.53,
          "y": 0.6211648928681236
         },
         {
          "x": 0.54,
          "y": 0.629934392084505
         },
         {
          "x": 0.55,
          "y": 0.6386633830041155
         },
         {
          "x": 0.56,
          "y": 0.6473527842827909
         },
         {
          "x": 0.5700000000000001,
          "y": 0.6560034777426358
         },
         {
          "x": 0.58,
          "y": 0.6646163104680073
         },
         {
          "x": 0.59,
          "y": 0.6731920967482075
         },
         {
          "x": 0.6,
          "y": 0.6817316198804996
         },
         {
          "x": 0.61,
          "y": 0.6902356338456498
         },
         {
          "x": 0.62,
          "y": 0.6987048648669424
         },
         {
          "x": 0.63,
          "y": 0.7071400128625219
         },
         {
          "x": 0.64,
          "y": 0.7155417527999327
         },
         {
          "x": 0.65,
          "y": 0.7239107359608682
         },
         {
          "x": 0.66,
          "y": 0.7322475911233668
         },
         {
          "x": 0.67,
          "y": 0.7405529256680135
         },
         {
          "x": 0.68,
          "y": 0.7488273266140879
         },
         {
          "x": 0.6900000000000001,
          "y": 0.7570713615910638
         },
         {
          "x": 0.7000000000000001,
          "y": 0.7652855797503655
         },
         {
          "x": 0.71,
          "y": 0.7734705126218591
         },
         {
          "x": 0.72,
          "y": 0.7816266749191567
         },
         {
          "x": 0.73,
          "y": 0.7897545652974598
         },
         {
          "x": 0.74,
          "y": 0.7978546670673515
         },
         {
          "x": 0.75,
          "y": 0.8059274488676564
         },
         {
          "x": 0.76,
          "y": 0.8139733653002305
         },
         {
          "x": 0.77,
          "y": 0.8219928575293057
         },
         {
          "x": 0.78,
          "y": 0.829986353847804
         },
         {
          "x": 0.79,
          "y": 0.837954270212839
         },
         {
          "x": 0.8,
          "y": 0.8458970107524514
         },
         {
          "x": 0.81,
          "y": 0.8538149682454624
         },
         {
          "x": 0.8200000000000001,
          "y": 0.8617085245761865
         },
         {
          "x": 0.8300000000000001,
          "y": 0.869578051165608
         },
         {
          "x": 0.84,
          "y": 0.8774239093805121
         },
         {
          "x": 0.85,
          "y": 0.8852464509219427
         },
         {
          "x": 0.86,
          "y": 0.8930460181942644
         },
         {
          "x": 0.87,
          "y": 0.9008229446560111
         },
         {
          "x": 0.88,
          "y": 0.9085775551536168
         },
         {
          "x": 0.89,
          "y": 0.9163101662390513
         },
         {
          "x": 0.9,
          "y": 0.9240210864723069
         },
         {
          "x": 0.91,
          "y": 0.9317106167096201
         },
         {
          "x": 0.92,
          "y": 0.9393790503782488
         },
         {
          "x": 0.93,
          "y": 0.9470266737385726
         },
         {
          "x": 0.9400000000000001,
          "y": 0.9546537661342305
         },
         {
          "x": 0.9500000000000001,
          "y": 0.9622606002309622
         },
         {
          "x": 0.96,
          "y": 0.9698474422447793
         },
         {
          "x": 0.97,
          "y": 0.9774145521600454
         },
         {
          "x": 0.98,
          "y": 0.9849621839380145
         },
         {
          "x": 0.99,
          "y": 0.992490585716335
         }
        ]
       },
       "encoding": {
        "x": {
         "field": "x",
         "type": "quantitative"
        },
        "y": {
         "field": "y",
         "type": "quantitative"
        }
       },
       "mark": "line",
       "title": "x^0.75"
      },
      "image/png": "",
      "text/plain": [
       "<VegaLite 2 object>\n",
       "\n",
       "If you see this message, it means the renderer has not been properly enabled\n",
       "for the frontend that you are using. For more information, see\n",
       "https://altair-viz.github.io/user_guide/troubleshooting.html\n"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import altair as alt\n",
    "import numpy as np\n",
    "\n",
    "data = pd.DataFrame(zip(np.arange(0,1,0.01), np.power(np.arange(0,1,0.01), 0.75)), columns=[\"x\", \"y\"])\n",
    "alt.Chart(data, title=\"x^0.75\").mark_line().encode(x=\"x\", y=\"y\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class NoiseDistribution:\n",
    "    def __init__(self, vocab):\n",
    "        self.probs = np.array([vocab.freqs[w] for w in vocab.words])\n",
    "        self.probs = np.power(self.probs, 0.75)\n",
    "        self.probs /= np.sum(self.probs)\n",
    "    def sample(self, n):\n",
    "        \"Returns the indices of n words randomly sampled from the vocabulary.\"\n",
    "        return np.random.choice(a=self.probs.shape[0], size=n, p=self.probs)\n",
    "        \n",
    "noise = NoiseDistribution(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this distribution, we advance through the documents creating examples. Note that we are always putting the positive sample first in the samples vector, following the convention the loss function expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def example_generator(df, context_size, noise, n_negative_samples, vocab):\n",
    "    for doc_id, doc in df.iterrows():\n",
    "        for i in range(context_size, len(doc.clean_tokens) - context_size):\n",
    "            positive_sample = vocab.word2idx[doc.clean_tokens[i]]\n",
    "            sample_ids = noise.sample(n_negative_samples).tolist()\n",
    "            # Fix a wee bug - ensure negative samples don't accidentally include the positive\n",
    "            sample_ids = [sample_id if sample_id != positive_sample else -1 for sample_id in sample_ids]\n",
    "            sample_ids.insert(0, positive_sample)                \n",
    "            context = doc.clean_tokens[i - context_size:i] + doc.clean_tokens[i + 1:i + context_size + 1]\n",
    "            context_ids = [vocab.word2idx[w] for w in context]\n",
    "            yield {\"doc_ids\": torch.tensor(doc_id),  # we use plural here because it will be batched\n",
    "                   \"sample_ids\": torch.tensor(sample_ids), \n",
    "                   \"context_ids\": torch.tensor(context_ids)}\n",
    "            \n",
    "examples = example_generator(example_df, context_size=5, noise=noise, n_negative_samples=5, vocab=vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we package this up as a good old PyTorch dataset and dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class NCEDataset(Dataset):\n",
    "    def __init__(self, examples):\n",
    "        self.examples = list(examples)  # just naively evaluate the whole damn thing - suboptimal!\n",
    "    def __len__(self):\n",
    "        return len(self.examples)\n",
    "    def __getitem__(self, index):\n",
    "        return self.examples[index]\n",
    "    \n",
    "dataset = NCEDataset(examples)\n",
    "dataloader = DataLoader(dataset, batch_size=2, drop_last=True, shuffle=True)  # TODO bigger batch size when not dummy data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's going to also be useful to have a way to convert batches back to a readable form for debugging, so we add a helper function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'doc_id': tensor(2),\n",
       "  'context': 'was allowed a moment to ____ in at him where he',\n",
       "  'context_ids': tensor([ 99,   5,   0,  61,  93,  52,  11,  48, 103,  47]),\n",
       "  'samples': ['peer', 'moment', 'atreides', 'a', 'caladan', 'night'],\n",
       "  'sample_ids': tensor([71, 61, 12,  0, 20, 65])},\n",
       " {'doc_id': tensor(3),\n",
       "  'context': 'mother the old woman was ____ witch shadow hair like matted',\n",
       "  'context_ids': tensor([ 62,  91,  67, 105,  99, 104,  79,  44,  59,  60]),\n",
       "  'samples': ['a', 'where', 'in', 'cooled', 'an', 'woman'],\n",
       "  'sample_ids': tensor([  0, 103,  52,  24,   6,  -1])}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def describe_batch(batch, vocab):\n",
    "    results = []\n",
    "    for doc_id, context_ids, sample_ids in zip(batch[\"doc_ids\"], batch[\"context_ids\"], batch[\"sample_ids\"]):\n",
    "        context = [vocab.words[i] for i in context_ids]\n",
    "        context.insert(len(context_ids) // 2, \"____\")\n",
    "        samples = [vocab.words[i] for i in sample_ids]\n",
    "        result = {\"doc_id\": doc_id,\n",
    "                  \"context\": \" \".join(context), \n",
    "                  \"context_ids\": context_ids, \n",
    "                  \"samples\": samples, \n",
    "                  \"sample_ids\": sample_ids}\n",
    "        results.append(result)\n",
    "    return results\n",
    "\n",
    "describe_batch(next(iter(dataloader)), vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's jump into creating the model itself. There isn't much to it - we multiply the input paragraph and word matrices by the output layer. Combining the paragraph and word matrices is done by summing here, but it could also be done by concatenating the inputs. The original paper actually found concatenation works better, perhaps because summing loses word order information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class DistributedMemory(nn.Module):\n",
    "    def __init__(self, vec_dim, n_docs, n_words):\n",
    "        super(DistributedMemory, self).__init__()\n",
    "        self.paragraph_matrix = nn.Parameter(torch.randn(n_docs, vec_dim))\n",
    "        self.word_matrix = nn.Parameter(torch.randn(n_words, vec_dim))\n",
    "        self.outputs = nn.Parameter(torch.zeros(vec_dim, n_words))\n",
    "    \n",
    "    def forward(self, doc_ids, context_ids, sample_ids):\n",
    "                                                                               # first add doc ids to context word ids to make the inputs\n",
    "        inputs = torch.add(self.paragraph_matrix[doc_ids,:],                   # (batch_size, vec_dim)\n",
    "                           torch.sum(self.word_matrix[context_ids,:], dim=1))  # (batch_size, 2x context, vec_dim) -> sum to (batch_size, vec_dim)\n",
    "                                                                               #\n",
    "                                                                               # select the subset of the output layer for the NCE test\n",
    "        outputs = self.outputs[:,sample_ids]                                   # (vec_dim, batch_size, n_negative_samples + 1)\n",
    "                                                                               #\n",
    "        return torch.bmm(inputs.unsqueeze(dim=1),                              # then multiply with some munging to make the tensor shapes line up \n",
    "                         outputs.permute(1, 0, 2)).squeeze()                   # -> (batch_size, n_negative_samples + 1)\n",
    "\n",
    "model = DistributedMemory(vec_dim=50,\n",
    "                          n_docs=len(example_df),\n",
    "                          n_words=len(vocab.words))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take it for a spin!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    logits = model.forward(**next(iter(dataloader)))\n",
    "logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Oh yeah, the output layer was initialized with zeros. Time to bash out a standard issue PyTorch training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from torch.optim import Adam  # ilenic uses Adam, but gensim uses plain SGD\n",
    "import numpy as np\n",
    "\n",
    "def train(model, dataloader, epochs=40, lr=1e-3):\n",
    "    optimizer = Adam(model.parameters(), lr=lr)\n",
    "    training_losses = []\n",
    "    try:\n",
    "        for epoch in trange(epochs, desc=\"Epochs\"):\n",
    "            epoch_losses = []\n",
    "            for batch in dataloader:\n",
    "                model.zero_grad()\n",
    "                logits = model.forward(**batch)\n",
    "                batch_loss = loss(logits)\n",
    "                epoch_losses.append(batch_loss.item())\n",
    "                batch_loss.backward()\n",
    "                optimizer.step()\n",
    "            training_losses.append(np.mean(epoch_losses))\n",
    "    except KeyboardInterrupt:\n",
    "        print(f\"Interrupted on epoch {epoch}!\")\n",
    "    finally:\n",
    "        return training_losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we'll sanity check by overfitting the example data. Training loss should drop from untrained loss to something close to the minimum possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 40/40 [00:02<00:00, 21.83it/s]\n"
     ]
    }
   ],
   "source": [
    "training_losses = train(model, dataloader, epochs=40, lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.vegalite.v2+json": {
       "$schema": "https://vega.github.io/schema/vega-lite/v2.6.0.json",
       "config": {
        "view": {
         "height": 300,
         "width": 400
        }
       },
       "data": {
        "name": "data-b4211a284b320247988d74ee339a35c7"
       },
       "datasets": {
        "data-b4211a284b320247988d74ee339a35c7": [
         {
          "epoch": 0,
          "training_loss": 3.873234655897496
         },
         {
          "epoch": 1,
          "training_loss": 2.4964269541077693
         },
         {
          "epoch": 2,
          "training_loss": 1.8195302142935283
         },
         {
          "epoch": 3,
          "training_loss": 1.41878628124625
         },
         {
          "epoch": 4,
          "training_loss": 1.1586913345223766
         },
         {
          "epoch": 5,
          "training_loss": 0.9543892801818201
         },
         {
          "epoch": 6,
          "training_loss": 0.8089911760920185
         },
         {
          "epoch": 7,
          "training_loss": 0.690002828836441
         },
         {
          "epoch": 8,
          "training_loss": 0.6072199561838376
         },
         {
          "epoch": 9,
          "training_loss": 0.5273456401744131
         },
         {
          "epoch": 10,
          "training_loss": 0.4652211883310544
         },
         {
          "epoch": 11,
          "training_loss": 0.4105674152152013
         },
         {
          "epoch": 12,
          "training_loss": 0.36498250900688817
         },
         {
          "epoch": 13,
          "training_loss": 0.3294624433679096
         },
         {
          "epoch": 14,
          "training_loss": 0.29689174738980956
         },
         {
          "epoch": 15,
          "training_loss": 0.2668858468532562
         },
         {
          "epoch": 16,
          "training_loss": 0.24272207208609176
         },
         {
          "epoch": 17,
          "training_loss": 0.21813622003389618
         },
         {
          "epoch": 18,
          "training_loss": 0.1999257556715254
         },
         {
          "epoch": 19,
          "training_loss": 0.18043559644434412
         },
         {
          "epoch": 20,
          "training_loss": 0.16454930458281
         },
         {
          "epoch": 21,
          "training_loss": 0.14988945260391398
         },
         {
          "epoch": 22,
          "training_loss": 0.1378944904496104
         },
         {
          "epoch": 23,
          "training_loss": 0.12689736037183616
         },
         {
          "epoch": 24,
          "training_loss": 0.11709884914048647
         },
         {
          "epoch": 25,
          "training_loss": 0.10717140561190702
         },
         {
          "epoch": 26,
          "training_loss": 0.09873407258320663
         },
         {
          "epoch": 27,
          "training_loss": 0.09167827893111666
         },
         {
          "epoch": 28,
          "training_loss": 0.08547388257111534
         },
         {
          "epoch": 29,
          "training_loss": 0.0792784820610689
         },
         {
          "epoch": 30,
          "training_loss": 0.07406510791536104
         },
         {
          "epoch": 31,
          "training_loss": 0.06805220107405872
         },
         {
          "epoch": 32,
          "training_loss": 0.064012377917514
         },
         {
          "epoch": 33,
          "training_loss": 0.06023278630385965
         },
         {
          "epoch": 34,
          "training_loss": 0.05617725763911918
         },
         {
          "epoch": 35,
          "training_loss": 0.05255285915681871
         },
         {
          "epoch": 36,
          "training_loss": 0.04945897159434981
         },
         {
          "epoch": 37,
          "training_loss": 0.047309281130842235
         },
         {
          "epoch": 38,
          "training_loss": 0.04321649504857043
         },
         {
          "epoch": 39,
          "training_loss": 0.040696044799761244
         }
        ]
       },
       "encoding": {
        "x": {
         "field": "epoch",
         "type": "quantitative"
        },
        "y": {
         "field": "training_loss",
         "scale": {
          "type": "log"
         },
         "type": "quantitative"
        }
       },
       "mark": "bar"
      },
      "image/png": "",
      "text/plain": [
       "<VegaLite 2 object>\n",
       "\n",
       "If you see this message, it means the renderer has not been properly enabled\n",
       "for the frontend that you are using. For more information, see\n",
       "https://altair-viz.github.io/user_guide/troubleshooting.html\n"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import altair as alt\n",
    "\n",
    "df_loss = pd.DataFrame(enumerate(training_losses), columns=[\"epoch\", \"training_loss\"])\n",
    "alt.Chart(df_loss).mark_bar().encode(alt.X(\"epoch\"), alt.Y(\"training_loss\", scale=alt.Scale(type=\"log\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And because we're paranoid types, let's check a prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4.0063, -7.1900, -6.8099, -5.7869, -6.4321, -3.4320],\n",
       "        [ 4.9544, -6.3152, -7.6040, -6.6827, -7.8661, -5.0460]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    logits = model.forward(**next(iter(dataloader)))\n",
    "logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The positive sample gets a positive score and the negatives get negative scores. Super."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We should be able get the paragraph vectors for the documents and do things like check these for similarity to one another."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>doc_id</th>\n",
       "      <th>similarity</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.177416</td>\n",
       "      <td>In the week before their departure to Arrakis, when all the final scurrying about had reached a ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.081760</td>\n",
       "      <td>By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>-0.044768</td>\n",
       "      <td>The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   doc_id  similarity  \\\n",
       "1       1    1.000000   \n",
       "0       0    0.177416   \n",
       "3       3    0.081760   \n",
       "2       2   -0.044768   \n",
       "\n",
       "                                                                                                  text  \n",
       "1  It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreide...  \n",
       "0  In the week before their departure to Arrakis, when all the final scurrying about had reached a ...  \n",
       "3  By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could...  \n",
       "2  The old woman was let in by the side door down the vaulted passage by Paul's room and she was al...  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "def most_similar(paragraph_matrix, docs_df, index, n=None):\n",
    "    pm = normalize(paragraph_matrix, norm=\"l2\")  # in a smarter implementation we would cache this somewhere\n",
    "    sims = np.dot(pm, pm[index,:])\n",
    "    df = pd.DataFrame(enumerate(sims), columns=[\"doc_id\", \"similarity\"])\n",
    "    n = n if n is not None else len(sims)\n",
    "    return df.merge(docs_df[[\"text\"]].reset_index(drop=True), left_index=True, right_index=True).sort_values(by=\"similarity\", ascending=False)[:n]\n",
    "\n",
    "most_similar(model.paragraph_matrix.data, example_df, 1, n=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's not particularly illuminating for our tiny set of dummy data though. We can also use PCA to reduce our n-dimensional paragraph vectors to 2 dimensions and see if they are clustered nicely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2-component PCA, explains 48.18% of variance\n"
     ]
    },
    {
     "data": {
      "application/vnd.vegalite.v2+json": {
       "$schema": "https://vega.github.io/schema/vega-lite/v2.6.0.json",
       "config": {
        "view": {
         "height": 300,
         "width": 400
        }
       },
       "data": {
        "name": "data-0b1e0474027b1da095c3ed8d1f5f386e"
       },
       "datasets": {
        "data-0b1e0474027b1da095c3ed8d1f5f386e": [
         {
          "group": "0",
          "x": -4.767539597377173,
          "y": -4.756888669464686
         },
         {
          "group": "1",
          "x": -3.202457816084362,
          "y": 6.363203538611249
         },
         {
          "group": "2",
          "x": 6.650774519184007,
          "y": -0.034075790016020964
         },
         {
          "group": "3",
          "x": 1.3192228942775253,
          "y": -1.572239079130542
         }
        ]
       },
       "encoding": {
        "color": {
         "field": "group",
         "type": "nominal"
        },
        "x": {
         "field": "x",
         "type": "quantitative"
        },
        "y": {
         "field": "y",
         "type": "quantitative"
        }
       },
       "mark": "point"
      },
      "image/png": "",
      "text/plain": [
       "<VegaLite 2 object>\n",
       "\n",
       "If you see this message, it means the renderer has not been properly enabled\n",
       "for the frontend that you are using. For more information, see\n",
       "https://altair-viz.github.io/user_guide/troubleshooting.html\n"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "def pca_2d(paragraph_matrix, groups):\n",
    "    pca = PCA(n_components=2)\n",
    "    reduced_dims = pca.fit_transform(paragraph_matrix)\n",
    "    print(f\"2-component PCA, explains {sum(pca.explained_variance_):.2f}% of variance\")\n",
    "    df = pd.DataFrame(reduced_dims, columns=[\"x\", \"y\"])\n",
    "    df[\"group\"] = groups\n",
    "    return df\n",
    "\n",
    "example_2d = pca_2d(model.paragraph_matrix.data, [\"0\",\"1\",\"2\",\"3\"])\n",
    "alt.Chart(example_2d).mark_point().encode(x=\"x\", y=\"y\", color=\"group\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not much to see on such a tiny dataset without any labelled groups."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running this on some bigger data\n",
    "--------------------------------\n",
    "\n",
    "We'll use the BBC's dataset. The dataset was created by Derek Greene at UCD and all articles are copyright Auntie. I've munged it into a file per topic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>tokens</th>\n",
       "      <th>group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Claxton hunting first major medal  British hurdler Sarah Claxton is confident she can win her fi...</td>\n",
       "      <td>[claxton, hunting, first, major, medal, british, hurdler, sarah, claxton, is, confident, she, ca...</td>\n",
       "      <td>sport</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>O'Sullivan could run in Worlds  Sonia O'Sullivan has indicated that she would like to participat...</td>\n",
       "      <td>[could, run, in, worlds, sonia, has, indicated, that, she, would, like, to, participate, in, nex...</td>\n",
       "      <td>sport</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Greene sets sights on world title  Maurice Greene aims to wipe out the pain of losing his Olympi...</td>\n",
       "      <td>[greene, sets, sights, on, world, title, maurice, greene, aims, to, wipe, out, the, pain, of, lo...</td>\n",
       "      <td>sport</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>IAAF launches fight against drugs  The IAAF - athletics' world governing body - has met anti-dop...</td>\n",
       "      <td>[iaaf, launches, fight, against, drugs, the, iaaf, athletics, world, governing, body, has, met, ...</td>\n",
       "      <td>sport</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                  text  \\\n",
       "0  Claxton hunting first major medal  British hurdler Sarah Claxton is confident she can win her fi...   \n",
       "1  O'Sullivan could run in Worlds  Sonia O'Sullivan has indicated that she would like to participat...   \n",
       "2  Greene sets sights on world title  Maurice Greene aims to wipe out the pain of losing his Olympi...   \n",
       "3  IAAF launches fight against drugs  The IAAF - athletics' world governing body - has met anti-dop...   \n",
       "\n",
       "                                                                                                tokens  \\\n",
       "0  [claxton, hunting, first, major, medal, british, hurdler, sarah, claxton, is, confident, she, ca...   \n",
       "1  [could, run, in, worlds, sonia, has, indicated, that, she, would, like, to, participate, in, nex...   \n",
       "2  [greene, sets, sights, on, world, title, maurice, greene, aims, to, wipe, out, the, pain, of, lo...   \n",
       "3  [iaaf, launches, fight, against, drugs, the, iaaf, athletics, world, governing, body, has, met, ...   \n",
       "\n",
       "   group  \n",
       "0  sport  \n",
       "1  sport  \n",
       "2  sport  \n",
       "3  sport  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dfs = []\n",
    "for document_set in (\"sport\",\n",
    "                     \"business\",\n",
    "                     \"politics\", \n",
    "                     \"tech\", \n",
    "                     \"entertainment\"):\n",
    "    df_ = pd.read_csv(f\"data/bbc/{document_set}.csv.bz2\", encoding=\"latin1\")\n",
    "    df_ = tokenize_text(df_)\n",
    "    df_[\"group\"] = document_set\n",
    "    dfs.append(df_)\n",
    "\n",
    "bbc_df = pd.concat(dfs)\n",
    "bbc_df[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset comprises 2225 documents and 19063 unique words\n"
     ]
    }
   ],
   "source": [
    "bbc_vocab = Vocab([tok for tokens in bbc_df.tokens for tok in tokens])\n",
    "\n",
    "bbc_df = clean_tokens(bbc_df, bbc_vocab)\n",
    "\n",
    "print(f\"Dataset comprises {len(bbc_df)} documents and {len(bbc_vocab.words)} unique words\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "bbc_noise = NoiseDistribution(bbc_vocab)\n",
    "bbc_examples = list(example_generator(bbc_df, context_size=5, noise=bbc_noise, n_negative_samples=5, vocab=bbc_vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "bbc_dataset = NCEDataset(bbc_examples)\n",
    "bbc_dataloader = DataLoader(bbc_dataset, batch_size=1024, drop_last=True, shuffle=True)  # TODO could tolerate a larger batch size\n",
    "\n",
    "bbc_model = DistributedMemory(vec_dim=50,\n",
    "                              n_docs=len(bbc_df),\n",
    "                              n_words=len(bbc_vocab.words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 80/80 [36:14<00:00, 26.78s/it]\n"
     ]
    }
   ],
   "source": [
    "bbc_training_losses = train(bbc_model, bbc_dataloader, epochs=80, lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.vegalite.v2+json": {
       "$schema": "https://vega.github.io/schema/vega-lite/v2.6.0.json",
       "config": {
        "view": {
         "height": 300,
         "width": 400
        }
       },
       "data": {
        "name": "data-1dabe6fce07510d2b0d23042635cb568"
       },
       "datasets": {
        "data-1dabe6fce07510d2b0d23042635cb568": [
         {
          "epoch": 0,
          "training_loss": 2.642270294331616
         },
         {
          "epoch": 1,
          "training_loss": 2.1742215334258463
         },
         {
          "epoch": 2,
          "training_loss": 1.990165346749821
         },
         {
          "epoch": 3,
          "training_loss": 1.873253081452032
         },
         {
          "epoch": 4,
          "training_loss": 1.789763425596012
         },
         {
          "epoch": 5,
          "training_loss": 1.7244694537997987
         },
         {
          "epoch": 6,
          "training_loss": 1.670895414915144
         },
         {
          "epoch": 7,
          "training_loss": 1.6244636086943727
         },
         {
          "epoch": 8,
          "training_loss": 1.5836919797873645
         },
         {
          "epoch": 9,
          "training_loss": 1.5470218249729701
         },
         {
          "epoch": 10,
          "training_loss": 1.5139503951398483
         },
         {
          "epoch": 11,
          "training_loss": 1.4832230216968134
         },
         {
          "epoch": 12,
          "training_loss": 1.4547421681954993
         },
         {
          "epoch": 13,
          "training_loss": 1.4282586559745836
         },
         {
          "epoch": 14,
          "training_loss": 1.403701515079285
         },
         {
          "epoch": 15,
          "training_loss": 1.3798937556166087
         },
         {
          "epoch": 16,
          "training_loss": 1.3577345677784511
         },
         {
          "epoch": 17,
          "training_loss": 1.3366852766238384
         },
         {
          "epoch": 18,
          "training_loss": 1.3162777810363295
         },
         {
          "epoch": 19,
          "training_loss": 1.2970764010589315
         },
         {
          "epoch": 20,
          "training_loss": 1.2788146187800058
         },
         {
          "epoch": 21,
          "training_loss": 1.2611135737496133
         },
         {
          "epoch": 22,
          "training_loss": 1.2440896054972772
         },
         {
          "epoch": 23,
          "training_loss": 1.2277287255162779
         },
         {
          "epoch": 24,
          "training_loss": 1.211776340229911
         },
         {
          "epoch": 25,
          "training_loss": 1.1965603542624053
         },
         {
          "epoch": 26,
          "training_loss": 1.1817405841365365
         },
         {
          "epoch": 27,
          "training_loss": 1.1673207131972223
         },
         {
          "epoch": 28,
          "training_loss": 1.1533892758884785
         },
         {
          "epoch": 29,
          "training_loss": 1.1397850001080436
         },
         {
          "epoch": 30,
          "training_loss": 1.12681249217217
         },
         {
          "epoch": 31,
          "training_loss": 1.1138663835407043
         },
         {
          "epoch": 32,
          "training_loss": 1.1015954906896035
         },
         {
          "epoch": 33,
          "training_loss": 1.089359670544263
         },
         {
          "epoch": 34,
          "training_loss": 1.0776376641314962
         },
         {
          "epoch": 35,
          "training_loss": 1.0659556754627584
         },
         {
          "epoch": 36,
          "training_loss": 1.0547422724480955
         },
         {
          "epoch": 37,
          "training_loss": 1.0438089944561075
         },
         {
          "epoch": 38,
          "training_loss": 1.0329133888209088
         },
         {
          "epoch": 39,
          "training_loss": 1.0227227165832282
         },
         {
          "epoch": 40,
          "training_loss": 1.0123565251042383
         },
         {
          "epoch": 41,
          "training_loss": 1.0023637906364773
         },
         {
          "epoch": 42,
          "training_loss": 0.9926695086200785
         },
         {
          "epoch": 43,
          "training_loss": 0.982954163951163
         },
         {
          "epoch": 44,
          "training_loss": 0.9737152805239517
         },
         {
          "epoch": 45,
          "training_loss": 0.9645086004867317
         },
         {
          "epoch": 46,
          "training_loss": 0.9555849193786242
         },
         {
          "epoch": 47,
          "training_loss": 0.9466561066437952
         },
         {
          "epoch": 48,
          "training_loss": 0.9380779170101474
         },
         {
          "epoch": 49,
          "training_loss": 0.929785754813911
         },
         {
          "epoch": 50,
          "training_loss": 0.9214358320147354
         },
         {
          "epoch": 51,
          "training_loss": 0.9133923515770006
         },
         {
          "epoch": 52,
          "training_loss": 0.905320764328382
         },
         {
          "epoch": 53,
          "training_loss": 0.8976051105475574
         },
         {
          "epoch": 54,
          "training_loss": 0.8899229808623746
         },
         {
          "epoch": 55,
          "training_loss": 0.8823469078318673
         },
         {
          "epoch": 56,
          "training_loss": 0.8750101461173585
         },
         {
          "epoch": 57,
          "training_loss": 0.867823110456052
         },
         {
          "epoch": 58,
          "training_loss": 0.8607556018029681
         },
         {
          "epoch": 59,
          "training_loss": 0.8537966694891083
         },
         {
          "epoch": 60,
          "training_loss": 0.8470302788367182
         },
         {
          "epoch": 61,
          "training_loss": 0.8402785332306572
         },
         {
          "epoch": 62,
          "training_loss": 0.8336316384884142
         },
         {
          "epoch": 63,
          "training_loss": 0.8271638499283642
         },
         {
          "epoch": 64,
          "training_loss": 0.8209897966118332
         },
         {
          "epoch": 65,
          "training_loss": 0.8146745335981712
         },
         {
          "epoch": 66,
          "training_loss": 0.8085972012940401
         },
         {
          "epoch": 67,
          "training_loss": 0.8024606412982348
         },
         {
          "epoch": 68,
          "training_loss": 0.7965273064856203
         },
         {
          "epoch": 69,
          "training_loss": 0.7905283134916554
         },
         {
          "epoch": 70,
          "training_loss": 0.785075603212629
         },
         {
          "epoch": 71,
          "training_loss": 0.7794511792822654
         },
         {
          "epoch": 72,
          "training_loss": 0.7739284674573389
         },
         {
          "epoch": 73,
          "training_loss": 0.7684736896745907
         },
         {
          "epoch": 74,
          "training_loss": 0.7630798336141598
         },
         {
          "epoch": 75,
          "training_loss": 0.7578280769519924
         },
         {
          "epoch": 76,
          "training_loss": 0.7526835779966035
         },
         {
          "epoch": 77,
          "training_loss": 0.7476536778189381
         },
         {
          "epoch": 78,
          "training_loss": 0.7426093139263413
         },
         {
          "epoch": 79,
          "training_loss": 0.7375516610856382
         }
        ]
       },
       "encoding": {
        "x": {
         "field": "epoch",
         "type": "quantitative"
        },
        "y": {
         "field": "training_loss",
         "type": "quantitative"
        }
       },
       "mark": "bar"
      },
      "image/png": "",
      "text/plain": [
       "<VegaLite 2 object>\n",
       "\n",
       "If you see this message, it means the renderer has not been properly enabled\n",
       "for the frontend that you are using. For more information, see\n",
       "https://altair-viz.github.io/user_guide/troubleshooting.html\n"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alt.Chart(pd.DataFrame(enumerate(bbc_training_losses), columns=[\"epoch\", \"training_loss\"])).mark_bar().encode(x=\"epoch\", y=\"training_loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the reduced dimensionality paragraph vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bbc_2d = pca_2d(bbc_model.paragraph_matrix.data, bbc_df.group.to_numpy())\n",
    "chart = alt.Chart(bbc_2d).mark_point().encode(x=\"x\", y=\"y\", color=\"group\")\n",
    "# Uncomment to print chart inline, but beware it will inflate the notebook size\n",
    "# chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`2-component PCA, explains 2.65% of variance`\n",
    "\n",
    "![](./img/bbc_pca_all_topics.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These results aren't great, but we can see the beginnings of separation. If we look at just two topics it becomes more obvious."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chart = alt.Chart(bbc_2d[bbc_2d[\"group\"].isin([\"sport\", \"business\"])]).mark_point().encode(x=\"x\", y=\"y\", color=\"group\")\n",
    "# Uncomment to print chart inline, but beware it will inflate the notebook size\n",
    "# chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./img/bbc_pca_business_sport.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Likewise we can see sorting by similarity produces reasonable, but not ideal, results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>doc_id</th>\n",
       "      <th>similarity</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>Claxton hunting first major medal  British hurdler Sarah Claxton is confident she can win her fi...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>37</td>\n",
       "      <td>0.504319</td>\n",
       "      <td>Radcliffe proves doubters wrong  This won't go down as one of the greatest marathons of Paula's ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>41</td>\n",
       "      <td>0.499603</td>\n",
       "      <td>Radcliffe enjoys winning comeback  Paula Radcliffe made a triumphant return to competitive runni...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1545</th>\n",
       "      <td>1545</td>\n",
       "      <td>0.499484</td>\n",
       "      <td>Search wars hit desktop PCs  Another front in the on-going battle between Microsoft and Google i...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1266</th>\n",
       "      <td>1266</td>\n",
       "      <td>0.490500</td>\n",
       "      <td>Student 'inequality' exposed  Teenagers from well-off backgrounds are six times more likely to g...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>0.442955</td>\n",
       "      <td>Edwards tips Idowu for Euro gold  World outdoor triple jump record holder and BBC pundit Jonatha...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>348</th>\n",
       "      <td>348</td>\n",
       "      <td>0.430447</td>\n",
       "      <td>Italy aim to rattle England  Italy coach John Kirwan believes his side can upset England as the ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>251</th>\n",
       "      <td>251</td>\n",
       "      <td>0.429918</td>\n",
       "      <td>Ferguson rues failure to cut gap  Boss Sir Alex Ferguson was left ruing Manchester United's fail...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>24</td>\n",
       "      <td>0.429485</td>\n",
       "      <td>El Guerrouj targets cross country  Double Olympic champion Hicham El Guerrouj is set to make a r...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>464</th>\n",
       "      <td>464</td>\n",
       "      <td>0.412518</td>\n",
       "      <td>Henin-Hardenne beaten on comeback  Justine Henin-Hardenne lost to Elena Dementieva in a comeback...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      doc_id  similarity  \\\n",
       "0          0    1.000000   \n",
       "37        37    0.504319   \n",
       "41        41    0.499603   \n",
       "1545    1545    0.499484   \n",
       "1266    1266    0.490500   \n",
       "19        19    0.442955   \n",
       "348      348    0.430447   \n",
       "251      251    0.429918   \n",
       "24        24    0.429485   \n",
       "464      464    0.412518   \n",
       "\n",
       "                                                                                                     text  \n",
       "0     Claxton hunting first major medal  British hurdler Sarah Claxton is confident she can win her fi...  \n",
       "37    Radcliffe proves doubters wrong  This won't go down as one of the greatest marathons of Paula's ...  \n",
       "41    Radcliffe enjoys winning comeback  Paula Radcliffe made a triumphant return to competitive runni...  \n",
       "1545  Search wars hit desktop PCs  Another front in the on-going battle between Microsoft and Google i...  \n",
       "1266  Student 'inequality' exposed  Teenagers from well-off backgrounds are six times more likely to g...  \n",
       "19    Edwards tips Idowu for Euro gold  World outdoor triple jump record holder and BBC pundit Jonatha...  \n",
       "348   Italy aim to rattle England  Italy coach John Kirwan believes his side can upset England as the ...  \n",
       "251   Ferguson rues failure to cut gap  Boss Sir Alex Ferguson was left ruing Manchester United's fail...  \n",
       "24    El Guerrouj targets cross country  Double Olympic champion Hicham El Guerrouj is set to make a r...  \n",
       "464   Henin-Hardenne beaten on comeback  Justine Henin-Hardenne lost to Elena Dementieva in a comeback...  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "most_similar(bbc_model.paragraph_matrix.data, bbc_df, 0, n=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next steps\n",
    "----------\n",
    "\n",
    "That's all for now! I honestly hope that was fun and educational (it was for me, anyway).\n",
    "\n",
    "But data science projects are notorious for never being finished. To carry this on, we could:\n",
    "\n",
    "- look for better hyperparameters, since the training loss remains quite high\n",
    "- benchmark against `gensim` and Ilenic's PyTorch implementation; it should be very similar to the latter\n",
    "- implement the inference step for new documents, which freezes the word and output matrices and adds a new column to the paragraph matrix\n",
    "- use inferred paragraph vectors as the input for a topic classifier; looking at the business/sport plot above it could be quite successful\n",
    "- try visualization with a better dimensionality reduction algorithm than PCA (I've used [LargeVis](https://arxiv.org/abs/1602.00370) in the past)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}