{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparison of two LDA models & visualize difference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## In this notebook, I want to show how you can compare models with itself and with other model and why you need it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## First, clean up 20 newsgroups dataset. We will use it for fitting LDA."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from string import punctuation\n",
"from nltk import RegexpTokenizer\n",
"from nltk.stem.porter import PorterStemmer\n",
"from nltk.corpus import stopwords\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"\n",
"\n",
"newsgroups = fetch_20newsgroups()\n",
"eng_stopwords = set(stopwords.words('english'))\n",
"\n",
"tokenizer = RegexpTokenizer('\\s+', gaps=True)\n",
"stemmer = PorterStemmer()\n",
"translate_tab = {ord(p): u\" \" for p in punctuation}\n",
"\n",
"def text2tokens(raw_text):\n",
" \"\"\"\n",
" Convert raw test to list of stemmed tokens\n",
" \"\"\"\n",
" clean_text = raw_text.lower().translate(translate_tab)\n",
" tokens = [token.strip() for token in tokenizer.tokenize(clean_text)]\n",
" tokens = [token for token in tokens if token not in eng_stopwords]\n",
" stemmed_tokens = [stemmer.stem(token) for token in tokens]\n",
" \n",
" return [token for token in stemmed_tokens if len(token) > 2] # skip short tokens\n",
"\n",
"dataset = [text2tokens(txt) for txt in newsgroups['data']] # convert a documents to list of tokens"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from gensim.corpora import Dictionary\n",
"dictionary = Dictionary(documents=dataset, prune_at=None)\n",
"dictionary.filter_extremes(no_below=5, no_above=0.3, keep_n=None) # use Dictionary to remove un-relevant tokens\n",
"dictionary.compactify()\n",
"\n",
"d2b_dataset = [dictionary.doc2bow(doc) for doc in dataset] # convert list of tokens to bag of word representation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Second, fit two LDA models."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 17s, sys: 22.2 s, total: 4min 39s\n",
"Wall time: 5min 13s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"from gensim.models import LdaMulticore\n",
"num_topics = 15\n",
"\n",
"lda_fst = LdaMulticore(\n",
" corpus=d2b_dataset, num_topics=num_topics, id2word=dictionary,\n",
" workers=4, eval_every=None, passes=10, batch=True\n",
")\n",
"\n",
"lda_snd = LdaMulticore(\n",
" corpus=d2b_dataset, num_topics=num_topics, id2word=dictionary,\n",
" workers=4, eval_every=None, passes=20, batch=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## It's time to cases with visualisation, Yay!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/vnd.plotly.v1+html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import plotly.offline as py\n",
"import plotly.graph_objs as go\n",
"\n",
"py.init_notebook_mode()\n",
"\n",
"def plot_difference(mdiff, title=\"\", annotation=None):\n",
" \"\"\"\n",
" Helper function for plot difference between models\n",
" \"\"\"\n",
" annotation_html = None\n",
" if annotation is not None:\n",
" annotation_html = [\n",
" [\n",
" \"+++ {}
--- {}\".format(\", \".join(int_tokens), \", \".join(diff_tokens)) \n",
" for (int_tokens, diff_tokens) in row\n",
" ] \n",
" for row in annotation\n",
" ]\n",
" \n",
" data = go.Heatmap(z=mdiff, colorscale='RdBu', text=annotation_html)\n",
" layout = go.Layout(width=950, height=950, title=title, xaxis=dict(title=\"topic\"), yaxis=dict(title=\"topic\"))\n",
" py.iplot(dict(data=[data], layout=layout))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In gensim, you can visualise topic different with matrix and annotation. For this purposes, you can use method `diff` from LdaModel.\n",
"\n",
"This function return matrix with distances mdiff and matrix with annotations annotation. Read the docstring for more detailed info.\n",
"\n",
"In cells mdiff[i][j] we can see a distance between topic_i from the first model and topic_j from the second model.\n",
"\n",
"In cells annotation[i][j] we can see [tokens from intersection, tokens from difference] between topic_i from first model and topic_j from the second model."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"LdaMulticore.diff?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Case 1: How topics in ONE model correlate with each other."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Short description:\n",
"- x-axis - topic;\n",
"- y-axis - topic;\n",
"- almost red cell - strongly decorrelated topics;\n",
"- almost blue cell - strongly correlated topics.\n",
"\n",
"In an ideal world, we would like to see different topics decorrelated between themselves. In this case, our matrix would look like this:\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"data": [
{
"colorscale": "RdBu",
"text": null,
"type": "heatmap",
"z": [
[
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0
]
]
}
],
"layout": {
"height": 950,
"title": "Topic difference (one model) in ideal world",
"width": 950,
"xaxis": {
"title": "topic"
},
"yaxis": {
"title": "topic"
}
}
},
"text/html": [
"