{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gensim `Doc2Vec` Tutorial on the IMDB Sentiment Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "In this tutorial, we will learn how to apply Doc2vec using gensim by recreating the results of Le and Mikolov 2014. \n", "\n", "### Bag-of-words Model\n", "Early state-of-the-art document representations were based on the bag-of-words model, which represent input documents as a fixed-length vector. For example, borrowing from the Wikipedia article, the two documents \n", "(1) `John likes to watch movies. Mary likes movies too.` \n", "(2) `John also likes to watch football games.` \n", "are used to construct a length 10 list of words \n", "`[\"John\", \"likes\", \"to\", \"watch\", \"movies\", \"Mary\", \"too\", \"also\", \"football\", \"games\"]` \n", "so then we can represent the two documents as fixed length vectors whose elements are the frequencies of the corresponding words in our list \n", "(1) `[1, 2, 1, 1, 2, 1, 1, 0, 0, 0]` \n", "(2) `[1, 1, 1, 1, 0, 0, 0, 1, 1, 1]` \n", "Bag-of-words models are surprisingly effective but still lose information about word order. Bag of n-grams models consider word phrases of length n to represent documents as fixed-length vectors to capture local word order but suffer from data sparsity and high dimensionality.\n", "\n", "### `Word2Vec`\n", "`Word2Vec` is a more recent model that embeds words in a lower-dimensional vector space using a shallow neural network. The result is a set of word-vectors where vectors close together in vector space have similar meanings based on context, and word-vectors distant to each other have differing meanings. For example, `strong` and `powerful` would be close together and `strong` and `Paris` would be relatively far. There are two versions of this model based on skip-grams (SG) and continuous-bag-of-words (CBOW), both implemented by the gensim `Word2Vec` class.\n", "\n", "\n", "#### `Word2Vec` - Skip-gram Model\n", "The skip-gram word2vec model, for example, takes in pairs (word1, word2) generated by moving a window across text data, and trains a 1-hidden-layer neural network based on the synthetic task of given an input word, giving us a predicted probability distribution of nearby words to the input. A virtual one-hot encoding of words goes through a 'projection layer' to the hidden layer; these projection weights are later interpreted as the word embeddings. So if the hidden layer has 300 neurons, this network will give us 300-dimensional word embeddings.\n", "\n", "#### `Word2Vec` - Continuous-bag-of-words Model\n", "Continuous-bag-of-words Word2vec is very similar to the skip-gram model. It is also a 1-hidden-layer neural network. The synthetic training task now uses the average of multiple input context words, rather than a single word as in skip-gram, to predict the center word. Again, the projection weights that turn one-hot words into averageable vectors, of the same width as the hidden layer, are interpreted as the word embeddings. \n", "\n", "But, Word2Vec doesn't yet get us fixed-size vectors for longer texts.\n", "\n", "\n", "### Paragraph Vector, aka gensim `Doc2Vec`\n", "The straightforward approach of averaging each of a text's words' word-vectors creates a quick and crude document-vector that can often be useful. However, Le and Mikolov in 2014 introduced the Paragraph Vector, which usually outperforms such simple-averaging.\n", "\n", "The basic idea is: act as if a document has another floating word-like vector, which contributes to all training predictions, and is updated like other word-vectors, but we will call it a doc-vector. Gensim's `Doc2Vec` class implements this algorithm. \n", "\n", "#### Paragraph Vector - Distributed Memory (PV-DM)\n", "This is the Paragraph Vector model analogous to Word2Vec CBOW. The doc-vectors are obtained by training a neural network on the synthetic task of predicting a center word based an average of both context word-vectors and the full document's doc-vector.\n", "\n", "#### Paragraph Vector - Distributed Bag of Words (PV-DBOW)\n", "This is the Paragraph Vector model analogous to Word2Vec SG. The doc-vectors are obtained by training a neural network on the synthetic task of predicting a target word just from the full document's doc-vector. (It is also common to combine this with skip-gram testing, using both the doc-vector and nearby word-vectors to predict a single target word, but only one at a time.)\n", "\n", "### Requirements\n", "The following python modules are dependencies for this tutorial:\n", "* testfixtures ( `pip install testfixtures` )\n", "* statsmodels ( `pip install statsmodels` )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load corpus" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's download the IMDB archive if it is not already downloaded (84 MB). This will be our text data for this tutorial. \n", "The data can be found here: http://ai.stanford.edu/~amaas/data/sentiment/\n", "\n", "This cell will only reattempt steps (such as downloading the compressed data) if their output isn't already present, so it is safe to re-run until it completes successfully. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IMDB archive directory already available without download.\n", "Cleaning up dataset...\n", " train/pos: 12500 files\n", " train/neg: 12500 files\n", " test/pos: 12500 files\n", " test/neg: 12500 files\n", " train/unsup: 50000 files\n", "Success, alldata-id.txt is available for next steps.\n", "CPU times: user 17.3 s, sys: 14.1 s, total: 31.3 s\n", "Wall time: 1min 2s\n" ] } ], "source": [ "%%time \n", "\n", "import locale\n", "import glob\n", "import os.path\n", "import requests\n", "import tarfile\n", "import sys\n", "import codecs\n", "from smart_open import smart_open\n", "import re\n", "\n", "dirname = 'aclImdb'\n", "filename = 'aclImdb_v1.tar.gz'\n", "locale.setlocale(locale.LC_ALL, 'C')\n", "all_lines = []\n", "\n", "if sys.version > '3':\n", " control_chars = [chr(0x85)]\n", "else:\n", " control_chars = [unichr(0x85)]\n", "\n", "# Convert text to lower-case and strip punctuation/symbols from words\n", "def normalize_text(text):\n", " norm_text = text.lower()\n", " # Replace breaks with spaces\n", " norm_text = norm_text.replace('
', ' ')\n", " # Pad punctuation with spaces on both sides\n", " norm_text = re.sub(r\"([\\.\\\",\\(\\)!\\?;:])\", \" \\\\1 \", norm_text)\n", " return norm_text\n", "\n", "if not os.path.isfile('aclImdb/alldata-id.txt'):\n", " if not os.path.isdir(dirname):\n", " if not os.path.isfile(filename):\n", " # Download IMDB archive\n", " print(\"Downloading IMDB archive...\")\n", " url = u'http://ai.stanford.edu/~amaas/data/sentiment/' + filename\n", " r = requests.get(url)\n", " with smart_open(filename, 'wb') as f:\n", " f.write(r.content)\n", " # if error here, try `tar xfz aclImdb_v1.tar.gz` outside notebook, then re-run this cell\n", " tar = tarfile.open(filename, mode='r')\n", " tar.extractall()\n", " tar.close()\n", " else:\n", " print(\"IMDB archive directory already available without download.\")\n", "\n", " # Collect & normalize test/train data\n", " print(\"Cleaning up dataset...\")\n", " folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup']\n", " for fol in folders:\n", " temp = u''\n", " newline = \"\\n\".encode(\"utf-8\")\n", " output = fol.replace('/', '-') + '.txt'\n", " # Is there a better pattern to use?\n", " txt_files = glob.glob(os.path.join(dirname, fol, '*.txt'))\n", " print(\" %s: %i files\" % (fol, len(txt_files)))\n", " with smart_open(os.path.join(dirname, output), \"wb\") as n:\n", " for i, txt in enumerate(txt_files):\n", " with smart_open(txt, \"rb\") as t:\n", " one_text = t.read().decode(\"utf-8\")\n", " for c in control_chars:\n", " one_text = one_text.replace(c, ' ')\n", " one_text = normalize_text(one_text)\n", " all_lines.append(one_text)\n", " n.write(one_text.encode(\"utf-8\"))\n", " n.write(newline)\n", "\n", " # Save to disk for instant re-use on any future runs\n", " with smart_open(os.path.join(dirname, 'alldata-id.txt'), 'wb') as f:\n", " for idx, line in enumerate(all_lines):\n", " num_line = u\"_*{0} {1}\\n\".format(idx, line)\n", " f.write(num_line.encode(\"utf-8\"))\n", "\n", "assert os.path.isfile(\"aclImdb/alldata-id.txt\"), \"alldata-id.txt unavailable\"\n", "print(\"Success, alldata-id.txt is available for next steps.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The text data is small enough to be read into memory. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100000 docs: 25000 train-sentiment, 25000 test-sentiment\n", "CPU times: user 5.3 s, sys: 1.25 s, total: 6.55 s\n", "Wall time: 6.74 s\n" ] } ], "source": [ "%%time\n", "\n", "import gensim\n", "from gensim.models.doc2vec import TaggedDocument\n", "from collections import namedtuple\n", "\n", "# this data object class suffices as a `TaggedDocument` (with `words` and `tags`) \n", "# plus adds other state helpful for our later evaluation/reporting\n", "SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment')\n", "\n", "alldocs = []\n", "with smart_open('aclImdb/alldata-id.txt', 'rb', encoding='utf-8') as alldata:\n", " for line_no, line in enumerate(alldata):\n", " tokens = gensim.utils.to_unicode(line).split()\n", " words = tokens[1:]\n", " tags = [line_no] # 'tags = [tokens[0]]' would also work at extra memory cost\n", " split = ['train', 'test', 'extra', 'extra'][line_no//25000] # 25k train, 25k test, 25k extra\n", " sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown\n", " alldocs.append(SentimentDocument(words, tags, split, sentiment))\n", "\n", "train_docs = [doc for doc in alldocs if doc.split == 'train']\n", "test_docs = [doc for doc in alldocs if doc.split == 'test']\n", "\n", "print('%d docs: %d train-sentiment, %d test-sentiment' % (len(alldocs), len(train_docs), len(test_docs)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the native document-order has similar-sentiment documents in large clumps – which is suboptimal for training – we work with once-shuffled copy of the training set." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from random import shuffle\n", "doc_list = alldocs[:] \n", "shuffle(doc_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set-up Doc2Vec Training & Evaluation Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We approximate the experiment of Le & Mikolov [\"Distributed Representations of Sentences and Documents\"](http://cs.stanford.edu/~quocle/paragraph_vector.pdf) with guidance from Mikolov's [example go.sh](https://groups.google.com/d/msg/word2vec-toolkit/Q49FIrNOQRo/J6KG8mUj45sJ):\n", "\n", "`./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1`\n", "\n", "We vary the following parameter choices:\n", "* 100-dimensional vectors, as the 400-d vectors of the paper take a lot of memory and, in our tests of this task, don't seem to offer much benefit\n", "* Similarly, frequent word subsampling seems to decrease sentiment-prediction accuracy, so it's left out\n", "* `cbow=0` means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with `dm=0`\n", "* Added to that DBOW model are two DM models, one which averages context vectors (`dm_mean`) and one which concatenates them (`dm_concat`, resulting in a much larger, slower, more data-hungry model)\n", "* A `min_count=2` saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Doc2Vec(dbow,d100,n5,mc2,t4) vocabulary scanned & state initialized\n", "Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4) vocabulary scanned & state initialized\n", "Doc2Vec(dm/c,d100,n5,w5,mc2,t4) vocabulary scanned & state initialized\n", "CPU times: user 28.7 s, sys: 414 ms, total: 29.1 s\n", "Wall time: 29.1 s\n" ] } ], "source": [ "%%time\n", "from gensim.models import Doc2Vec\n", "import gensim.models.doc2vec\n", "from collections import OrderedDict\n", "import multiprocessing\n", "\n", "cores = multiprocessing.cpu_count()\n", "assert gensim.models.doc2vec.FAST_VERSION > -1, \"This will be painfully slow otherwise\"\n", "\n", "simple_models = [\n", " # PV-DBOW plain\n", " Doc2Vec(dm=0, vector_size=100, negative=5, hs=0, min_count=2, sample=0, \n", " epochs=20, workers=cores),\n", " # PV-DM w/ default averaging; a higher starting alpha may improve CBOW/PV-DM modes\n", " Doc2Vec(dm=1, vector_size=100, window=10, negative=5, hs=0, min_count=2, sample=0, \n", " epochs=20, workers=cores, alpha=0.05, comment='alpha=0.05'),\n", " # PV-DM w/ concatenation - big, slow, experimental mode\n", " # window=5 (both sides) approximates paper's apparent 10-word total window size\n", " Doc2Vec(dm=1, dm_concat=1, vector_size=100, window=5, negative=5, hs=0, min_count=2, sample=0, \n", " epochs=20, workers=cores),\n", "]\n", "\n", "for model in simple_models:\n", " model.build_vocab(alldocs)\n", " print(\"%s vocabulary scanned & state initialized\" % model)\n", "\n", "models_by_name = OrderedDict((str(model), model) for model in simple_models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Le and Mikolov notes that combining a paragraph vector from Distributed Bag of Words (DBOW) and Distributed Memory (DM) improves performance. We will follow, pairing the models together for evaluation. Here, we concatenate the paragraph vectors obtained from each model with the help of a thin wrapper class included in a gensim test module. (Note that this a separate, later concatenation of output-vectors than the kind of input-window-concatenation enabled by the `dm_concat=1` mode above.)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from gensim.test.test_doc2vec import ConcatenatedDoc2Vec\n", "models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[1]])\n", "models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[0], simple_models[2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predictive Evaluation Methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's define some helper methods for evaluating the performance of our Doc2vec using paragraph vectors. We will classify document sentiments using a logistic regression model based on our paragraph embeddings. We will compare the error rates based on word embeddings from our various Doc2vec models." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import statsmodels.api as sm\n", "from random import sample\n", " \n", "def logistic_predictor_from_data(train_targets, train_regressors):\n", " \"\"\"Fit a statsmodel logistic predictor on supplied data\"\"\"\n", " logit = sm.Logit(train_targets, train_regressors)\n", " predictor = logit.fit(disp=0)\n", " # print(predictor.summary())\n", " return predictor\n", "\n", "def error_rate_for_model(test_model, train_set, test_set, \n", " reinfer_train=False, reinfer_test=False, \n", " infer_steps=None, infer_alpha=None, infer_subsample=0.2):\n", " \"\"\"Report error rate on test_doc sentiments, using supplied model and train_docs\"\"\"\n", "\n", " train_targets = [doc.sentiment for doc in train_set]\n", " if reinfer_train:\n", " train_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in train_set]\n", " else:\n", " train_regressors = [test_model.docvecs[doc.tags[0]] for doc in train_set]\n", " train_regressors = sm.add_constant(train_regressors)\n", " predictor = logistic_predictor_from_data(train_targets, train_regressors)\n", "\n", " test_data = test_set\n", " if reinfer_test:\n", " if infer_subsample < 1.0:\n", " test_data = sample(test_data, int(infer_subsample * len(test_data)))\n", " test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data]\n", " else:\n", " test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs]\n", " test_regressors = sm.add_constant(test_regressors)\n", " \n", " # Predict & evaluate\n", " test_predictions = predictor.predict(test_regressors)\n", " corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data])\n", " errors = len(test_predictions) - corrects\n", " error_rate = float(errors) / len(test_predictions)\n", " return (error_rate, errors, len(test_predictions), predictor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bulk Training & Per-Model Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that doc-vector training is occurring on *all* documents of the dataset, which includes all TRAIN/TEST/DEV docs.\n", "\n", "We evaluate each model's sentiment predictive power based on error rate, and the evaluation is done for each model. \n", "\n", "(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "error_rates = defaultdict(lambda: 1.0) # To selectively print only best errors achieved" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Doc2Vec(dbow,d100,n5,mc2,t4)\n", "CPU times: user 18min 41s, sys: 59.7 s, total: 19min 41s\n", "Wall time: 6min 49s\n", "\n", "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4)\n", "CPU times: user 1.85 s, sys: 226 ms, total: 2.07 s\n", "Wall time: 673 ms\n", "\n", "0.102600 Doc2Vec(dbow,d100,n5,mc2,t4)\n", "\n", "Training Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "CPU times: user 28min 21s, sys: 1min 30s, total: 29min 52s\n", "Wall time: 9min 22s\n", "\n", "Evaluating Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "CPU times: user 1.71 s, sys: 175 ms, total: 1.88 s\n", "Wall time: 605 ms\n", "\n", "0.154280 Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "\n", "Training Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "CPU times: user 55min 8s, sys: 36.5 s, total: 55min 44s\n", "Wall time: 14min 43s\n", "\n", "Evaluating Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "CPU times: user 1.47 s, sys: 110 ms, total: 1.58 s\n", "Wall time: 533 ms\n", "\n", "0.225760 Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "\n" ] } ], "source": [ "for model in simple_models: \n", " print(\"Training %s\" % model)\n", " %time model.train(doc_list, total_examples=len(doc_list), epochs=model.epochs)\n", " \n", " print(\"\\nEvaluating %s\" % model)\n", " %time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)\n", " error_rates[str(model)] = err_rate\n", " print(\"\\n%f %s\\n\" % (err_rate, model))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "CPU times: user 4.13 s, sys: 459 ms, total: 4.59 s\n", "Wall time: 1.72 s\n", "\n", "0.103360 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "\n", "\n", "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "CPU times: user 4.03 s, sys: 351 ms, total: 4.38 s\n", "Wall time: 1.38 s\n", "\n", "0.105080 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "\n" ] } ], "source": [ "for model in [models_by_name['dbow+dmm'], models_by_name['dbow+dmc']]: \n", " print(\"\\nEvaluating %s\" % model)\n", " %time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs)\n", " error_rates[str(model)] = err_rate\n", " print(\"\\n%f %s\\n\" % (err_rate, model))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Achieved Sentiment-Prediction Accuracy" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Err_rate Model\n", "0.102600 Doc2Vec(dbow,d100,n5,mc2,t4)\n", "0.103360 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "0.105080 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "0.154280 Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "0.225760 Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n" ] } ], "source": [ "# Compare error rates achieved, best-to-worst\n", "print(\"Err_rate Model\")\n", "for rate, name in sorted((rate, name) for name, rate in error_rates.items()):\n", " print(\"%f %s\" % (rate, name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our testing, contrary to the results of the paper, on this problem, PV-DBOW alone performs as good as anything else. Concatenating vectors from different models only sometimes offers a tiny predictive improvement – and stays generally close to the best-performing solo model included. \n", "\n", "The best results achieved here are just around 10% error rate, still a long way from the paper's reported 7.42% error rate. \n", "\n", "(Other trials not shown, with larger vectors and other changes, also don't come close to the paper's reported value. Others around the net have reported a similar inability to reproduce the paper's best numbers. The PV-DM/C mode improves a bit with many more training epochs – but doesn't reach parity with PV-DBOW.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Examining Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Are inferred vectors close to the precalculated ones?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for doc 66229...\n", "Doc2Vec(dbow,d100,n5,mc2,t4):\n", " [(66229, 0.9756568670272827), (66223, 0.5901858806610107), (81851, 0.5678753852844238)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/neuscratch/Dev/gensim/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", " if np.issubdtype(vec.dtype, np.int):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4):\n", " [(66229, 0.9355567097663879), (71883, 0.49743932485580444), (74232, 0.49549904465675354)]\n", "Doc2Vec(dm/c,d100,n5,w5,mc2,t4):\n", " [(66229, 0.9248996376991272), (97306, 0.4372865557670593), (99824, 0.40370166301727295)]\n" ] } ], "source": [ "doc_id = np.random.randint(simple_models[0].docvecs.count) # Pick random doc; re-run cell for more examples\n", "print('for doc %d...' % doc_id)\n", "for model in simple_models:\n", " inferred_docvec = model.infer_vector(alldocs[doc_id].words)\n", " print('%s:\\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Defaults for inference may benefit from tuning for each dataset or model parameters.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Do close documents seem more related than distant ones?" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/neuscratch/Dev/gensim/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", " if np.issubdtype(vec.dtype, np.int):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "TARGET (34105): «even a decade after \" frontline \" aired on the abc , near as i can tell , \" current affairs \" programmes are still using the same tricks over and over . time after time , \" today tonight \" and \" a current affair \" are seen to be hiding behind the facade of journalistic professionalism , and yet they feed us nothing but tired stories about weight-loss and dodgy tradesmen , shameless network promotions and pointless celebrity puff-pieces . having often been subjected to that entertainment-less void between 'the simpsons' at 6 : 00 pm and 'sale of the century' ( or 'temptation' ) at 7 : 00 pm , i was all too aware of the little tricks that these shows would use to attract ratings . fortunately , four rising comedians – rob sitch , jane kennedy , santo cilauro and tom gleisner – were also all too aware of all this , and they crafted their frustrations into one of the most wickedly-hilarious media satires you'll ever see on television . the four entertainers had already met with comedic success , their previous most memorable television stint being on 'the late show , ' the brilliant saturday night variety show which ran for two seasons from 1992-1993 , and also featured fellow comedians mick molloy , tony martin , jason stephens and judith lucy . \" frontline \" boasts an ensemble of colourful characters , each with their own distinct and quirky personality . the current-affairs show is headed by nicely-groomed mike moore ( rob sitch ) , an ambitious , pretentious , dim-witted narcissist . mike works under the delusion that the show is serving a vital role for society – he is always adamant that they \" maintain their journalistic integrity \" – and his executive producers have excelled into getting him to believe just that . mike is basically a puppet to bring the news to the people ; occasionally he gets the inkling that he is being led along by the nose , but usually this thought is stamped out via appeals to his vanity or promises of a promotion . brooke vandenberg ( jane kennedy ) is the senior female reporter on the show . she is constantly concerned about her looks and public profile , and , if the rumours are to be believed , she has had a romantic liaison with just about every male celebrity in existence . another equally amoral reporter , marty di stasio , is portrayed by tiriel mora , who memorably played inept solicitor dennis denuto in the australian comedy classic , 'the castle . ' emma ward ( alison whyte ) is the line producer on the show , and the single shining beacon of morality on the \" frontline \" set . then there's the highly-amusing weatherman , geoffrey salter ( santo cilauro ) , mike's best friend and confidant . geoff makes a living out of always agreeing with mike's opinion , and of laughing uproariously at his jokes before admitting that he doesn't get them . for each of the shows three seasons , we are treated to a different ep , executive producer . brian thompson ( bruno lawrence ) , who unfortunately passed away in 1995 , runs the programme during season 1 . he has a decent set of morals , and is always civil to his employees , and yet is more-than-willing to cast these aside in favour of high ratings . sam murphy ( kevin j . wilson ) arrives on set in season 2 , a hard-nosed , smooth-talking producer who knows exactly how to string mike along ; the last episode of the second season , when mike finally gets the better of him , is a classic moment . graeme \" prowsey \" prowse ( steve bisley ) , ep for the third season , is crude , unpleasant and unashamedly sexist . it's , therefore , remarkable that you eventually come to like him . with its cast of distinctive , exaggerated characters , \" frontline \" has a lot of fun satirising current-affairs programmes and their dubious methods for winning ratings . many of the episodes were shot quickly and cheaply , often implementing many plot ideas from recent real-life situations , but this never really detracts from the show's topicality ten years on . celebrity cameos come in abundance , with some of the most memorable appearances including pauline hanson , don burke and jon english . watch out for harry shearer's hilarious appearance in the season 2 episode \" changing the face of current affairs , \" playing larry hadges , an american hired by the network to reform the show . particularly in the third season , i noticed that \" frontline \" boasted an extremely gritty form of black humour , uncharacteristic for such a light-hearted comedy show . genuinely funny moments are born from brooke being surreptitiously bribed into having an abortion , murder by a crazed gunman and mike treacherously betraying his best friend's hopes and dreams , only to be told that he is a good friend . the series' final minute – minus an added-scene during the credits , which was probably added just in case a fourth season was to be produced – was probably the greatest , blackest ending to a comedy series that i've yet seen . below is listed a very tentative list of my top five favourite \" frontline \" episodes , but , make no mistake , every single half-hour is absolutely hilarious and hard-hitting satire . 1 ) \" the siege \" ( season 1 ) 2 ) \" give 'em enough rope \" ( season 2 ) 3 ) \" addicted to fame \" ( season 3 ) 4 ) \" basic instincts \" ( season 2 ) 5 ) \" add sex and stir \" ( season 1 )»\n", "\n", "SIMILAR/DISSIMILAR DOCS PER MODEL Doc2Vec(dbow,d100,n5,mc2,t4):\n", "\n", "MOST (34106, 0.6284705996513367): «the sad thing about frontline is that once you watch three or four episodes of it you really begin to understand that it is not far away from what happens in real life . what is really sad is that it also makes extremely funny . the frontline team in series one consists of brian thompson ( bruno lawrence ) - a man who truly lives and dies merely by the ratings his show gets . occasionally his stunts to achieve these ratings see him run in with his line producer emma thompson ( alison whyte ) ; a woman who hasn't lost all her journalistic integrity and is prepared to defend moral scruples on occasions . the same cannot be said of reporter brooke vandenberg ( jane kennedy ) - a reporter who has had all the substance sucked out of her- so much so that when interviewing ben elton she needs to be instructed to laugh . her reports usually consist of interviewing celebrities ( with whom she has or hasn't 'crossed paths' with before ) or scandalous unethical reports that usually backfire . martin de stasio ( tiriel mora ) is the reporter with whom the team relies on for gravitas and dignity , as he has the smarts of 21 years of journalism behind him . his doesn't have principles so much as a nous of what makes a good journalistic story , though he does draw the occasional line . parading over this chaos ( in name ) is mike moore ( rob sitch ) an egotistical , naive reporter who can't see that he's only a pretty face for the grubby journalism . he often finds his morals being compromised simply because brian appeals to his vanity and allows his stupidity to do the rest . frontline is the sort of show that there needs to be more of , because it shows that while in modern times happiness , safety and deep political insight are interesting things ; it's much easier to rate with scandal , fear and tabloid celebrities .»\n", "\n", "MEDIAN (35245, 0.2309201955795288): «\" hell to pay \" bills itself as the rebirth of the classic western . . . it succeeds as a western genre movie that the entire family could see and not unlike the films baby-boomers experienced decades ago . the good guys are good and the bad guys are really bad ! . bo svenson , stella stevens , lee majors , andrew prine ( excellent in this film ) tim thomerson and james drury are all great and it's fun to see them again . james drury really shines in this one , maybe even better than his days as \" the virginian . \" in a way , \" hell to pay \" reminds me of those movies in the 60's where actors you know from so many shows make an appearance . if you're of a certain age , buck taylor , peter brown and denny miller and william smith provide a \" wow \" factor because we seldom get to see these icons these days . \" hell to pay \" features screen legends along with newer names in hollywood . most notable in the cast of \" newbies \" is rachel kimsey ( rebekah ) , who i've seen lately on \" the young and the restless \" and kevin kazakoff , who plays the angst-ridden kirby , a war-weary man who's torn between wanting to live and let live or stepping in to \" do the right thing . \" william gregory lee is excellent as chance , kirby's mischievous and womanizing brother . katie keane plays rachel , rebekah's sister , a woman who did what was necessary to stay alive but giving up her pride in the process . in a small but memorable role , jeff davis plays mean joe , a former confederate with a rather nasty mean streak . i think we'll be seeing more of these fine actors in the future . \" hell to pay \" is a fun movie with a great story to tell grab the popcorn , we're headin' west ! .»\n", "\n", "LEAST (261, -0.09666291624307632): «an unusual film from ringo lam and one that's strangely under-appreciated . the mix of fantasy kung-fu with a more realistic depiction of swords and spears being driven thru bodies is startling especially during the first ten minutes . a horseback rider get chopped in two and his waist and legs keep riding the horse . several horses get chopped up . it's very unexpected . the story is very simple , fong and his shaolin brothers are captured by a crazed maniac general and imprisoned in the red lotus temple which seems to be more of a torture chamber then a temple . the general has a similarity to kurtz in apocalypse now as he spouts warped philosophy and makes frightening paintings with human blood . the production is very impressive and the setting is bleak . blood is everywhere . the action is very well done and mostly coherent unlike many hk action scenes from the time . sometimes the movie veers into absurdity or the effects are cheesy but it's never bad enough to ruin the film . find this one , it's one of the best hk kung fu films from the early nineties . just remember it's not child friendly .»\n", "\n" ] } ], "source": [ "import random\n", "\n", "doc_id = np.random.randint(simple_models[0].docvecs.count) # pick random doc, re-run cell for more examples\n", "model = random.choice(simple_models) # and a random model\n", "sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count) # get *all* similar documents\n", "print(u'TARGET (%d): «%s»\\n' % (doc_id, ' '.join(alldocs[doc_id].words)))\n", "print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\\n' % model)\n", "for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:\n", " print(u'%s %s: «%s»\\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST... especially if the MOST has a cosine-similarity > 0.5. Re-run the cell to try another random target document." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Do the word vectors show useful similarities?" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "word_models = simple_models[:]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "most similar words for 'spoilt' (97 occurences)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/neuscratch/Dev/gensim/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", " if np.issubdtype(vec.dtype, np.int):\n" ] }, { "data": { "text/html": [ "
Doc2Vec(dbow,d100,n5,mc2,t4)Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)Doc2Vec(dm/c,d100,n5,w5,mc2,t4)
[(\"wives'\", 0.4262964725494385),
\n", "('horrificaly', 0.4177134335041046),
\n", "(\"snit'\", 0.4037289619445801),
\n", "('improf', 0.40169233083724976),
\n", "('humiliatingly', 0.3946930170059204),
\n", "('heart-pounding', 0.3938479423522949),
\n", "(\"'jo'\", 0.38460421562194824),
\n", "('kieron', 0.37991276383399963),
\n", "('linguistic', 0.3727714419364929),
\n", "('rothery', 0.3719364404678345),
\n", "('zellwegger', 0.370682954788208),
\n", "('never-released', 0.36564797163009644),
\n", "('coffeeshop', 0.36534833908081055),
\n", "('slater--these', 0.3643302917480469),
\n", "('over-plotted', 0.36348140239715576),
\n", "('synchronism', 0.36320072412490845),
\n", "('exploitations', 0.3631579875946045),
\n", "(\"donor's\", 0.36226314306259155),
\n", "('neend', 0.3619685769081116),
\n", "('renaud', 0.3611547350883484)]
[('spoiled', 0.6693772077560425),
\n", "('ruined', 0.5701743960380554),
\n", "('dominated', 0.554553747177124),
\n", "('marred', 0.5456377267837524),
\n", "('undermined', 0.5353708267211914),
\n", "('unencumbered', 0.5345744490623474),
\n", "('dwarfed', 0.5331343412399292),
\n", "('followed', 0.5186703205108643),
\n", "('entranced', 0.513541042804718),
\n", "('emboldened', 0.5100494623184204),
\n", "('shunned', 0.5044804215431213),
\n", "('disgusted', 0.5000460743904114),
\n", "('overestimated', 0.49955034255981445),
\n", "('bolstered', 0.4971669018268585),
\n", "('replaced', 0.4966174364089966),
\n", "('bookended', 0.49495506286621094),
\n", "('blowout', 0.49287083745002747),
\n", "('overshadowed', 0.48964253067970276),
\n", "('played', 0.48709338903427124),
\n", "('accompanied', 0.47834640741348267)]
[('spoiled', 0.6672338247299194),
\n", "('troubled', 0.520033597946167),
\n", "('bankrupted', 0.509053647518158),
\n", "('ruined', 0.4965386986732483),
\n", "('misguided', 0.4900725483894348),
\n", "('devoured', 0.48988765478134155),
\n", "('ravaged', 0.4861036539077759),
\n", "('frustrated', 0.4841104745864868),
\n", "('suffocated', 0.4828023314476013),
\n", "('investigated', 0.47958582639694214),
\n", "('tormented', 0.4791877865791321),
\n", "('traumatized', 0.4785040616989136),
\n", "('shaken', 0.4784379005432129),
\n", "('persecuted', 0.4774147868156433),
\n", "('crippled', 0.4771782457828522),
\n", "('torpedoed', 0.4764551818370819),
\n", "('plagued', 0.47006863355636597),
\n", "('drowned', 0.4688340723514557),
\n", "('prompted', 0.4678872525691986),
\n", "('abandoned', 0.4652657210826874)]
" ], "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "from IPython.display import HTML\n", "# pick a random word with a suitable number of occurences\n", "while True:\n", " word = random.choice(word_models[0].wv.index2word)\n", " if word_models[0].wv.vocab[word].count > 10:\n", " break\n", "# or uncomment below line, to just pick a word from the relevant domain:\n", "#word = 'comedy/drama'\n", "similars_per_model = [str(model.wv.most_similar(word, topn=20)).replace('), ','),
\\n') for model in word_models]\n", "similar_table = (\"
\" +\n", " \"\".join([str(model) for model in word_models]) + \n", " \"
\" +\n", " \"\".join(similars_per_model) +\n", " \"
\")\n", "print(\"most similar words for '%s' (%d occurences)\" % (word, simple_models[0].wv.vocab[word].count))\n", "HTML(similar_table)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the `dbow_words=1` initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task, but may be appropriate on other tasks, or if you also need word-vectors. \n", "\n", "Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word-vector training concurrent with doc-vector training.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Are the word vectors from this dataset any good at analogies?" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Success, questions-words.txt is available for next steps.\n" ] } ], "source": [ "# grab the file if not already local\n", "questions_filename = 'questions-words.txt'\n", "if not os.path.isfile(questions_filename):\n", " # Download IMDB archive\n", " print(\"Downloading analogy questions file...\")\n", " url = u'https://raw.githubusercontent.com/tmikolov/word2vec/master/questions-words.txt'\n", " r = requests.get(url)\n", " with smart_open(questions_filename, 'wb') as f:\n", " f.write(r.content)\n", "assert os.path.isfile(questions_filename), \"questions-words.txt unavailable\"\n", "print(\"Success, questions-words.txt is available for next steps.\")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/neuscratch/Dev/gensim/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", " if np.issubdtype(vec.dtype, np.int):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Doc2Vec(dbow,d100,n5,mc2,t4): 0.00% correct (0 of 14657)\n", "Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4): 17.37% correct (2546 of 14657)\n", "Doc2Vec(dm/c,d100,n5,w5,mc2,t4): 19.20% correct (2814 of 14657)\n" ] } ], "source": [ "# Note: this analysis takes many minutes\n", "for model in word_models:\n", " score, sections = model.wv.evaluate_word_analogies('questions-words.txt')\n", " correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])\n", " print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/mean and DM/concat models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Slop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "This cell left intentionally erroneous." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Advanced technique: re-inferring doc-vectors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the bulk-trained vectors had much of their training early, when the model itself was still settling, it is *sometimes* the case that rather than using the bulk-trained vectors, new vectors re-inferred from the final state of the model serve better as the input/test data for downstream tasks. \n", "\n", "Our `error_rate_for_model()` function already had a non-default option to re-infer vectors before training/testing the classifier, so here we test that option. (This takes as long or longer than initial bulk training, as inference is only single-threaded.)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4) re-inferred\n", "CPU times: user 7min 9s, sys: 1.55 s, total: 7min 11s\n", "Wall time: 7min 10s\n", "\n", "0.102240 Doc2Vec(dbow,d100,n5,mc2,t4)_reinferred\n", "\n", "Evaluating Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4) re-inferred\n", "CPU times: user 9min 48s, sys: 1.53 s, total: 9min 49s\n", "Wall time: 9min 48s\n", "\n", "0.146200 Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)_reinferred\n", "\n", "Evaluating Doc2Vec(dm/c,d100,n5,w5,mc2,t4) re-inferred\n", "CPU times: user 16min 13s, sys: 1.32 s, total: 16min 14s\n", "Wall time: 16min 13s\n", "\n", "0.218120 Doc2Vec(dm/c,d100,n5,w5,mc2,t4)_reinferred\n", "\n", "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4) re-inferred\n", "CPU times: user 15min 50s, sys: 1.63 s, total: 15min 52s\n", "Wall time: 15min 49s\n", "\n", "0.102120 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)_reinferred\n", "\n", "Evaluating Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4) re-inferred\n", "CPU times: user 22min 53s, sys: 1.81 s, total: 22min 55s\n", "Wall time: 22min 52s\n", "\n", "0.104320 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)_reinferred\n", "\n" ] } ], "source": [ "for model in simple_models + [models_by_name['dbow+dmm'], models_by_name['dbow+dmc']]: \n", " print(\"Evaluating %s re-inferred\" % str(model))\n", " pseudomodel_name = str(model)+\"_reinferred\"\n", " %time err_rate, err_count, test_count, predictor = error_rate_for_model(model, train_docs, test_docs, reinfer_train=True, reinfer_test=True, infer_subsample=1.0)\n", " error_rates[pseudomodel_name] = err_rate\n", " print(\"\\n%f %s\\n\" % (err_rate, pseudomodel_name))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Err_rate Model\n", "0.102120 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)_reinferred\n", "0.102240 Doc2Vec(dbow,d100,n5,mc2,t4)_reinferred\n", "0.102600 Doc2Vec(dbow,d100,n5,mc2,t4)\n", "0.103360 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "0.104320 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)_reinferred\n", "0.105080 Doc2Vec(dbow,d100,n5,mc2,t4)+Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n", "0.146200 Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)_reinferred\n", "0.154280 Doc2Vec(\"alpha=0.05\",dm/m,d100,n5,w10,mc2,t4)\n", "0.218120 Doc2Vec(dm/c,d100,n5,w5,mc2,t4)_reinferred\n", "0.225760 Doc2Vec(dm/c,d100,n5,w5,mc2,t4)\n" ] } ], "source": [ "# Compare error rates achieved, best-to-worst\n", "print(\"Err_rate Model\")\n", "for rate, name in sorted((rate, name) for name, rate in error_rates.items()):\n", " print(\"%f %s\" % (rate, name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we do *not* see much benefit of re-inference. It's more likely to help if the initial training used fewer epochs (10 is also a common value in the literature for larger datasets), or perhaps in larger datasets. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To get copious logging output from above steps..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)\n", "rootLogger = logging.getLogger()\n", "rootLogger.setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To auto-reload python code while developing..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 1 }