{ "metadata": { "name": "04A_supervised_classification" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 1, "metadata": {}, "source": [ "Supervised Learning: Classification of Handwritten Digits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section we'll apply scikit-learn to the classification of handwritten\n", "digits. This will go a bit beyond the iris classification we saw before: we'll\n", "discuss some of the metrics which can be used in evaluating the effectiveness\n", "of a classification model.\n", "\n", "We'll work with the handwritten digits dataset which we saw in an earlier\n", "section of the tutorial." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.datasets import load_digits\n", "digits = load_digits()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll re-use some of our code from before to visualize the data and remind us what\n", "we're looking at:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%pylab inline" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# copied from notebook 02A_representation_of_data.ipynb\n", "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n", "\n", "# plot the digits: each image is 8x8 pixels\n", "for i in range(64):\n", " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", " ax.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')\n", " \n", " # label the image with the target value\n", " ax.text(0, 7, str(digits.target[i]))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Visualizing the Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A good first-step for many problems is to visualize the data using one of the\n", "*Dimensionality Reduction* techniques we saw earlier. We'll start with the\n", "most straightforward one, Principal Component Analysis (PCA).\n", "\n", "PCA seeks orthogonal linear combinations of the features which show the greatest\n", "variance, and as such, can help give you a good idea of the structure of the\n", "data set. Here we'll use `RandomizedPCA`, because it's faster for large `N`." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.decomposition import RandomizedPCA\n", "pca = RandomizedPCA(n_components=2)\n", "proj = pca.fit_transform(digits.data)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.scatter(proj[:, 0], proj[:, 1], c=digits.target)\n", "plt.colorbar()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we see that the digits do cluster fairly well, so we can expect even\n", "a fairly naive classification scheme to do a decent job separating them.\n", "\n", "A weakness of PCA is that it produces a linear dimensionality reduction:\n", "this may miss some interesting relationships in the data. If we want to\n", "see a nonlinear mapping of the data, we can use one of the several\n", "methods in the `manifold` module. Here we'll use Isomap (a concatenation\n", "of Isometric Mapping) which is a manifold learning method based on\n", "graph theory:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.manifold import Isomap\n", "iso = Isomap(n_neighbors=5, n_components=2)\n", "proj = iso.fit_transform(digits.data)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.scatter(proj[:, 0], proj[:, 1], c=digits.target)\n", "plt.colorbar()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can be fun to explore the various manifold learning methods available,\n", "and how the output depends on the various parameters used to tune the\n", "projection.\n", "In any case, these visualizations show us that there is hope: even a simple\n", "classifier should be able to adequately identify the members of the various\n", "classes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question: Given these projections of the data, which numbers do you think\n", "a classifier might have trouble distinguishing?**" ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Gaussian Naive Bayes Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For most classification problems, it's nice to have a simple, fast, go-to\n", "method to provide a quick baseline classification. If the simple and fast\n", "method is sufficient, then we don't have to waste CPU cycles on more complex\n", "models. If not, we can use the results of the simple method to give us\n", "clues about our data.\n", "\n", "One good method to keep in mind is Gaussian Naive Bayes. It is a *generative*\n", "classifier which fits an axis-aligned multi-dimensional Gaussian distribution to\n", "each training label, and uses this to quickly give a rough classification. It\n", "is generally not sufficiently accurate for real-world data, but can perform surprisingly well." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.cross_validation import train_test_split" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# split the data into training and validation sets\n", "X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)\n", "\n", "# train the model\n", "clf = GaussianNB()\n", "clf.fit(X_train, y_train)\n", "\n", "# use the model to predict the labels of the test data\n", "predicted = clf.predict(X_test)\n", "expected = y_test" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question**: why did we split the data into training and validation sets?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the digits again with the predicted labels to get an idea of\n", "how well the classification is working:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n", "\n", "# plot the digits: each image is 8x8 pixels\n", "for i in range(64):\n", " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", " ax.imshow(X_test.reshape(-1, 8, 8)[i], cmap=plt.cm.binary,\n", " interpolation='nearest')\n", " \n", " # label the image with the target value\n", " if predicted[i] == expected[i]:\n", " ax.text(0, 7, str(predicted[i]), color='green')\n", " else:\n", " ax.text(0, 7, str(predicted[i]), color='red')" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Quantitative Measurement of Performance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'd like to measure the performance of our estimator without having to resort\n", "to plotting examples. A simple method might be to simply compare the number of\n", "matches:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "matches = (predicted == expected)\n", "print matches.sum()\n", "print len(matches)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "matches.sum() / float(len(matches))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that nearly 1500 of the 1800 predictions match the input. But there are other\n", "more sophisticated metrics that can be used to judge the performance of a classifier:\n", "several are available in the ``sklearn.metrics`` submodule.\n", "\n", "One of the most useful metrics is the ``classification_report``, which combines several\n", "measures and prints a table with the results:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn import metrics\n", "print metrics.classification_report(expected, predicted)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another enlightening metric for this sort of multi-label classification\n", "is a *confusion matrix*: it helps us visualize which labels are\n", "being interchanged in the classification errors:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "print metrics.confusion_matrix(expected, predicted)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see here that in particular, the numbers 1, 2, 3, and 9 are often being labeled 8." ] } ], "metadata": {} } ] }