{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "We'll take a dataset of documents in several different categories, and find topics (consisting of groups of words) for them. Knowing the actual categories helps us evaluate if the topics we find make sense.\n", "\n", "We will try this with two different matrix factorizations: **Singular Value Decomposition (SVD) and Non-negative Matrix Factorization (NMF)**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn import decomposition\n", "from scipy import linalg\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%matplotlib inline\n", "np.set_printoptions(suppress=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scikit Learn comes with a number of built-in datasets, as well as loading utilities to load several standard external datasets. This is a great resource, and the datasets include Boston housing prices, face images, patches of forest, diabetes, breast cancer, and more. We will be using the newsgroups dataset.\n", "\n", "Newsgroups are discussion groups on Usenet, which was popular in the 80s and 90s before the web really took off. This dataset includes 18,000 newsgroups posts with 20 topics." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']\n", "remove = ('headers', 'footers', 'quotes')\n", "newsgroups_train = fetch_20newsgroups(subset='train', categories=categories, remove=remove)\n", "newsgroups_test = fetch_20newsgroups(subset='test', categories=categories, remove=remove)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['/Users/princegrover/scikit_learn_data/20news_home/20news-bydate-train/comp.graphics/38816',\n", " '/Users/princegrover/scikit_learn_data/20news_home/20news-bydate-train/talk.religion.misc/83741',\n", " '/Users/princegrover/scikit_learn_data/20news_home/20news-bydate-train/sci.space/61092'],\n", " dtype='In article <1993Apr19.020359.26996@sq.sq.com>, msb@sq.sq.com (Mark Brader) \n", "\n", "MB> So the\n", "MB> 1970 figure seems unlikely to actually be anything but a perijove.\n", "\n", "JG>Sorry, _perijoves_...I'm not used to talking this language.\n", "\n", "Couldn't we just say periapsis or apoapsis?\n", "\n", " \n" ] } ], "source": [ "# topic about sci space\n", "# definition of *perijove* is the point in the orbit of a satellite of Jupiter nearest the planet's center \n", "\n", "print('\\n'.join(newsgroups_train.data[2:3]))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.space', 'talk.religion.misc']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "newsgroups_train.target_names" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['comp.graphics', 'talk.religion.misc', 'sci.space'], dtype='\n", "\n", "(source: [Facebook Research: Fast Randomized SVD](https://research.fb.com/fast-randomized-svd/))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1min 8s, sys: 2.99 s, total: 1min 11s\n", "Wall time: 47.3 s\n" ] } ], "source": [ "%time U, s, Vh = linalg.svd(vectors, full_matrices=False)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(2034, 2034) (2034,) (2034, 26576)\n" ] } ], "source": [ "print(U.shape, s.shape, Vh.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`s` are singular values in decreasing order" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(s);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Topics" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "num_top_words=8\n", "\n", "def show_topics(a):\n", " top_words = lambda t: [vocab[i] for i in np.argsort(t)[:-num_top_words-1:-1]]\n", " topic_words = ([top_words(t) for t in a])\n", " return [' '.join(t) for t in topic_words]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "['critus ditto propagandist surname galacticentric kindergarten surreal imaginative',\n", " 'jpeg gif file color quality image jfif format',\n", " 'graphics edu pub mail 128 3d ray ftp',\n", " 'jesus god matthew people atheists atheism does graphics',\n", " 'image data processing analysis software available tools display',\n", " 'god atheists atheism religious believe religion argument true',\n", " 'space nasa lunar mars probe moon missions probes',\n", " 'image probe surface lunar mars probes moon orbit',\n", " 'argument fallacy conclusion example true ad argumentum premises',\n", " 'space larson image theory universe physical nasa material']" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_topics(Vh[:10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that topics are telling something. For example, \n", "`Topic 10` is related to `physics and space`. \n", "`Topic 3` is related to `computer graphics`. \n", "`topic 2` is related to `image quality` etc.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Non Negative Matrix Factorization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"NMF\n", "(source: [NMF Tutorial](http://perso.telecom-paristech.fr/~essid/teach/NMF_tutorial_ICME-2014.pdf))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "code_folding": [] }, "source": [ "#### nmf on sklearn" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "m,n=vectors.shape\n", "d=5 # num topics. to be chosen " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "clf = decomposition.NMF(n_components=d, random_state=1)\n", "\n", "W1 = clf.fit_transform(vectors)\n", "H1 = clf.components_" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "code_folding": [] }, "outputs": [], "source": [ "def display_topics(model, n_top):\n", " for topic_idx, topic in enumerate(model.components_):\n", " print(\"Topic %d:\" % (topic_idx))\n", " print(' '.join(vocab[i] for i in topic.argsort()[:-n_top:-1]), '\\n')" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# no_top_words = 10\n", "# display_topics(nmf, tfidf_feature_names, no_top_words)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "array(['00', '000', '0000', ..., 'zware', 'zwarte', 'zyxel'], dtype='