{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Topic Networks\n", "\n", "In this notebook, we will learn how to visualize topic model using network graphs. Networks can be a great way to explore topic models. We can use it to navigate that how topics belonging to one context may relate to some topics in other context and discover common factors between them. We can use them to find communities of similar topics and pinpoint the most influential topic that has large no. of connections or perform any number of other workflows designed for network analysis." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mYou are using pip version 19.0.1, however version 19.1 is available.\r\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\r\n" ] } ], "source": [ "!pip install plotly>=2.0.16 # 2.0.16 need for support 'hovertext' argument from create_dendrogram function" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from gensim.models.ldamodel import LdaModel\n", "from gensim.corpora import Dictionary\n", "import pandas as pd\n", "import re\n", "from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Model\n", "\n", "We'll use the [fake news dataset](https://www.kaggle.com/mrisdal/fake-news) from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this [notebook](https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/lda_training_tips.ipynb) also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2019-05-12 18:54:33-- https://www.kaggle.com/mrisdal/fake-news/downloads/fake-news.zip/1\n", "Resolving www.kaggle.com (www.kaggle.com)... 35.244.233.98\n", "Connecting to www.kaggle.com (www.kaggle.com)|35.244.233.98|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: /account/login?returnUrl=%2Fmrisdal%2Ffake-news%2Fversion%2F1 [following]\n", "--2019-05-12 18:54:35-- https://www.kaggle.com/account/login?returnUrl=%2Fmrisdal%2Ffake-news%2Fversion%2F1\n", "Reusing existing connection to www.kaggle.com:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: unspecified [text/html]\n", "Saving to: ‘fake.news.zip’\n", "\n", "fake.news.zip [ <=> ] 8.46K --.-KB/s in 0.01s \n", "\n", "2019-05-12 18:54:36 (640 KB/s) - ‘fake.news.zip’ saved [8668]\n", "\n" ] } ], "source": [ "!wget https://www.kaggle.com/mrisdal/fake-news/downloads/fake-news.zip/1 -O fake.news.zip" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Archive: fake.news.zip\r\n", " End-of-central-directory signature not found. Either this file is not\r\n", " a zipfile, or it constitutes one disk of a multi-part archive. In the\r\n", " latter case the central directory and zipfile comment will be found on\r\n", " the last disk(s) of this archive.\r\n", "unzip: cannot find zipfile directory in one of fake.news.zip or\r\n", " fake.news.zip.zip, and cannot find fake.news.zip.ZIP, period.\r\n" ] } ], "source": [ "!unzip fake.news.zip" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] File b'fake.csv' does not exist: b'fake.csv'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdf_fake\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'fake.csv'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mdf_fake\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'title'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'text'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'language'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdf_fake\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf_fake\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotnull\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf_fake\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdf_fake\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlanguage\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'english'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# remove stopwords and punctuations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36mparser_f\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, dialect, tupleize_cols, error_bad_lines, warn_bad_lines, delim_whitespace, low_memory, memory_map, float_precision)\u001b[0m\n\u001b[1;32m 700\u001b[0m skip_blank_lines=skip_blank_lines)\n\u001b[1;32m 701\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 702\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 703\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 704\u001b[0m \u001b[0mparser_f\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 893\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'has_index_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'has_index_names'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 895\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 896\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 897\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, engine)\u001b[0m\n\u001b[1;32m 1120\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mengine\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'c'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1121\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mengine\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'c'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1122\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCParserWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1123\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1124\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mengine\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'python'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, src, **kwds)\u001b[0m\n\u001b[1;32m 1851\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'usecols'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0musecols\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1852\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1853\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparsers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTextReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1854\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munnamed_cols\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munnamed_cols\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1855\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32mpandas/_libs/parsers.pyx\u001b[0m in \u001b[0;36mpandas._libs.parsers.TextReader.__cinit__\u001b[0;34m()\u001b[0m\n", "\u001b[0;32mpandas/_libs/parsers.pyx\u001b[0m in \u001b[0;36mpandas._libs.parsers.TextReader._setup_parser_source\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] File b'fake.csv' does not exist: b'fake.csv'" ] } ], "source": [ "df_fake = pd.read_csv('fake.csv')\n", "df_fake[['title', 'text', 'language']].head()\n", "df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language=='english')]\n", "\n", "# remove stopwords and punctuations\n", "def preprocess(row):\n", " return strip_punctuation(remove_stopwords(row.lower()))\n", " \n", "df_fake['text'] = df_fake['text'].apply(preprocess)\n", "\n", "# Convert data to required input format by LDA\n", "texts = []\n", "for line in df_fake.text:\n", " lowered = line.lower()\n", " words = re.findall(r'\\w+', lowered, flags=re.UNICODE|re.LOCALE)\n", " texts.append(words)\n", "# Create a dictionary representation of the documents.\n", "dictionary = Dictionary(texts)\n", "\n", "# Filter out words that occur less than 2 documents, or more than 30% of the documents.\n", "dictionary.filter_extremes(no_below=2, no_above=0.4)\n", "# Bag-of-words representation of the documents.\n", "corpus_fake = [dictionary.doc2bow(text) for text in texts]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, chunksize=1500, iterations=200, alpha='auto')\n", "lda_fake.save('lda_35')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lda_fake = LdaModel.load('lda_35')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize topic network\n", "\n", "Firstly, a distance matrix is calculated to store distance between every topic pair. The nodes of the network graph will represent topics and the edges between them will be created based on the distance between two connecting nodes/topics." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get topic distributions\n", "topic_dist = lda_fake.state.get_lambda()\n", "\n", "# get topic terms\n", "num_words = 50\n", "topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To draw the edges, we can use different types of distance metrics available in gensim for calculating the distance between every topic pair. Next, we'd have to define a threshold of distance value such that the topic-pairs with distance above that does not get connected. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.spatial.distance import pdist, squareform\n", "from gensim.matutils import jensen_shannon\n", "import networkx as nx\n", "import itertools as itt\n", "\n", "# calculate distance matrix using the input distance metric\n", "def distance(X, dist_metric):\n", " return squareform(pdist(X, lambda u, v: dist_metric(u, v)))\n", "\n", "topic_distance = distance(topic_dist, jensen_shannon)\n", "\n", "# store edges b/w every topic pair along with their distance\n", "edges = [(i, j, {'weight': topic_distance[i, j]})\n", " for i, j in itt.combinations(range(topic_dist.shape[0]), 2)]\n", "\n", "# keep edges with distance below the threshold value\n", "k = np.percentile(np.array([e[2]['weight'] for e in edges]), 20)\n", "edges = [e for e in edges if e[2]['weight'] < k]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our edges, let's plot the annotated network graph. On hovering over the nodes, we'll see the topic_id along with it's top words and on hovering over the edges, we'll see the intersecting/different words of the two topics that it connects. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import plotly.offline as py\n", "from plotly.graph_objs import *\n", "import plotly.figure_factory as ff\n", "\n", "py.init_notebook_mode()\n", "\n", "# add nodes and edges to graph layout\n", "G = nx.Graph()\n", "G.add_nodes_from(range(topic_dist.shape[0]))\n", "G.add_edges_from(edges)\n", "\n", "graph_pos = nx.spring_layout(G)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# initialize traces for drawing nodes and edges \n", "node_trace = Scatter(\n", " x=[],\n", " y=[],\n", " text=[],\n", " mode='markers',\n", " hoverinfo='text',\n", " marker=Marker(\n", " showscale=True,\n", " colorscale='YIGnBu',\n", " reversescale=True,\n", " color=[],\n", " size=10,\n", " colorbar=dict(\n", " thickness=15,\n", " xanchor='left'\n", " ),\n", " line=dict(width=2)))\n", "\n", "edge_trace = Scatter(\n", " x=[],\n", " y=[],\n", " text=[],\n", " line=Line(width=0.5, color='#888'),\n", " hoverinfo='text',\n", " mode='lines')\n", "\n", "\n", "# no. of terms to display in annotation\n", "n_ann_terms = 10\n", "\n", "# add edge trace with annotations\n", "for edge in G.edges():\n", " x0, y0 = graph_pos[edge[0]]\n", " x1, y1 = graph_pos[edge[1]]\n", " \n", " pos_tokens = topic_terms[edge[0]] & topic_terms[edge[1]]\n", " neg_tokens = topic_terms[edge[0]].symmetric_difference(topic_terms[edge[1]])\n", " pos_tokens = list(pos_tokens)[:min(len(pos_tokens), n_ann_terms)]\n", " neg_tokens = list(neg_tokens)[:min(len(neg_tokens), n_ann_terms)]\n", " annotation = \"
\".join((\": \".join((\"+++\", str(pos_tokens))), \": \".join((\"---\", str(neg_tokens)))))\n", " \n", " x_trace = list(np.linspace(x0, x1, 10))\n", " y_trace = list(np.linspace(y0, y1, 10))\n", " text_annotation = [annotation] * 10\n", " x_trace.append(None)\n", " y_trace.append(None)\n", " text_annotation.append(None)\n", " \n", " edge_trace['x'] += x_trace\n", " edge_trace['y'] += y_trace\n", " edge_trace['text'] += text_annotation\n", "\n", "# add node trace with annotations\n", "for node in G.nodes():\n", " x, y = graph_pos[node]\n", " node_trace['x'].append(x)\n", " node_trace['y'].append(y)\n", " node_info = ''.join((str(node+1), ': ', str(list(topic_terms[node])[:n_ann_terms])))\n", " node_trace['text'].append(node_info)\n", " \n", "# color node according to no. of connections\n", "for node, adjacencies in enumerate(G.adjacency()):\n", " node_trace['marker']['color'].append(len(adjacencies))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = Figure(data=Data([edge_trace, node_trace]),\n", " layout=Layout(showlegend=False,\n", " hovermode='closest',\n", " xaxis=XAxis(showgrid=True, zeroline=False, showticklabels=True),\n", " yaxis=YAxis(showgrid=True, zeroline=False, showticklabels=True)))\n", "\n", "py.iplot(fig)" ] }, { "cell_type": "markdown", "metadata": { "scrolled": false }, "source": [ "For the above graph, we just used the 20th percentile of all the distance values. But we can experiment with few different values also such that the graph doesn’t become too crowded or too sparse and we could get an optimum amount of information about similar topics or any interesting relations b/w different topics.\n", "\n", "Or we can also get an idea of threshold from the dendrogram (with ‘single’ linkage function). You can refer to [this notebook](http://nbviewer.jupyter.org/github/parulsethi/gensim/blob/b9e7ab54dde98438b0e4f766ee764b81af704367/docs/notebooks/Topic_dendrogram.ipynb) for more details on topic dendrogram visualization. The y-values in the dendrogram represent the metric distances and if we choose a certain y-value then only those topics which are clustered below it would be connected. So let's plot the dendrogram now to see the sequential clustering process with increasing distance values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from gensim.matutils import jensen_shannon\n", "import scipy as scp\n", "from scipy.cluster import hierarchy as sch\n", "from scipy import spatial as scs\n", "\n", "# get topic distributions\n", "topic_dist = lda_fake.state.get_lambda()\n", "\n", "# get topic terms\n", "num_words = 300\n", "topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]\n", "\n", "# no. of terms to display in annotation\n", "n_ann_terms = 10\n", "\n", "# use Jenson-Shannon distance metric in dendrogram\n", "def js_dist(X):\n", " return pdist(X, lambda u, v: jensen_shannon(u, v))\n", "\n", "# define method for distance calculation in clusters\n", "linkagefun=lambda x: sch.linkage(x, 'single')\n", "\n", "# calculate text annotations\n", "def text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun):\n", " # get dendrogram hierarchy data\n", " d = js_dist(topic_dist)\n", " Z = linkagefun(d)\n", " P = sch.dendrogram(Z, orientation=\"bottom\", no_plot=True)\n", "\n", " # store topic no.(leaves) corresponding to the x-ticks in dendrogram\n", " x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)\n", " x_topic = dict(zip(P['leaves'], x_ticks))\n", "\n", " # store {topic no.:topic terms}\n", " topic_vals = dict()\n", " for key, val in x_topic.items():\n", " topic_vals[val] = (topic_terms[key], topic_terms[key])\n", "\n", " text_annotations = []\n", " # loop through every trace (scatter plot) in dendrogram\n", " for trace in P['icoord']:\n", " fst_topic = topic_vals[trace[0]]\n", " scnd_topic = topic_vals[trace[2]]\n", " \n", " # annotation for two ends of current trace\n", " pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]\n", " neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]\n", "\n", " pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]\n", " neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]\n", "\n", " t1 = \"
\".join((\": \".join((\"+++\", str(pos_tokens_t1))), \": \".join((\"---\", str(neg_tokens_t1)))))\n", " t2 = t3 = ()\n", " t4 = \"
\".join((\": \".join((\"+++\", str(pos_tokens_t4))), \": \".join((\"---\", str(neg_tokens_t4)))))\n", "\n", " # show topic terms in leaves\n", " if trace[0] in x_ticks:\n", " t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])\n", " if trace[2] in x_ticks:\n", " t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])\n", "\n", " text_annotations.append([t1, t2, t3, t4])\n", "\n", " # calculate intersecting/diff for upper level\n", " intersecting = fst_topic[0] & scnd_topic[0]\n", " different = fst_topic[0].symmetric_difference(scnd_topic[0])\n", "\n", " center = (trace[0] + trace[2]) / 2\n", " topic_vals[center] = (intersecting, different)\n", "\n", " # remove trace value after it is annotated\n", " topic_vals.pop(trace[0], None)\n", " topic_vals.pop(trace[2], None) \n", " \n", " return text_annotations\n", "\n", "# get text annotations\n", "annotation = text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun)\n", "\n", "# Plot dendrogram\n", "dendro = ff.create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), linkagefun=linkagefun, hovertext=annotation)\n", "dendro['layout'].update({'width': 1000, 'height': 600})\n", "py.iplot(dendro)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From observing this dendrogram, we can try the threshold values between 0.3 to 0.35 for network graph, as the topics are clustered in distinct groups below them and this could plot separate clusters of related topics in the network graph.\n", "\n", "But then why do we need to use network graph if the dendrogram already shows the topic clusters with a clear sequence of how topics joined one after the other. The problem is that we can't see the direct relation of any topic with another topic except if they are directly paired at the first hierarchy level. The network graph let's us explore the inter-topic distances and at the same time observe clusters of closely related topics." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }