{ "metadata": { "name": "10_digits_classification" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 1, "metadata": {}, "source": [ "Classification of Handwritten Digits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section we'll apply scikit-learn to the classification of handwritten\n", "digits. This will go a bit beyond the iris classification we saw before: we'll\n", "discuss some of the metrics which can be used in evaluating the effectiveness\n", "of a classification model, see an example of K-fold cross-validation, and\n", "present a more involved exercise.\n", "\n", "We'll work with the handwritten digits dataset which we saw in an earlier\n", "section of the tutorial." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.datasets import load_digits\n", "digits = load_digits()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll re-use some of our code from before to visualize the data and remind us what\n", "we're looking at:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%pylab inline" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# copied from notebook 02_sklearn_data.ipynb\n", "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n", "\n", "# plot the digits: each image is 8x8 pixels\n", "for i in range(64):\n", " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", " ax.imshow(digits.images[i], cmap=plt.cm.binary)\n", " \n", " # label the image with the target value\n", " ax.text(0, 7, str(digits.target[i]))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Visualizing the Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A good first-step for many problems is to visualize the data using one of the\n", "*Dimensionality Reduction* techniques we saw earlier. We'll start with the\n", "most straightforward one, Principal Component Analysis (PCA).\n", "\n", "PCA seeks orthogonal linear combinations of the features which show the greatest\n", "variance, and as such, can help give you a good idea of the structure of the\n", "data set. Here we'll use `RandomizedPCA`, because it's faster for large `N`." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.decomposition import RandomizedPCA\n", "pca = RandomizedPCA(n_components=2)\n", "proj = pca.fit_transform(digits.data)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.scatter(proj[:, 0], proj[:, 1], c=digits.target)\n", "plt.colorbar()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we see that the digits do cluster fairly well, so we can expect even\n", "a fairly naive classification scheme to do a decent job separating them.\n", "\n", "A weakness of PCA is that it produces a linear dimensionality reduction:\n", "this may miss some interesting relationships in the data. If we want to\n", "see a nonlinear mapping of the data, we can use one of the several\n", "methods in the `manifold` module. Here we'll use Isomap (a concatenation\n", "of Isometric Mapping) which is a manifold learning method based on\n", "graph theory:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.manifold import Isomap\n", "iso = Isomap(n_neighbors=5, n_components=2)\n", "proj = iso.fit_transform(digits.data)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.scatter(proj[:, 0], proj[:, 1], c=digits.target)\n", "plt.colorbar()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can be fun to explore the various manifold learning methods available,\n", "and how the output depends on the various parameters used to tune the\n", "projection.\n", "In any case, these visualizations show us that there is hope: even a simple\n", "classifier should be able to adequately identify the members of the various\n", "classes." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Gaussian Naive Bayes Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For most classification problems, it's nice to have a simple, fast, go-to\n", "method to provide a quick baseline classification. If the simple and fast\n", "method is sufficient, then we don't have to waste CPU cycles on more complex\n", "models. If not, we can use the results of the simple method to give us\n", "clues about our data.\n", "\n", "One good method to keep in mind is Gaussian Naive Bayes. It is a *generative*\n", "classifier which fits an axis-aligned multi-dimensional Gaussian distribution to\n", "each training label, and uses this to quickly give a rough classification. It\n", "is generally not sufficiently accurate for real-world data, but (especially in\n", "high dimensions) can perform surprisingly well." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.naive_bayes import GaussianNB\n", "from sklearn import cross_validation" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# split the data into training and validation sets\n", "data_train, data_test, target_train, target_test = cross_validation.train_test_split(digits.data, digits.target)\n", "\n", "# train the model\n", "clf = GaussianNB()\n", "clf.fit(data_train, target_train)\n", "\n", "# predict the labels of the test data\n", "predicted = clf.predict(data_test)\n", "expected = target_test" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously we've done something like\n", "\n", " print (predicted == expected)\n", "\n", "as a rough evaluation of our model. Here we'll do something more sophisticated:\n", "scikit-learn includes a ``metrics`` module which contains several metrics for\n", "evaluating classifiers like this. One of the most useful combines several of\n", "the metrics and prints a table showing the results:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn import metrics\n", "print metrics.classification_report(expected, predicted)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another enlightening metric for this sort of task is a\n", "*confusion matrix*: it helps us visualize which labels are\n", "being interchanged in the classification errors:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "print metrics.confusion_matrix(expected, predicted)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see here that in particular, a lot of twos are being mistakenly labeled eights.\n", "\n", "Previously we mentioned cross-validation. Here, rather than having a single training\n", "and test set, we divide the data into `K` subsets, and perform `K` different classifications,\n", "each time training on `K - 1` of the subsets and validating on the one left out.\n", "\n", "The tools to accomplish this are also in the `cross_validation` submodule." ] }, { "cell_type": "code", "collapsed": false, "input": [ "cv = cross_validation.KFold(digits.data.shape[0], 5, shuffle=True, random_state=0)\n", "\n", "clf = GaussianNB()\n", "\n", "print cross_validation.cross_val_score(clf, digits.data, digits.target, cv=cv)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These metrics show us that the simplistic Gaussian Naive Bayes classifier is giving us correct\n", "classifications of about 4 out of 5 digits. This is probably not sufficient: imagine the chaos\n", "at the post office if their zipcode scanning software misread one out of five digits!\n", "\n", "We can do better... but how?" ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Exercise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we come to the first major exercise of the tutorial, which will take some thought and\n", "creativity. It will combine and extend several of the ideas we've worked through today.\n", "The question is this:\n", "\n", "- **Given the digits data and the Gaussian Naive Bayes classifier, would it be better for us to\n", "invest effort in gathering more training samples, or seeking a more sophisticated\n", "model?**\n", "\n", "- **What is the best possible classification score you could expect from Gaussian Naive Bayes?**\n", "\n", "Some things to keep in mind:\n", "\n", "- we can use the ideas of *learning curves* to answer this question\n", "- previously, we used learning curves with RMS error in a regression model.\n", " Here, we'll have to decide on some suitable metric for classification.\n", "- We can compute the learning curves on classification **loss** (i.e. larger loss is worse)\n", " or classification **score** (i.e. larger score is better).\n", "- Classification is complicated by the fact that we generally have one score per class.\n", " Whether you use the (weighted) average of these, or take the min or max, depends on your\n", " ultimate goal.\n", "\n", "Once you're finished with this, you may wish to try out some other classifiers on the data\n", "and see how the results compare. Another extremely simple and fast classifier is\n", "``sklearn.tree.DecisionTreeClassifier``. Several more sophisticated (and therefore\n", "slower) classifiers can be found in the Support Vector Machines module, ``sklearn.svm``." ] } ], "metadata": {} } ] }