{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `pyLDAvis.lda_model`\n", "\n", "pyLDAvis now also supports LDA application from scikit-learn. Let's take a look into this in more detail. We will be using the 20 newsgroups dataset as provided by scikit-learn." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore', category=DeprecationWarning) \n", "warnings.filterwarnings('ignore', category=FutureWarning) \n", "warnings.filterwarnings('ignore', category=UserWarning)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "import pyLDAvis\n", "import pyLDAvis.lda_model\n", "pyLDAvis.enable_notebook()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n", "from sklearn.decomposition import LatentDirichletAllocation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load 20 newsgroups dataset\n", "\n", "First, the 20 newsgroups dataset available in sklearn is loaded. As always, the headers, footers and quotes are removed." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11314\n" ] } ], "source": [ "newsgroups = fetch_20newsgroups(remove=('headers', 'footers', 'quotes'))\n", "docs_raw = newsgroups.data\n", "print(len(docs_raw))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert to document-term matrix\n", "\n", "Next, the raw documents are converted into document-term matrix, possibly as raw counts or in TF-IDF form." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(11314, 9144)\n" ] } ], "source": [ "tf_vectorizer = CountVectorizer(strip_accents = 'unicode',\n", " stop_words = 'english',\n", " lowercase = True,\n", " token_pattern = r'\\b[a-zA-Z]{3,}\\b',\n", " max_df = 0.5, \n", " min_df = 10)\n", "dtm_tf = tf_vectorizer.fit_transform(docs_raw)\n", "print(dtm_tf.shape)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(11314, 9144)\n" ] } ], "source": [ "tfidf_vectorizer = TfidfVectorizer(**tf_vectorizer.get_params())\n", "dtm_tfidf = tfidf_vectorizer.fit_transform(docs_raw)\n", "print(dtm_tfidf.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit Latent Dirichlet Allocation models\n", "\n", "Finally, the LDA models are fitted." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
LatentDirichletAllocation(n_components=20, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LatentDirichletAllocation(n_components=20, random_state=0)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# for TF DTM\n", "lda_tf = LatentDirichletAllocation(n_components=20, random_state=0)\n", "lda_tf.fit(dtm_tf)\n", "# for TFIDF DTM\n", "lda_tfidf = LatentDirichletAllocation(n_components=20, random_state=0)\n", "lda_tfidf.fit(dtm_tfidf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing the models with pyLDAvis" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", "" ], "text/plain": [ "PreparedData(topic_coordinates= x y topics cluster Freq\n", "topic \n", "6 0.080525 0.084967 1 1 10.698914\n", "14 0.196207 0.062185 2 1 10.077951\n", "0 -0.102869 -0.170439 3 1 9.104501\n", "18 0.054788 -0.103025 4 1 7.036540\n", "13 0.125707 0.045466 5 1 6.281324\n", "8 -0.128318 -0.186213 6 1 5.411451\n", "19 -0.012272 -0.152626 7 1 5.387479\n", "12 0.094636 0.071910 8 1 5.377476\n", "9 -0.124044 0.010609 9 1 4.547461\n", "7 -0.030445 -0.011852 10 1 4.498512\n", "5 0.142672 0.071664 11 1 4.342297\n", "1 -0.061581 -0.055212 12 1 4.146615\n", "15 0.069750 -0.060217 13 1 4.078443\n", "2 0.110436 -0.003365 14 1 3.673889\n", "11 -0.138638 -0.110874 15 1 3.391038\n", "3 -0.176262 -0.024068 16 1 3.240042\n", "17 0.115595 0.110588 17 1 2.977799\n", "4 0.038802 0.044731 18 1 2.548118\n", "10 0.079317 0.034016 19 1 2.216235\n", "16 -0.334006 0.341754 20 1 0.963914, topic_info= Term Freq Total Category logprob loglift\n", "5016 max 4601.000000 4601.000000 Default 30.0000 30.0000\n", "2653 edu 2438.000000 2438.000000 Default 29.0000 29.0000\n", "3523 god 1945.000000 1945.000000 Default 28.0000 28.0000\n", "4497 key 1211.000000 1211.000000 Default 27.0000 27.0000\n", "7677 space 1250.000000 1250.000000 Default 26.0000 26.0000\n", "... ... ... ... ... ... ...\n", "8805 virtual 15.110336 134.053142 Topic20 -6.2580 2.4591\n", "5140 mil 13.665991 113.450206 Topic20 -6.3585 2.5255\n", "6642 reality 15.259194 189.037514 Topic20 -6.2482 2.1252\n", "7221 scientific 15.126340 248.731356 Topic20 -6.2570 1.8420\n", "2749 end 13.682255 852.822371 Topic20 -6.3573 0.5095\n", "\n", "[1419 rows x 6 columns], token_table= Topic Freq Term\n", "term \n", "20 11 0.929034 absolutes\n", "20 17 0.038710 absolutes\n", "37 12 0.032754 accelerators\n", "37 13 0.032754 accelerators\n", "37 14 0.884362 accelerators\n", "... ... ... ...\n", "9113 19 0.022005 years\n", "9114 14 0.972649 yeast\n", "9132 9 0.041934 yzerman\n", "9132 18 0.922552 yzerman\n", "9134 6 0.949301 zenith\n", "\n", "[5975 rows x 3 columns], R=30, lambda_step=0.01, plot_opts={'xlab': 'PC1', 'ylab': 'PC2'}, topic_order=[7, 15, 1, 19, 14, 9, 20, 13, 10, 8, 6, 2, 16, 3, 12, 4, 18, 5, 11, 17])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyLDAvis.lda_model.prepare(lda_tf, dtm_tf, tf_vectorizer)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", "" ], "text/plain": [ "PreparedData(topic_coordinates= x y topics cluster Freq\n", "topic \n", "15 0.195576 0.043593 1 1 34.043427\n", "6 0.120526 0.138568 2 1 20.066897\n", "12 0.299475 -0.132850 3 1 12.540161\n", "5 0.017137 0.188807 4 1 5.209774\n", "10 0.104041 0.180372 5 1 4.859943\n", "18 0.139450 -0.149233 6 1 3.561402\n", "4 -0.120754 -0.015554 7 1 2.472662\n", "7 -0.052732 -0.030381 8 1 2.044849\n", "11 -0.029780 -0.073130 9 1 1.982630\n", "13 -0.087406 0.005970 10 1 1.504950\n", "9 -0.074859 -0.028866 11 1 1.408547\n", "3 -0.025622 -0.053174 12 1 1.304860\n", "14 -0.064001 -0.011279 13 1 1.287910\n", "0 -0.072994 -0.015142 14 1 1.273234\n", "1 -0.068393 -0.004103 15 1 1.262125\n", "8 -0.064321 -0.018974 16 1 1.201762\n", "17 -0.058346 -0.009243 17 1 1.042810\n", "16 -0.054621 -0.008681 18 1 0.992643\n", "19 -0.053633 -0.008068 19 1 0.970807\n", "2 -0.048743 0.001367 20 1 0.968607, topic_info= Term Freq Total Category logprob loglift\n", "4497 key 52.000000 52.000000 Default 30.0000 30.0000\n", "3523 god 82.000000 82.000000 Default 29.0000 29.0000\n", "2653 edu 90.000000 90.000000 Default 28.0000 28.0000\n", "8252 thanks 112.000000 112.000000 Default 27.0000 27.0000\n", "1322 chip 42.000000 42.000000 Default 26.0000 26.0000\n", "... ... ... ... ... ... ...\n", "8267 theory 1.719096 16.841984 Topic20 -5.7826 2.3550\n", "8906 wave 1.228214 8.026150 Topic20 -6.1188 2.7599\n", "1002 bring 1.138414 14.067291 Topic20 -6.1947 2.1229\n", "912 book 1.206459 37.538219 Topic20 -6.1367 1.1994\n", "913 books 1.079081 20.675158 Topic20 -6.2483 1.6842\n", "\n", "[1011 rows x 6 columns], token_table= Topic Freq Term\n", "term \n", "13 10 0.696130 abraham\n", "24 19 0.444972 absurdity\n", "49 2 0.821916 accident\n", "51 14 0.746203 accidentally\n", "88 2 0.249232 acs\n", "... ... ... ...\n", "9131 18 0.288959 yup\n", "9140 1 0.140220 zip\n", "9140 6 0.070110 zip\n", "9140 8 0.070110 zip\n", "9140 12 0.701101 zip\n", "\n", "[1994 rows x 3 columns], R=30, lambda_step=0.01, plot_opts={'xlab': 'PC1', 'ylab': 'PC2'}, topic_order=[16, 7, 13, 6, 11, 19, 5, 8, 12, 14, 10, 4, 15, 1, 2, 9, 18, 17, 20, 3])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyLDAvis.lda_model.prepare(lda_tfidf, dtm_tfidf, tfidf_vectorizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using different MDS functions\n", "\n", "With `sklearn` installed, other MDS functions, such as MMDS and TSNE can be used for plotting if the default PCoA is not satisfactory." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", "" ], "text/plain": [ "PreparedData(topic_coordinates= x y topics cluster Freq\n", "topic \n", "6 -0.099949 -0.213273 1 1 10.698914\n", "14 -0.061287 -0.120387 2 1 10.077951\n", "0 -0.204806 0.220991 3 1 9.104501\n", "18 -0.248929 -0.049583 4 1 7.036540\n", "13 0.130947 0.032302 5 1 6.281324\n", "8 -0.286411 0.172089 6 1 5.411451\n", "19 -0.357969 -0.001664 7 1 5.387479\n", "12 -0.200551 -0.291682 8 1 5.377476\n", "9 0.002478 -0.374588 9 1 4.547461\n", "7 0.092395 -0.255058 10 1 4.498512\n", "5 0.108063 -0.102476 11 1 4.342297\n", "1 0.057300 0.198581 12 1 4.146615\n", "15 -0.138635 0.048003 13 1 4.078443\n", "2 -0.024170 0.006722 14 1 3.673889\n", "11 -0.111985 0.340480 15 1 3.391038\n", "3 0.082585 0.355001 16 1 3.240042\n", "17 0.244131 -0.230228 17 1 2.977799\n", "4 0.291942 -0.110922 18 1 2.548118\n", "10 0.272353 0.065793 19 1 2.216235\n", "16 0.452499 0.309896 20 1 0.963914, topic_info= Term Freq Total Category logprob loglift\n", "5016 max 4601.000000 4601.000000 Default 30.0000 30.0000\n", "2653 edu 2438.000000 2438.000000 Default 29.0000 29.0000\n", "3523 god 1945.000000 1945.000000 Default 28.0000 28.0000\n", "4497 key 1211.000000 1211.000000 Default 27.0000 27.0000\n", "7677 space 1250.000000 1250.000000 Default 26.0000 26.0000\n", "... ... ... ... ... ... ...\n", "8805 virtual 15.110336 134.053142 Topic20 -6.2580 2.4591\n", "5140 mil 13.665991 113.450206 Topic20 -6.3585 2.5255\n", "6642 reality 15.259194 189.037514 Topic20 -6.2482 2.1252\n", "7221 scientific 15.126340 248.731356 Topic20 -6.2570 1.8420\n", "2749 end 13.682255 852.822371 Topic20 -6.3573 0.5095\n", "\n", "[1419 rows x 6 columns], token_table= Topic Freq Term\n", "term \n", "20 11 0.929034 absolutes\n", "20 17 0.038710 absolutes\n", "37 12 0.032754 accelerators\n", "37 13 0.032754 accelerators\n", "37 14 0.884362 accelerators\n", "... ... ... ...\n", "9113 19 0.022005 years\n", "9114 14 0.972649 yeast\n", "9132 9 0.041934 yzerman\n", "9132 18 0.922552 yzerman\n", "9134 6 0.949301 zenith\n", "\n", "[5975 rows x 3 columns], R=30, lambda_step=0.01, plot_opts={'xlab': 'PC1', 'ylab': 'PC2'}, topic_order=[7, 15, 1, 19, 14, 9, 20, 13, 10, 8, 6, 2, 16, 3, 12, 4, 18, 5, 11, 17])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyLDAvis.lda_model.prepare(lda_tf, dtm_tf, tf_vectorizer, mds='mmds')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", "" ], "text/plain": [ "PreparedData(topic_coordinates= x y topics cluster Freq\n", "topic \n", "6 -14.853498 -45.800789 1 1 10.698914\n", "14 -96.739922 30.589661 2 1 10.077951\n", "0 15.219661 60.205353 3 1 9.104501\n", "18 -8.863257 22.755390 4 1 7.036540\n", "13 -37.461292 -11.603455 5 1 6.281324\n", "8 -7.041614 -89.441154 6 1 5.411451\n", "19 86.750648 11.611804 7 1 5.387479\n", "12 38.504055 97.494286 8 1 5.377476\n", "9 -52.937565 25.837904 9 1 4.547461\n", "7 8.071974 -12.579908 10 1 4.498512\n", "5 -27.465855 60.084621 11 1 4.342297\n", "1 35.104191 24.467779 12 1 4.146615\n", "15 -68.970993 76.778412 13 1 4.078443\n", "2 78.914879 -46.868855 14 1 3.673889\n", "11 68.260010 56.726215 15 1 3.391038\n", "3 -58.801720 -60.212524 16 1 3.240042\n", "17 30.836576 -58.077229 17 1 2.977799\n", "4 -83.038345 -17.733229 18 1 2.548118\n", "10 -15.602712 103.489136 19 1 2.216235\n", "16 48.532097 -14.421998 20 1 0.963914, topic_info= Term Freq Total Category logprob loglift\n", "5016 max 4601.000000 4601.000000 Default 30.0000 30.0000\n", "2653 edu 2438.000000 2438.000000 Default 29.0000 29.0000\n", "3523 god 1945.000000 1945.000000 Default 28.0000 28.0000\n", "4497 key 1211.000000 1211.000000 Default 27.0000 27.0000\n", "7677 space 1250.000000 1250.000000 Default 26.0000 26.0000\n", "... ... ... ... ... ... ...\n", "8805 virtual 15.110336 134.053142 Topic20 -6.2580 2.4591\n", "5140 mil 13.665991 113.450206 Topic20 -6.3585 2.5255\n", "6642 reality 15.259194 189.037514 Topic20 -6.2482 2.1252\n", "7221 scientific 15.126340 248.731356 Topic20 -6.2570 1.8420\n", "2749 end 13.682255 852.822371 Topic20 -6.3573 0.5095\n", "\n", "[1419 rows x 6 columns], token_table= Topic Freq Term\n", "term \n", "20 11 0.929034 absolutes\n", "20 17 0.038710 absolutes\n", "37 12 0.032754 accelerators\n", "37 13 0.032754 accelerators\n", "37 14 0.884362 accelerators\n", "... ... ... ...\n", "9113 19 0.022005 years\n", "9114 14 0.972649 yeast\n", "9132 9 0.041934 yzerman\n", "9132 18 0.922552 yzerman\n", "9134 6 0.949301 zenith\n", "\n", "[5975 rows x 3 columns], R=30, lambda_step=0.01, plot_opts={'xlab': 'PC1', 'ylab': 'PC2'}, topic_order=[7, 15, 1, 19, 14, 9, 20, 13, 10, 8, 6, 2, 16, 3, 12, 4, 18, 5, 11, 17])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyLDAvis.lda_model.prepare(lda_tf, dtm_tf, tf_vectorizer, mds='tsne')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 4 }