{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 1, "metadata": {}, "source": [ "Supervised Learning: Classification of Handwritten Digits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section we'll apply scikit-learn to the classification of handwritten\n", "digits. This will go a bit beyond the iris classification we saw before: we'll\n", "discuss some of the metrics which can be used in evaluating the effectiveness\n", "of a classification model.\n", "\n", "We'll work with the handwritten digits dataset which we saw in an earlier\n", "section of the tutorial." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.datasets import load_digits\n", "digits = load_digits()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll re-use some of our code from before to visualize the data and remind us what\n", "we're looking at:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# copied from notebook 02_representation_of_data.ipynb\n", "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n", "\n", "# plot the digits: each image is 8x8 pixels\n", "for i in range(64):\n", " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", " ax.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')\n", " \n", " # label the image with the target value\n", " ax.text(0, 7, str(digits.target[i]))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Visualizing the Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A good first-step for many problems is to visualize the data using one of the\n", "*Dimensionality Reduction* techniques we saw earlier. We'll start with the\n", "most straightforward one, Principal Component Analysis (PCA).\n", "\n", "PCA seeks orthogonal linear combinations of the features which show the greatest\n", "variance, and as such, can help give you a good idea of the structure of the\n", "data set. Here we'll use `RandomizedPCA`, because it's faster for large `N`." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.decomposition import RandomizedPCA\n", "pca = RandomizedPCA(n_components=2)\n", "proj = pca.fit_transform(digits.data)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.scatter(proj[:, 0], proj[:, 1], c=digits.target)\n", "plt.colorbar()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we see that the digits do cluster fairly well, so we can expect even\n", "a fairly naive classification scheme to do a decent job separating them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question: Given these projections of the data, which numbers do you think\n", "a classifier might have trouble distinguishing?**" ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Gaussian Naive Bayes Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For most classification problems, it's nice to have a simple, fast, go-to\n", "method to provide a quick baseline classification. If the simple and fast\n", "method is sufficient, then we don't have to waste CPU cycles on more complex\n", "models. If not, we can use the results of the simple method to give us\n", "clues about our data.\n", "\n", "One good method to keep in mind is Gaussian Naive Bayes. It fits a Gaussian distribution to each training label independently on each feature, and uses this to quickly give a rough classification. It is generally not sufficiently accurate for real-world data, but can perform surprisingly well, for instance on text data." ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.cross_validation import train_test_split" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# split the data into training and validation sets\n", "X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)\n", "\n", "# train the model\n", "clf = GaussianNB()\n", "clf.fit(X_train, y_train)\n", "\n", "# use the model to predict the labels of the test data\n", "predicted = clf.predict(X_test)\n", "expected = y_test" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question**: why did we split the data into training and validation sets?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the digits again with the predicted labels to get an idea of\n", "how well the classification is working:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n", "\n", "# plot the digits: each image is 8x8 pixels\n", "for i in range(64):\n", " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", " ax.imshow(X_test.reshape(-1, 8, 8)[i], cmap=plt.cm.binary,\n", " interpolation='nearest')\n", " \n", " # label the image with the target value\n", " if predicted[i] == expected[i]:\n", " ax.text(0, 7, str(predicted[i]), color='green')\n", " else:\n", " ax.text(0, 7, str(predicted[i]), color='red')" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Quantitative Measurement of Performance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'd like to measure the performance of our estimator without having to resort\n", "to plotting examples. A simple method might be to simply compare the number of\n", "matches:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "matches = (predicted == expected)\n", "print(matches.sum())\n", "print(len(matches))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "matches.sum() / float(len(matches))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that nearly 1500 of the 1800 predictions match the input. But there are other\n", "more sophisticated metrics that can be used to judge the performance of a classifier:\n", "several are available in the ``sklearn.metrics`` submodule.\n", "\n", "One of the most useful metrics is the ``classification_report``, which combines several\n", "measures and prints a table with the results:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn import metrics\n", "print(metrics.classification_report(expected, predicted))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another enlightening metric for this sort of multi-label classification\n", "is a *confusion matrix*: it helps us visualize which labels are\n", "being interchanged in the classification errors:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "print(metrics.confusion_matrix(expected, predicted))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see here that in particular, the numbers 1, 2, 3, and 9 are often being labeled 8." ] } ], "metadata": {} } ] }