{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Independent Distribution\n", "\n", "> In this post, we will find the meaning of the independent distribution, which is the bridge between univariate distribution and multivariate distribution. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/independent_from_bivariate.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "tfd = tfp.distributions\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Independent Distribution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Actually, the Independent distribution is not a formal name of specific distribution. In tensorflow probability, `Independent` distribution is that converts from batch of univariate distribution to multivariate distribution. In previous notebook, you may see that when the batch univariate distribution is formed, its shape is same as multivariate distribution. If we can define the index of batch size to be reinterpret, we can convert it." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Start by defining a batch of two univariate Gaussians, then\n", "# Combine them into a bivariate Gaussian with independent components\n", "\n", "locs = [-1., 1]\n", "scales = [0.5, 1.]\n", "\n", "batched_normal = tfd.Normal(loc=locs, scale=scales)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Univariate distribution\n", "\n", "t = np.linspace(-4, 4, 10000)\n", "# each column is a vector of densities for one distribution\n", "densities = batched_normal.prob(np.repeat(t[:, np.newaxis], 2, axis=1))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.lineplot(x=t, y=densities[:, 0], label='loc={}, scale={}'.format(locs[0], scales[0]))\n", "sns.lineplot(x=t, y=densities[:, 1], label='loc={}, scale={}'.format(locs[1], scales[1]))\n", "plt.xlabel('Probability Density')\n", "plt.ylabel('Value')\n", "plt.legend(loc='best')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check their batch_shape and event_shape\n", "\n", "batched_normal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, this distribution has batch_shape of 2. So How can we convert it to identical-shaped multivariate distribution?" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use Independent to convert batch shape to the event shape\n", "\n", "bivariate_normal_from_Independent = tfd.Independent(batched_normal, reinterpreted_batch_ndims=1)\n", "bivariate_normal_from_Independent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The meaning is that, the batch dimension of specific distribution will be regarded as events in the new distribution. So you can that the output distribution has the shape of 2, not batch size of 2. In order to visualize it, we can use joint plot." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "samples = bivariate_normal_from_Independent.sample(10000)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde', space=0, color='b', xlim=[-4, 4], ylim=[-4, 4], fill=True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So is it the same as multivariate one? Let's check this out." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use MultivariateNormalDiag to create the equivalent distribution\n", "# Note that diagonal covariance matrix => no correlation => independence (for the multivariate normal distribution)\n", "\n", "bivariate_normal_from_Multivariate = tfd.MultivariateNormalDiag(loc=locs, scale_diag=scales)\n", "bivariate_normal_from_Multivariate" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "samples = bivariate_normal_from_Multivariate.sample(10000)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde', color='r', xlim=[-4, 4], ylim=[-4, 4], fill=True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Shifting batch dimensions to event dimensions using `reinterpreted_batch_ndims`\n", "\n", "So we need to understand the usage of `reinterpreted_batch_ndims`, since the output distribution is different in terms of this parameter." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Demonstrate usage of reinterpreted_batch_ndims\n", "# By default, all batch dims except the first are transferred to event dims\n", "\n", "loc_grid = [[-100., -100.],\n", " [100., 100.],\n", " [0., 0.]]\n", "\n", "scale_grid = [[1., 10.],\n", " [1., 10.],\n", " [1., 1.]]\n", "\n", "normals_batch_3by2_event_1 = tfd.Normal(loc=loc_grid, scale=scale_grid)\n", "normals_batch_3by2_event_1" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3, 2)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array(loc_grid).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now have a batch of 3 bivariate normal distributions, and each paramterized by a column of our original parameter grid." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normals_batch_3_event_2 = tfd.Independent(normals_batch_3by2_event_1)\n", "normals_batch_3_event_2" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate log prob\n", "normals_batch_3_event_2.log_prob(loc_grid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can also reinterpret all batch dimensions as event dimensions." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normals_batch_1_event_3by2 = tfd.Independent(normals_batch_3by2_event_1, reinterpreted_batch_ndims=2)\n", "normals_batch_1_event_3by2" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Take log_probs\n", "\n", "normals_batch_1_event_3by2.log_prob(loc_grid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using `Independent` to build a Naive Bayes classifier\n", "\n", "### Introduction to `newsgroup` dataset\n", "\n", "In this tutorial, just load the dataset, fetch train/test splits, probably choose a subset of the data.\n", "\n", "Construct the class conditional feature distribution (with Independent, using the Naive Bayes assumption) and sample from it.\n", "\n", "We can just use the ML estimates for parameters, in later tutorials we will learn them." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# Convenience function for retrieving the 20 newsgroups data set\n", "\n", "# Usenet was a forerunner to modern internet forums\n", "# Users could post and read articles\n", "# Newsgroup corresponded to a topic\n", "# Example topics in this data set: IBM computer hardware, baseball\n", "# Our objective is to use an article's contents to predict its newsgroup,\n", "# a 20-class classification problem.\n", "\n", "# 18000 newsgroups, posts on 20 topics\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "# Get the train data\n", "newsgroup_data = fetch_20newsgroups(data_home='./dataset/20_Newsgroup_Data/', subset='train')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".. _20newsgroups_dataset:\n", "\n", "The 20 newsgroups text dataset\n", "------------------------------\n", "\n", "The 20 newsgroups dataset comprises around 18000 newsgroups posts on\n", "20 topics split in two subsets: one for training (or development)\n", "and the other one for testing (or for performance evaluation). The split\n", "between the train and test set is based upon a messages posted before\n", "and after a specific date.\n", "\n", "This module contains two loaders. The first one,\n", ":func:`sklearn.datasets.fetch_20newsgroups`,\n", "returns a list of the raw texts that can be fed to text feature\n", "extractors such as :class:`~sklearn.feature_extraction.text.CountVectorizer`\n", "with custom parameters so as to extract feature vectors.\n", "The second one, :func:`sklearn.datasets.fetch_20newsgroups_vectorized`,\n", "returns ready-to-use features, i.e., it is not necessary to use a feature\n", "extractor.\n", "\n", "**Data Set Characteristics:**\n", "\n", " ================= ==========\n", " Classes 20\n", " Samples total 18846\n", " Dimensionality 1\n", " Features text\n", " ================= ==========\n", "\n", "Usage\n", "~~~~~\n", "\n", "The :func:`sklearn.datasets.fetch_20newsgroups` function is a data\n", "fetching / caching functions that downloads the data archive from\n", "the original `20 newsgroups website`_, extracts the archive contents\n", "in the ``~/scikit_learn_data/20news_home`` folder and calls the\n", ":func:`sklearn.datasets.load_files` on either the training or\n", "testing set folder, or both of them::\n", "\n", " >>> from sklearn.datasets import fetch_20newsgroups\n", " >>> newsgroups_train = fetch_20newsgroups(subset='train')\n", "\n", " >>> from pprint import pprint\n", " >>> pprint(list(newsgroups_train.target_names))\n", " ['alt.atheism',\n", " 'comp.graphics',\n", " 'comp.os.ms-windows.misc',\n", " 'comp.sys.ibm.pc.hardware',\n", " 'comp.sys.mac.hardware',\n", " 'comp.windows.x',\n", " 'misc.forsale',\n", " 'rec.autos',\n", " 'rec.motorcycles',\n", " 'rec.sport.baseball',\n", " 'rec.sport.hockey',\n", " 'sci.crypt',\n", " 'sci.electronics',\n", " 'sci.med',\n", " 'sci.space',\n", " 'soc.religion.christian',\n", " 'talk.politics.guns',\n", " 'talk.politics.mideast',\n", " 'talk.politics.misc',\n", " 'talk.religion.misc']\n", "\n", "The real data lies in the ``filenames`` and ``target`` attributes. The target\n", "attribute is the integer index of the category::\n", "\n", " >>> newsgroups_train.filenames.shape\n", " (11314,)\n", " >>> newsgroups_train.target.shape\n", " (11314,)\n", " >>> newsgroups_train.target[:10]\n", " array([ 7, 4, 4, 1, 14, 16, 13, 3, 2, 4])\n", "\n", "It is possible to load only a sub-selection of the categories by passing the\n", "list of the categories to load to the\n", ":func:`sklearn.datasets.fetch_20newsgroups` function::\n", "\n", " >>> cats = ['alt.atheism', 'sci.space']\n", " >>> newsgroups_train = fetch_20newsgroups(subset='train', categories=cats)\n", "\n", " >>> list(newsgroups_train.target_names)\n", " ['alt.atheism', 'sci.space']\n", " >>> newsgroups_train.filenames.shape\n", " (1073,)\n", " >>> newsgroups_train.target.shape\n", " (1073,)\n", " >>> newsgroups_train.target[:10]\n", " array([0, 1, 1, 1, 0, 1, 1, 0, 0, 0])\n", "\n", "Converting text to vectors\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "In order to feed predictive or clustering models with the text data,\n", "one first need to turn the text into vectors of numerical values suitable\n", "for statistical analysis. This can be achieved with the utilities of the\n", "``sklearn.feature_extraction.text`` as demonstrated in the following\n", "example that extract `TF-IDF`_ vectors of unigram tokens\n", "from a subset of 20news::\n", "\n", " >>> from sklearn.feature_extraction.text import TfidfVectorizer\n", " >>> categories = ['alt.atheism', 'talk.religion.misc',\n", " ... 'comp.graphics', 'sci.space']\n", " >>> newsgroups_train = fetch_20newsgroups(subset='train',\n", " ... categories=categories)\n", " >>> vectorizer = TfidfVectorizer()\n", " >>> vectors = vectorizer.fit_transform(newsgroups_train.data)\n", " >>> vectors.shape\n", " (2034, 34118)\n", "\n", "The extracted TF-IDF vectors are very sparse, with an average of 159 non-zero\n", "components by sample in a more than 30000-dimensional space\n", "(less than .5% non-zero features)::\n", "\n", " >>> vectors.nnz / float(vectors.shape[0])\n", " 159.01327...\n", "\n", ":func:`sklearn.datasets.fetch_20newsgroups_vectorized` is a function which \n", "returns ready-to-use token counts features instead of file names.\n", "\n", ".. _`20 newsgroups website`: http://people.csail.mit.edu/jrennie/20Newsgroups/\n", ".. _`TF-IDF`: https://en.wikipedia.org/wiki/Tf-idf\n", "\n", "\n", "Filtering text for more realistic training\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "It is easy for a classifier to overfit on particular things that appear in the\n", "20 Newsgroups data, such as newsgroup headers. Many classifiers achieve very\n", "high F-scores, but their results would not generalize to other documents that\n", "aren't from this window of time.\n", "\n", "For example, let's look at the results of a multinomial Naive Bayes classifier,\n", "which is fast to train and achieves a decent F-score::\n", "\n", " >>> from sklearn.naive_bayes import MultinomialNB\n", " >>> from sklearn import metrics\n", " >>> newsgroups_test = fetch_20newsgroups(subset='test',\n", " ... categories=categories)\n", " >>> vectors_test = vectorizer.transform(newsgroups_test.data)\n", " >>> clf = MultinomialNB(alpha=.01)\n", " >>> clf.fit(vectors, newsgroups_train.target)\n", " MultinomialNB(alpha=0.01, class_prior=None, fit_prior=True)\n", "\n", " >>> pred = clf.predict(vectors_test)\n", " >>> metrics.f1_score(newsgroups_test.target, pred, average='macro')\n", " 0.88213...\n", "\n", "(The example :ref:`sphx_glr_auto_examples_text_plot_document_classification_20newsgroups.py` shuffles\n", "the training and test data, instead of segmenting by time, and in that case\n", "multinomial Naive Bayes gets a much higher F-score of 0.88. Are you suspicious\n", "yet of what's going on inside this classifier?)\n", "\n", "Let's take a look at what the most informative features are:\n", "\n", " >>> import numpy as np\n", " >>> def show_top10(classifier, vectorizer, categories):\n", " ... feature_names = np.asarray(vectorizer.get_feature_names())\n", " ... for i, category in enumerate(categories):\n", " ... top10 = np.argsort(classifier.coef_[i])[-10:]\n", " ... print(\"%s: %s\" % (category, \" \".join(feature_names[top10])))\n", " ...\n", " >>> show_top10(clf, vectorizer, newsgroups_train.target_names)\n", " alt.atheism: edu it and in you that is of to the\n", " comp.graphics: edu in graphics it is for and of to the\n", " sci.space: edu it that is in and space to of the\n", " talk.religion.misc: not it you in is that and to of the\n", "\n", "\n", "You can now see many things that these features have overfit to:\n", "\n", "- Almost every group is distinguished by whether headers such as\n", " ``NNTP-Posting-Host:`` and ``Distribution:`` appear more or less often.\n", "- Another significant feature involves whether the sender is affiliated with\n", " a university, as indicated either by their headers or their signature.\n", "- The word \"article\" is a significant feature, based on how often people quote\n", " previous posts like this: \"In article [article ID], [name] <[e-mail address]>\n", " wrote:\"\n", "- Other features match the names and e-mail addresses of particular people who\n", " were posting at the time.\n", "\n", "With such an abundance of clues that distinguish newsgroups, the classifiers\n", "barely have to identify topics from text at all, and they all perform at the\n", "same high level.\n", "\n", "For this reason, the functions that load 20 Newsgroups data provide a\n", "parameter called **remove**, telling it what kinds of information to strip out\n", "of each file. **remove** should be a tuple containing any subset of\n", "``('headers', 'footers', 'quotes')``, telling it to remove headers, signature\n", "blocks, and quotation blocks respectively.\n", "\n", " >>> newsgroups_test = fetch_20newsgroups(subset='test',\n", " ... remove=('headers', 'footers', 'quotes'),\n", " ... categories=categories)\n", " >>> vectors_test = vectorizer.transform(newsgroups_test.data)\n", " >>> pred = clf.predict(vectors_test)\n", " >>> metrics.f1_score(pred, newsgroups_test.target, average='macro')\n", " 0.77310...\n", "\n", "This classifier lost over a lot of its F-score, just because we removed\n", "metadata that has little to do with topic classification.\n", "It loses even more if we also strip this metadata from the training data:\n", "\n", " >>> newsgroups_train = fetch_20newsgroups(subset='train',\n", " ... remove=('headers', 'footers', 'quotes'),\n", " ... categories=categories)\n", " >>> vectors = vectorizer.fit_transform(newsgroups_train.data)\n", " >>> clf = MultinomialNB(alpha=.01)\n", " >>> clf.fit(vectors, newsgroups_train.target)\n", " MultinomialNB(alpha=0.01, class_prior=None, fit_prior=True)\n", "\n", " >>> vectors_test = vectorizer.transform(newsgroups_test.data)\n", " >>> pred = clf.predict(vectors_test)\n", " >>> metrics.f1_score(newsgroups_test.target, pred, average='macro')\n", " 0.76995...\n", "\n", "Some other classifiers cope better with this harder version of the task. Try\n", "running :ref:`sphx_glr_auto_examples_model_selection_grid_search_text_feature_extraction.py` with and without\n", "the ``--filter`` option to compare the results.\n", "\n", ".. topic:: Recommendation\n", "\n", " When evaluating text classifiers on the 20 Newsgroups data, you\n", " should strip newsgroup-related metadata. In scikit-learn, you can do this by\n", " setting ``remove=('headers', 'footers', 'quotes')``. The F-score will be\n", " lower because it is more realistic.\n", "\n", ".. topic:: Examples\n", "\n", " * :ref:`sphx_glr_auto_examples_model_selection_grid_search_text_feature_extraction.py`\n", "\n", " * :ref:`sphx_glr_auto_examples_text_plot_document_classification_20newsgroups.py`\n", "\n" ] } ], "source": [ "print(newsgroup_data['DESCR'])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From: lerxst@wam.umd.edu (where's my thing)\n", "Subject: WHAT car is this!?\n", "Nntp-Posting-Host: rac3.wam.umd.edu\n", "Organization: University of Maryland, College Park\n", "Lines: 15\n", "\n", " I was wondering if anyone out there could enlighten me on this car I saw\n", "the other day. It was a 2-door sports car, looked to be from the late 60s/\n", "early 70s. It was called a Bricklin. The doors were really small. In addition,\n", "the front bumper was separate from the rest of the body. This is \n", "all I know. If anyone can tellme a model name, engine specs, years\n", "of production, where this car is made, history, or whatever info you\n", "have on this funky looking car, please e-mail.\n", "\n", "Thanks,\n", "- IL\n", " ---- brought to you by your neighborhood Lerxst ----\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "# Example article\n", "\n", "print(newsgroup_data['data'][0])" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Associated label\n", "label = newsgroup_data['target'][0]\n", "label" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'rec.autos'" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The name of label\n", "newsgroup_data['target_names'][label]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It means that this news is related on cars (or autos)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To handle this ML pipeline, we need to preprocess it." ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "n_documents = len(newsgroup_data['data'])\n", "\n", "cv = CountVectorizer(input='content', binary=True, max_df=0.25, min_df=1.01 / n_documents)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "binary_bag_of_words = cv.fit_transform(newsgroup_data['data'])" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(11314, 56365)" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check shape\n", "binary_bag_of_words.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check the output by using inverse transform from CountVectorizer." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array(['lerxst', 'wam', 'umd', 'where', 'thing', 'car', 'rac3',\n", " 'maryland', 'college', 'park', '15', 'wondering', 'anyone',\n", " 'could', 'enlighten', 'saw', 'day', 'door', 'sports', 'looked',\n", " 'late', '60s', 'early', '70s', 'called', 'bricklin', 'doors',\n", " 'were', 'really', 'small', 'addition', 'front', 'bumper',\n", " 'separate', 'rest', 'body', 'tellme', 'model', 'name', 'engine',\n", " 'specs', 'years', 'production', 'made', 'history', 'whatever',\n", " 'info', 'funky', 'looking', 'please', 'mail', 'thanks', 'il',\n", " 'brought', 'neighborhood'], dtype='" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p_x_given_y = tfd.Independent(batch_of_bernoullis, reinterpreted_batch_ndims=1)\n", "p_x_given_y" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Take a sample of words from each class\n", "samples = p_x_given_y.sample(10)\n", "samples" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'rec.sport.hockey'" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Choose a specific class to test\n", "chosen_class = 10\n", "newsgroup_data['target_names'][chosen_class]" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Indicators for words that appear in the sample\n", "class_sample = samples[:, chosen_class, :]\n", "class_sample" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['1076', '108', '12', '18', '184', '1984', '259', '32', '48', '514',\n", " '54', '97', '_are_', 'admirals', 'advance', 'affinity', 'after',\n", " 'against', 'alberta', 'attitude', 'babych', 'basis', 'before',\n", " 'best', 'blues', 'both', 'cci632', 'ccohen', 'central', 'champs',\n", " 'closed', 'computer', 'consistently', 'couple', 'designated',\n", " 'development', 'devils', 'distribution', 'div', 'does', 'doherty',\n", " 'droopy', 'during', 'effort', 'engineering', 'entity', 'every',\n", " 'everyone', 'expensive', 'final', 'finland', 'first', 'foster',\n", " 'franchise', 'gballent', 'georgia', 'gilhen', 'goals',\n", " 'goaltenders', 'god', 'going', 'good', 'guy', 'had', 'haha',\n", " 'happened', 'he', 'head', 'home', 'however', 'kick', 'looks',\n", " 'maine', 'models', 'mom', 'need', 'nne', 'ny', 'off', 'ot',\n", " 'penguins', 'pittsburgh', 'play', 'played', 'rangers', 'really',\n", " 'record', 'rex', 'right', 'san', 'stadium', 'still', 'streak',\n", " 'style', 'talking', 'terry_yake', 'then', 'though', 'tied', 'top',\n", " 'two', 'uci', 'us', 'vancouver', 'waiting', 'wang', 'washington',\n", " 'wirtz', 'zzzzzz'], dtype='