{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 08 - Naive Bayes\n", "\n", "by [Alejandro Correa Bahnsen](albahnsen.com/)\n", "\n", "version 0.1, Mar 2016\n", "\n", "## Part of the class [Practical Machine Learning](https://github.com/albahnsen/PracticalMachineLearningClass)\n", "\n", "\n", "\n", "This notebook is licensed under a [Creative Commons Attribution-ShareAlike 3.0 Unported License](http://creativecommons.org/licenses/by-sa/3.0/deed.en_US). Special thanks goes to [Kevin Markham](https://github.com/justmarkham), [Sebastian Raschka](http://sebastianraschka.com/) & [Scikit-learn docs](http://scikit-learn.org/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Naive Bayes\n", "\n", "Naive Bayes methods are a set of supervised learning algorithms\n", "based on applying Bayes' theorem with the \"naive\" assumption of independence\n", "between every pair of features. Given a class variable $y$ and a\n", "dependent feature vector $x_1$ through $x_n$,\n", "Bayes' theorem states the following relationship:\n", "\n", "$$\n", " P(y \\mid x_1, \\dots, x_n) = \\frac{P(y) P(x_1, \\dots x_n \\mid y)}\n", " {P(x_1, \\dots, x_n)}\n", "$$\n", "\n", "Using the naive independence assumption that\n", "\n", "$$\n", " P(x_i | y, x_1, \\dots, x_{i-1}, x_{i+1}, \\dots, x_n) = P(x_i | y),\n", "$$\n", "\n", "for all $i$, this relationship is simplified to\n", "\n", "$$\n", " P(y \\mid x_1, \\dots, x_n) = \\frac{P(y) \\prod_{i=1}^{n} P(x_i \\mid y)}\n", " {P(x_1, \\dots, x_n)}\n", "$$\n", "\n", "Since $P(x_1, \\dots, x_n)$ is constant given the input,\n", "we can use the following classification rule:\n", "\n", "$$ P(y \\mid x_1, \\dots, x_n) \\propto P(y) \\prod_{i=1}^{n} P(x_i \\mid y) $$\n", "\n", "$$ \\Downarrow$$\n", "$$\n", " \\hat{y} = \\arg\\max_y P(y) \\prod_{i=1}^{n} P(x_i \\mid y),\n", "$$\n", "\n", "and we can use Maximum A Posteriori (MAP) estimation to estimate\n", "$P(y)$ and $P(x_i \\mid y)$;\n", "the former is then the relative frequency of class :math:`y`\n", "in the training set.\n", "\n", "The different naive Bayes classifiers differ mainly by the assumptions they\n", "make regarding the distribution of $P(x_i \\mid y)$.\n", "\n", "In spite of their apparently over-simplified assumptions, naive Bayes\n", "classifiers have worked quite well in many real-world situations, famously\n", "document classification and spam filtering. They require a small amount\n", "of training data to estimate the necessary parameters. (For theoretical\n", "reasons why naive Bayes works well, and on which types of data it does, see\n", "the references below.)\n", "\n", "Naive Bayes learners and classifiers can be extremely fast compared to more\n", "sophisticated methods.\n", "The decoupling of the class conditional feature distributions means that each\n", "distribution can be independently estimated as a one dimensional distribution.\n", "This in turn helps to alleviate problems stemming from the curse of\n", "dimensionality.\n", "\n", "On the flip side, although naive Bayes is known as a decent classifier,\n", "it is known to be a bad estimator, so the probability outputs from\n", "``predict_proba`` are not to be taken too seriously." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gaussian Naive Bayes\n", "\n", "`GaussianNB` implements the Gaussian Naive Bayes algorithm for\n", "classification. The likelihood of the features is assumed to be Gaussian:\n", "\n", "$$ P(x_i \\mid y) = \\frac{1}{\\sqrt{2\\pi\\sigma^2_y}} \\exp\\left(-\\frac{(x_i - \\mu_y)^2}{2\\sigma^2_y}\\right) $$\n", "\n", "The parameters $\\sigma_y$ and $\\mu_y$\n", "are estimated using maximum likelihood." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Applying Bayes' theorem to iris classification\n", "\n", "## Preparing the data\n", "\n", "We'll read the iris data into a DataFrame, and **round up** all of the measurements to the next integer:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
05.13.51.40.2Iris-setosa
14.93.01.40.2Iris-setosa
24.73.21.30.2Iris-setosa
34.63.11.50.2Iris-setosa
45.03.61.40.2Iris-setosa
\n", "
" ], "text/plain": [ " sepal_length sepal_width petal_length petal_width species\n", "0 5.1 3.5 1.4 0.2 Iris-setosa\n", "1 4.9 3.0 1.4 0.2 Iris-setosa\n", "2 4.7 3.2 1.3 0.2 Iris-setosa\n", "3 4.6 3.1 1.5 0.2 Iris-setosa\n", "4 5.0 3.6 1.4 0.2 Iris-setosa" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "# read the iris data into a DataFrame\n", "url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'\n", "col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']\n", "iris = pd.read_csv(url, header=None, names=col_names)\n", "iris.head()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
06421Iris-setosa
15321Iris-setosa
25421Iris-setosa
35421Iris-setosa
45421Iris-setosa
\n", "
" ], "text/plain": [ " sepal_length sepal_width petal_length petal_width species\n", "0 6 4 2 1 Iris-setosa\n", "1 5 3 2 1 Iris-setosa\n", "2 5 4 2 1 Iris-setosa\n", "3 5 4 2 1 Iris-setosa\n", "4 5 4 2 1 Iris-setosa" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# apply the ceiling function to the numeric columns\n", "iris.loc[:, 'sepal_length':'petal_width'] = iris.loc[:, 'sepal_length':'petal_width'].apply(np.ceil)\n", "iris.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Deciding how to make a prediction\n", "\n", "Let's say that I have an **out-of-sample iris** with the following measurements: **7, 3, 5, 2**. How might I predict the species?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
547352Iris-versicolor
587352Iris-versicolor
637352Iris-versicolor
687352Iris-versicolor
727352Iris-versicolor
737352Iris-versicolor
747352Iris-versicolor
757352Iris-versicolor
767352Iris-versicolor
777352Iris-versicolor
877352Iris-versicolor
917352Iris-versicolor
977352Iris-versicolor
1237352Iris-virginica
1267352Iris-virginica
1277352Iris-virginica
1467352Iris-virginica
\n", "
" ], "text/plain": [ " sepal_length sepal_width petal_length petal_width species\n", "54 7 3 5 2 Iris-versicolor\n", "58 7 3 5 2 Iris-versicolor\n", "63 7 3 5 2 Iris-versicolor\n", "68 7 3 5 2 Iris-versicolor\n", "72 7 3 5 2 Iris-versicolor\n", "73 7 3 5 2 Iris-versicolor\n", "74 7 3 5 2 Iris-versicolor\n", "75 7 3 5 2 Iris-versicolor\n", "76 7 3 5 2 Iris-versicolor\n", "77 7 3 5 2 Iris-versicolor\n", "87 7 3 5 2 Iris-versicolor\n", "91 7 3 5 2 Iris-versicolor\n", "97 7 3 5 2 Iris-versicolor\n", "123 7 3 5 2 Iris-virginica\n", "126 7 3 5 2 Iris-virginica\n", "127 7 3 5 2 Iris-virginica\n", "146 7 3 5 2 Iris-virginica" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# show all observations with features: 7, 3, 5, 2\n", "iris[(iris.sepal_length==7) & (iris.sepal_width==3) & (iris.petal_length==5) & (iris.petal_width==2)]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Iris-versicolor 13\n", "Iris-virginica 4\n", "Name: species, dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# count the species for these observations\n", "iris[(iris.sepal_length==7) & (iris.sepal_width==3) & (iris.petal_length==5) & (iris.petal_width==2)].species.value_counts()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Iris-versicolor 50\n", "Iris-virginica 50\n", "Iris-setosa 50\n", "Name: species, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# count the species for all observations\n", "iris.species.value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's frame this as a **conditional probability problem**: What is the probability of some particular species, given the measurements 7, 3, 5, and 2?\n", "\n", "$$P(species \\ | \\ 7352)$$\n", "\n", "We could calculate the conditional probability for **each of the three species**, and then predict the species with the **highest probability**:\n", "\n", "$$P(setosa \\ | \\ 7352)$$\n", "$$P(versicolor \\ | \\ 7352)$$\n", "$$P(virginica \\ | \\ 7352)$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calculating the probability of each species\n", "\n", "**Bayes' theorem** gives us a way to calculate these conditional probabilities.\n", "\n", "Let's start with **versicolor**:\n", "\n", "$$P(versicolor \\ | \\ 7352) = \\frac {P(7352 \\ | \\ versicolor) \\times P(versicolor)} {P(7352)}$$\n", "\n", "We can calculate each of the terms on the right side of the equation:\n", "\n", "$$P(7352 \\ | \\ versicolor) = \\frac {13} {50} = 0.26$$\n", "\n", "$$P(versicolor) = \\frac {50} {150} = 0.33$$\n", "\n", "$$P(7352) = \\frac {17} {150} = 0.11$$\n", "\n", "Therefore, Bayes' theorem says the **probability of versicolor given these measurements** is:\n", "\n", "$$P(versicolor \\ | \\ 7352) = \\frac {0.26 \\times 0.33} {0.11} = 0.76$$\n", "\n", "Let's repeat this process for **virginica** and **setosa**:\n", "\n", "$$P(virginica \\ | \\ 7352) = \\frac {0.08 \\times 0.33} {0.11} = 0.24$$\n", "\n", "$$P(setosa \\ | \\ 7352) = \\frac {0 \\times 0.33} {0.11} = 0$$\n", "\n", "We predict that the iris is a versicolor, since that species had the **highest conditional probability**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using sklearn" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.naive_bayes import GaussianNB\n", "gnb = GaussianNB()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X = iris[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "le = LabelEncoder()\n", "y = le.fit_transform(iris['species'])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of mislabeled points out of a total 150 points : 10\n" ] } ], "source": [ "gnb.fit(X, y)\n", "y_pred = gnb.predict(X)\n", "print(\"Number of mislabeled points out of a total %d points : %d\" \n", " % (iris.shape[0],(y != y_pred).sum()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The intuition behind Bayes' theorem\n", "\n", "Let's make some hypothetical adjustments to the data, to demonstrate how Bayes' theorem makes intuitive sense:\n", "\n", "Pretend that **more of the existing versicolors had measurements of 7352:**\n", "\n", "- $P(7352 \\ | \\ versicolor)$ would increase, thus increasing the numerator.\n", "- It would make sense that given an iris with measurements of 7352, the probability of it being a versicolor would also increase.\n", "\n", "Pretend that **most of the existing irises were versicolor:**\n", "\n", "- $P(versicolor)$ would increase, thus increasing the numerator.\n", "- It would make sense that the probability of any iris being a versicolor (regardless of measurements) would also increase.\n", "\n", "Pretend that **17 of the setosas had measurements of 7352:**\n", "\n", "- $P(7352)$ would double, thus doubling the denominator.\n", "- It would make sense that given an iris with measurements of 7352, the probability of it being a versicolor would be cut in half." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Why is the Naive Bayes Classifier naive?\n", "\n", "Let's start by taking a quick look at the Bayes' Theorem:\n", "\n", "![](https://github.com/rasbt/python-machine-learning-book/raw/4c49547869a1a6b3798c20ea231c0974aa887302/faq/naive-naive-bayes/bayes-theorem-english.png)\n", "\n", "In context of pattern classification, we can express it as\n", "\n", "![](https://github.com/rasbt/python-machine-learning-book/raw/4c49547869a1a6b3798c20ea231c0974aa887302/faq/naive-naive-bayes/bayes_theorem.png)\n", "\n", "![](https://github.com/rasbt/python-machine-learning-book/raw/4c49547869a1a6b3798c20ea231c0974aa887302/faq/naive-naive-bayes/let.png)\n", "\n", "If we use the Bayes Theorem in classification, our goal (or objective function) is to maximize the posterior probability\n", "\n", "![](https://github.com/rasbt/python-machine-learning-book/raw/4c49547869a1a6b3798c20ea231c0974aa887302/faq/naive-naive-bayes/decision_rule.png)\n", "\n", "Now, let's talk a bit more about the individual components. The priors are representing our expert (or any other prior) knowledge; in practice, the priors are often estimated via MLE (computed as class frequencies). The evidence term cancels because it is constant for all classes.\n", "\n", "Moving on to the \"naive\" part in the Naive Bayes Classifier: What makes it \"naive\" is that we compute the conditional probability (sometimes also called likelihoods) as the product of the individual probabilities for each feature:\n", "\n", "![](https://github.com/rasbt/python-machine-learning-book/raw/4c49547869a1a6b3798c20ea231c0974aa887302/faq/naive-naive-bayes/likelihood.png)\n", "\n", "Since this assumption (the absolute independence of features) is probably never met in practice, it's the truly \"naive\" part in naive Bayes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Working with Text Data and Naive Bayes in scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Representing text as data\n", "\n", "From the [scikit-learn documentation](http://scikit-learn.org/stable/modules/feature_extraction.html#text-feature-extraction):\n", "\n", "> Text Analysis is a major application field for machine learning algorithms. However the raw data, a sequence of symbols cannot be fed directly to the algorithms themselves as most of them expect **numerical feature vectors with a fixed size** rather than the **raw text documents with variable length**.\n", "\n", "We will use [CountVectorizer](http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html) to \"convert text into a matrix of token counts\":" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.feature_extraction.text import CountVectorizer" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# start with a simple example\n", "simple_train = ['call you tonight', 'Call me a cab', 'please call me... PLEASE!']" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['cab', 'call', 'me', 'please', 'tonight', 'you']" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# learn the 'vocabulary' of the training data\n", "vect = CountVectorizer()\n", "vect.fit(simple_train)\n", "vect.get_feature_names()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "<3x6 sparse matrix of type ''\n", "\twith 9 stored elements in Compressed Sparse Row format>" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# transform training data into a 'document-term matrix'\n", "simple_train_dtm = vect.transform(simple_train)\n", "simple_train_dtm" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " (0, 1)\t1\n", " (0, 4)\t1\n", " (0, 5)\t1\n", " (1, 0)\t1\n", " (1, 1)\t1\n", " (1, 2)\t1\n", " (2, 1)\t1\n", " (2, 2)\t1\n", " (2, 3)\t2\n" ] } ], "source": [ "# print the sparse matrix\n", "print(simple_train_dtm)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[0, 1, 0, 0, 1, 1],\n", " [1, 1, 1, 0, 0, 0],\n", " [0, 1, 1, 2, 0, 0]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# convert sparse matrix to a dense matrix\n", "simple_train_dtm.toarray()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cabcallmepleasetonightyou
0010011
1111000
2011200
\n", "
" ], "text/plain": [ " cab call me please tonight you\n", "0 0 1 0 0 1 1\n", "1 1 1 1 0 0 0\n", "2 0 1 1 2 0 0" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# examine the vocabulary and document-term matrix together\n", "pd.DataFrame(simple_train_dtm.toarray(), columns=vect.get_feature_names())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the [scikit-learn documentation](http://scikit-learn.org/stable/modules/feature_extraction.html#text-feature-extraction):\n", "\n", "> In this scheme, features and samples are defined as follows:\n", "\n", "> - Each individual token occurrence frequency (normalized or not) is treated as a **feature**.\n", "> - The vector of all the token frequencies for a given document is considered a multivariate **sample**.\n", "\n", "> A **corpus of documents** can thus be represented by a matrix with **one row per document** and **one column per token** (e.g. word) occurring in the corpus.\n", "\n", "> We call **vectorization** the general process of turning a collection of text documents into numerical feature vectors. This specific strategy (tokenization, counting and normalization) is called the **Bag of Words** or \"Bag of n-grams\" representation. Documents are described by word occurrences while completely ignoring the relative position information of the words in the document." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[0, 1, 1, 1, 0, 0]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# transform testing data into a document-term matrix (using existing vocabulary)\n", "simple_test = [\"please don't call me\"]\n", "simple_test_dtm = vect.transform(simple_test)\n", "simple_test_dtm.toarray()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cabcallmepleasetonightyou
0011100
\n", "
" ], "text/plain": [ " cab call me please tonight you\n", "0 0 1 1 1 0 0" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# examine the vocabulary and document-term matrix together\n", "pd.DataFrame(simple_test_dtm.toarray(), columns=vect.get_feature_names())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Summary:**\n", "\n", "- `vect.fit(train)` learns the vocabulary of the training data\n", "- `vect.transform(train)` uses the fitted vocabulary to build a document-term matrix from the training data\n", "- `vect.transform(test)` uses the fitted vocabulary to build a document-term matrix from the testing data (and ignores tokens it hasn't seen before)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reading SMS data\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(5572, 2)\n" ] } ], "source": [ "# read tab-separated file\n", "url = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/sms.tsv'\n", "col_names = ['label', 'message']\n", "sms = pd.read_table(url, sep='\\t', header=None, names=col_names)\n", "print(sms.shape)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelmessage
0hamGo until jurong point, crazy.. Available only ...
1hamOk lar... Joking wif u oni...
2spamFree entry in 2 a wkly comp to win FA Cup fina...
3hamU dun say so early hor... U c already then say...
4hamNah I don't think he goes to usf, he lives aro...
5spamFreeMsg Hey there darling it's been 3 week's n...
6hamEven my brother is not like to speak with me. ...
7hamAs per your request 'Melle Melle (Oru Minnamin...
8spamWINNER!! As a valued network customer you have...
9spamHad your mobile 11 months or more? U R entitle...
10hamI'm gonna be home soon and i don't want to tal...
11spamSIX chances to win CASH! From 100 to 20,000 po...
12spamURGENT! You have won a 1 week FREE membership ...
13hamI've been searching for the right words to tha...
14hamI HAVE A DATE ON SUNDAY WITH WILL!!
15spamXXXMobileMovieClub: To use your credit, click ...
16hamOh k...i'm watching here:)
17hamEh u remember how 2 spell his name... Yes i di...
18hamFine if that’s the way u feel. That’s the way ...
19spamEngland v Macedonia - dont miss the goals/team...
\n", "
" ], "text/plain": [ " label message\n", "0 ham Go until jurong point, crazy.. Available only ...\n", "1 ham Ok lar... Joking wif u oni...\n", "2 spam Free entry in 2 a wkly comp to win FA Cup fina...\n", "3 ham U dun say so early hor... U c already then say...\n", "4 ham Nah I don't think he goes to usf, he lives aro...\n", "5 spam FreeMsg Hey there darling it's been 3 week's n...\n", "6 ham Even my brother is not like to speak with me. ...\n", "7 ham As per your request 'Melle Melle (Oru Minnamin...\n", "8 spam WINNER!! As a valued network customer you have...\n", "9 spam Had your mobile 11 months or more? U R entitle...\n", "10 ham I'm gonna be home soon and i don't want to tal...\n", "11 spam SIX chances to win CASH! From 100 to 20,000 po...\n", "12 spam URGENT! You have won a 1 week FREE membership ...\n", "13 ham I've been searching for the right words to tha...\n", "14 ham I HAVE A DATE ON SUNDAY WITH WILL!!\n", "15 spam XXXMobileMovieClub: To use your credit, click ...\n", "16 ham Oh k...i'm watching here:)\n", "17 ham Eh u remember how 2 spell his name... Yes i di...\n", "18 ham Fine if that’s the way u feel. That’s the way ...\n", "19 spam England v Macedonia - dont miss the goals/team..." ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sms.head(20)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "ham 4825\n", "spam 747\n", "Name: label, dtype: int64" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sms.label.value_counts()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# convert label to a numeric variable\n", "sms['label'] = sms.label.map({'ham':0, 'spam':1})" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# define X and y\n", "X = sms.message\n", "y = sms.label" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4179,)\n", "(1393,)\n" ] } ], "source": [ "# split into training and testing sets\n", "from sklearn.cross_validation import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)\n", "print(X_train.shape)\n", "print(X_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vectorizing SMS data" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# instantiate the vectorizer\n", "vect = CountVectorizer()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "<4179x7456 sparse matrix of type ''\n", "\twith 55209 stored elements in Compressed Sparse Row format>" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# learn training data vocabulary, then create document-term matrix\n", "vect.fit(X_train)\n", "X_train_dtm = vect.transform(X_train)\n", "X_train_dtm" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "<4179x7456 sparse matrix of type ''\n", "\twith 55209 stored elements in Compressed Sparse Row format>" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# alternative: combine fit and transform into a single step\n", "X_train_dtm = vect.fit_transform(X_train)\n", "X_train_dtm" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "<1393x7456 sparse matrix of type ''\n", "\twith 17604 stored elements in Compressed Sparse Row format>" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# transform testing data (using fitted vocabulary) into a document-term matrix\n", "X_test_dtm = vect.transform(X_test)\n", "X_test_dtm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Examining the tokens and their counts" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# store token names\n", "X_train_tokens = vect.get_feature_names()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['00', '000', '008704050406', '0121', '01223585236', '01223585334', '0125698789', '02', '0207', '02072069400', '02073162414', '02085076972', '021', '03', '04', '0430', '05', '050703', '0578', '06', '07', '07008009200', '07090201529', '07090298926', '07123456789', '07732584351', '07734396839', '07742676969', '0776xxxxxxx', '07781482378', '07786200117', '078', '07801543489', '07808', '07808247860', '07808726822', '07815296484', '07821230901', '07880867867', '0789xxxxxxx', '07946746291', '0796xxxxxx', '07973788240', '07xxxxxxxxx', '08', '0800', '08000407165', '08000776320', '08000839402', '08000930705']\n" ] } ], "source": [ "# first 50 tokens\n", "print(X_train_tokens[:50])" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['yer', 'yes', 'yest', 'yesterday', 'yet', 'yetunde', 'yijue', 'ym', 'ymca', 'yo', 'yoga', 'yogasana', 'yor', 'yorge', 'you', 'youdoing', 'youi', 'youphone', 'your', 'youre', 'yourjob', 'yours', 'yourself', 'youwanna', 'yowifes', 'yoyyooo', 'yr', 'yrs', 'ything', 'yummmm', 'yummy', 'yun', 'yunny', 'yuo', 'yuou', 'yup', 'zac', 'zaher', 'zealand', 'zebra', 'zed', 'zeros', 'zhong', 'zindgi', 'zoe', 'zoom', 'zouk', 'zyada', 'èn', '〨ud']\n" ] } ], "source": [ "# last 50 tokens\n", "print(X_train_tokens[-50:])" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " ..., \n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0]], dtype=int64)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view X_train_dtm as a dense matrix\n", "X_train_dtm.toarray()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([ 5, 23, 2, ..., 1, 1, 1], dtype=int64)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# count how many times EACH token appears across ALL messages in X_train_dtm\n", "import numpy as np\n", "X_train_counts = np.sum(X_train_dtm.toarray(), axis=0)\n", "X_train_counts" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(7456,)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_counts.shape" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
counttoken
37271jules
41721mallika
41691malarky
41651makiing
41611maintaining
41581mails
41571mailed
41511magicalsongs
41501maggi
41491magazine
41461madodu
41431mad2
41421mad1
41401macs
41391macleran
41381mack
41741manage
41751manageable
41781manchester
41791manda
42011marking
42001marketing
41971marine
41961margin
41931marandratha
41921maraikara
41911maps
41361machi
41901mapquest
41871manual
.........
2290292do
7257293with
7120293we
6904297ur
1081298at
2995299get
3465302if
4778306or
1522332but
4647338not
6017344so
1574349can
1016358are
4662361now
4743390on
3235416have
1552443call
6539453that
4704460of
7424508your
2821518for
4489550my
3623568it
4238601me
3612679is
3502683in
929717and
65421004the
74201660you
66561670to
\n", "

7456 rows × 2 columns

\n", "
" ], "text/plain": [ " count token\n", "3727 1 jules\n", "4172 1 mallika\n", "4169 1 malarky\n", "4165 1 makiing\n", "4161 1 maintaining\n", "4158 1 mails\n", "4157 1 mailed\n", "4151 1 magicalsongs\n", "4150 1 maggi\n", "4149 1 magazine\n", "4146 1 madodu\n", "4143 1 mad2\n", "4142 1 mad1\n", "4140 1 macs\n", "4139 1 macleran\n", "4138 1 mack\n", "4174 1 manage\n", "4175 1 manageable\n", "4178 1 manchester\n", "4179 1 manda\n", "4201 1 marking\n", "4200 1 marketing\n", "4197 1 marine\n", "4196 1 margin\n", "4193 1 marandratha\n", "4192 1 maraikara\n", "4191 1 maps\n", "4136 1 machi\n", "4190 1 mapquest\n", "4187 1 manual\n", "... ... ...\n", "2290 292 do\n", "7257 293 with\n", "7120 293 we\n", "6904 297 ur\n", "1081 298 at\n", "2995 299 get\n", "3465 302 if\n", "4778 306 or\n", "1522 332 but\n", "4647 338 not\n", "6017 344 so\n", "1574 349 can\n", "1016 358 are\n", "4662 361 now\n", "4743 390 on\n", "3235 416 have\n", "1552 443 call\n", "6539 453 that\n", "4704 460 of\n", "7424 508 your\n", "2821 518 for\n", "4489 550 my\n", "3623 568 it\n", "4238 601 me\n", "3612 679 is\n", "3502 683 in\n", "929 717 and\n", "6542 1004 the\n", "7420 1660 you\n", "6656 1670 to\n", "\n", "[7456 rows x 2 columns]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create a DataFrame of tokens with their counts\n", "pd.DataFrame({'token':X_train_tokens, 'count':X_train_counts}).sort_values('count')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a Multinomial Naive Bayes model\n", "\n", "`MultinomialNB` implements the naive Bayes algorithm for multinomially\n", "distributed data, and is one of the two classic naive Bayes variants used in\n", "text classification (where the data are typically represented as word vector\n", "counts, although tf-idf vectors are also known to work well in practice).\n", "The distribution is parametrized by vectors\n", "$\\theta_y = (\\theta_{y1},\\ldots,\\theta_{yn})$\n", "for each class :math:`y`, where :math:`n` is the number of features\n", "(in text classification, the size of the vocabulary)\n", "and $\\theta_{yi}$ is the probability $P(x_i \\mid y)$\n", "of feature $i$ appearing in a sample belonging to class :math:`y`.\n", "\n", "The parameters $\\theta_y$ is estimated by a smoothed\n", "version of maximum likelihood, i.e. relative frequency counting:\n", "\n", "$$\n", " \\hat{\\theta}_{yi} = \\frac{ N_{yi} + \\alpha}{N_y + \\alpha n}\n", "$$\n", "\n", "where $N_{yi} = \\sum_{x \\in T} x_i$ is\n", "the number of times feature $i$ appears in a sample of class $y$\n", "in the training set $T$,\n", "and $N_{y} = \\sum_{i=1}^{|T|} N_{yi}$ is the total count of\n", "all features for class $y$.\n", "\n", "The smoothing priors $\\alpha \\ge 0$ accounts for\n", "features not present in the learning samples and prevents zero probabilities\n", "in further computations.\n", "Setting $\\alpha = 1$ is called Laplace smoothing,\n", "while $\\alpha < 1$ is called Lidstone smoothing." ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train a Naive Bayes model using X_train_dtm\n", "from sklearn.naive_bayes import MultinomialNB\n", "nb = MultinomialNB()\n", "nb.fit(X_train_dtm, y_train)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# make class predictions for X_test_dtm\n", "y_pred_class = nb.predict(X_test_dtm)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.988513998564\n" ] } ], "source": [ "# calculate accuracy of class predictions\n", "from sklearn import metrics\n", "print(metrics.accuracy_score(y_test, y_pred_class))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1203 5]\n", " [ 11 174]]\n" ] } ], "source": [ "# confusion matrix\n", "print(metrics.confusion_matrix(y_test, y_pred_class))" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([ 2.87744864e-03, 1.83488846e-05, 2.07301295e-03, ...,\n", " 1.09026171e-06, 1.00000000e+00, 3.98279868e-09])" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# predict (poorly calibrated) probabilities\n", "y_pred_prob = nb.predict_proba(X_test_dtm)[:, 1]\n", "y_pred_prob" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.986643100054\n" ] } ], "source": [ "# calculate AUC\n", "print(metrics.roc_auc_score(y_test, y_pred_prob))" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "574 Waiting for your call.\n", "3375 Also andros ice etc etc\n", "45 No calls..messages..missed calls\n", "3415 No pic. Please re-send.\n", "1988 No calls..messages..missed calls\n", "Name: message, dtype: object" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# print message text for the false positives\n", "X_test[y_test < y_pred_class]" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "3132 LookAtMe!: Thanks for your purchase of a video...\n", "5 FreeMsg Hey there darling it's been 3 week's n...\n", "3530 Xmas & New Years Eve tickets are now on sale f...\n", "684 Hi I'm sue. I am 20 years old and work as a la...\n", "1875 Would you like to see my XXX pics they are so ...\n", "1893 CALL 09090900040 & LISTEN TO EXTREME DIRTY LIV...\n", "4298 thesmszone.com lets you send free anonymous an...\n", "4949 Hi this is Amy, we will be sending you a free ...\n", "2821 INTERFLORA - “It's not too late to order Inter...\n", "2247 Hi ya babe x u 4goten bout me?' scammers getti...\n", "4514 Money i have won wining number 946 wot do i do...\n", "Name: message, dtype: object" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# print message text for the false negatives\n", "X_test[y_test > y_pred_class]" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "\"LookAtMe!: Thanks for your purchase of a video clip from LookAtMe!, you've been charged 35p. Think you can do better? Why not send a video in a MMSto 32323.\"" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# what do you notice about the false negatives?\n", "X_test[3132]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Comparing Multinomial and Gaussian Naive Bayes\n", "\n", "scikit-learn documentation: [MultinomialNB](http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html) and [GaussianNB](http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html)\n", "\n", "Dataset: [Pima Indians Diabetes](https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabetes) from the UCI Machine Learning Repository" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# read the data\n", "url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data'\n", "col_names = ['pregnant', 'glucose', 'bp', 'skin', 'insulin', 'bmi', 'pedigree', 'age', 'label']\n", "pima = pd.read_csv(url, header=None, names=col_names)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pregnantglucosebpskininsulinbmipedigreeagelabel
061487235033.60.627501
11856629026.60.351310
28183640023.30.672321
318966239428.10.167210
40137403516843.12.288331
\n", "
" ], "text/plain": [ " pregnant glucose bp skin insulin bmi pedigree age label\n", "0 6 148 72 35 0 33.6 0.627 50 1\n", "1 1 85 66 29 0 26.6 0.351 31 0\n", "2 8 183 64 0 0 23.3 0.672 32 1\n", "3 1 89 66 23 94 28.1 0.167 21 0\n", "4 0 137 40 35 168 43.1 2.288 33 1" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# notice that all features are continuous\n", "pima.head()" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# create X and y\n", "X = pima.drop('label', axis=1)\n", "y = pima.label\n", "# split into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# import both Multinomial and Gaussian Naive Bayes\n", "from sklearn.naive_bayes import MultinomialNB, GaussianNB\n", "from sklearn import metrics" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.54166666666666663" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# testing accuracy of Multinomial Naive Bayes\n", "mnb = MultinomialNB()\n", "mnb.fit(X_train, y_train)\n", "y_pred_class = mnb.predict(X_test)\n", "metrics.accuracy_score(y_test, y_pred_class)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.79166666666666663" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# testing accuracy of Gaussian Naive Bayes\n", "gnb = GaussianNB()\n", "gnb.fit(X_train, y_train)\n", "y_pred_class = gnb.predict(X_test)\n", "metrics.accuracy_score(y_test, y_pred_class)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Conclusion:** When applying Naive Bayes classification to a dataset with **continuous features**, it is better to use Gaussian Naive Bayes than Multinomial Naive Bayes. The latter is suitable for datasets containing **discrete features** (e.g., word counts).\n", "\n", "Wikipedia has a short [description](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes) of Gaussian Naive Bayes, as well as an excellent [example](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Sex_classification) of its usage." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Key takeaways\n", "\n", "- The **\"naive\" assumption** of Naive Bayes (that the features are conditionally independent) is critical to making these calculations simple.\n", "- The **normalization constant** (the denominator) can be ignored since it's the same for all classes.\n", "- The **prior probability** is much less relevant once you have a lot of features." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# Comparing Naive Bayes with other models\n", "\n", "Advantages of Naive Bayes:\n", "\n", "- Model training and prediction are very fast\n", "- Somewhat interpretable\n", "- No tuning is required\n", "- Features don't need scaling\n", "- Insensitive to irrelevant features (with enough observations)\n", "- Performs better than logistic regression when the training set is very small\n", "\n", "Disadvantages of Naive Bayes:\n", "\n", "- Predicted probabilities are not well-calibrated\n", "- Correlated features can be problematic (due to the independence assumption)\n", "- Can't handle negative features (with Multinomial Naive Bayes)\n", "- Has a higher \"asymptotic error\" than logistic regression" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" }, "name": "_merged" }, "nbformat": 4, "nbformat_minor": 0 }