{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# k-Nearest Neighbor Methods and Image Data\n", "\n", "In this lab we will explore MNIST, a classic machine learning data set of images of handwritten digits (i.e., 0, 1, 2, 3, ...).\n", "In addition, we will investigate an intuitive, yet powerful learning method called k-nearest neighbors (KNN).\n", "Even though the type of data is different from what we've worked with so far, we'll see how to apply familiar tools to the data, namely, scikit learn and matplotlib for machine learning and plotting in python.\n", "\n", "Don't forget to fill out the [response form](https://docs.google.com/a/berkeley.edu/forms/d/188vgp28CXJrJ_qRpyGGBNwG0RjPvAfzirH9zLeSTMKo/viewform).\n", "\n", "And if you haven't already, fetch the data by running to following code (it will download into the `DATA_PATH` directory):" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%pylab inline\n", "import pylab\n", "from sklearn.datasets import fetch_mldata\n", "DATA_PATH = '~/data'\n", "mnist = fetch_mldata('MNIST original', data_home=DATA_PATH)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST\n", "\n", "
\n", " \n", "
\n", "\n", "The [MNIST database](http://yann.lecun.com/exdb/mnist/) of handwritten digits is a collection of labeled images that has been used to evaluate machine learning techniques since the '90's.\n", "The core application of the MNIST data is to train computer vision systems to recognize handwritten text.\n", "The post office, for example, is a major user of such systems---addresses on letters and packages are all photographed, read, and routed digitally, with only a few ambiguous cases verified by a human.\n", "\n", "The MNIST data set has also become a reliable benchmark for learning methods.\n", "It's small, but not tiny, and the data dimensionality (28x28 pixels) is big enough to cause some \"curse of dimensionality\" issues.\n", "Also, the problem is highly non-linear, meaning a linear classification methods (like linear regression, but for predicting discete categories) don't perform so well on the raw data.\n", "The MNIST [website](http://yann.lecun.com/exdb/mnist/) reports an extensive list of results obtained by different machine learning models, including neural nets, SVM, nearest neighbors, and others.\n", "\n", "The data consists of 60,000 training images and 10,000 test images.\n", "Each image is a 28x28 pixel, grayscale picture of a digit written either by a highschool student or an employee of the US Census Bureau.\n", "The images have all been preprocessed to be clean and regular: only one digit appears in each image, and it appears directly in the center of the image.\n", "\n", "The goal of the benchmark is to fit a model to the training set, and then use that model to predict which digit is in each of the test images.\n", "The best results achieve a classification error rate of less than half of one percent.\n", "This is often described as the \"human error rate,\" because if you ask people to classify the images, they will also find about 0.5% of them to be comepletely inscutable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### MNIST Data\n", "\n", "Let's take a look at the data.\n", "MNIST data is easy to load with the built in scikit learn `datasets` module.\n", "Make sure you've loaded the data set by running the code at the top of the lab.\n", "\n", "The image data itself is in `mnist.data`, and it's stored in a numpy [n-dimensional array](http://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html) (n=2 in this case).\n", "Numpy arrays are a vector/matrix data structure that provide high performance numerical computing.\n", "\n", "We can find out the dimensions of the array with its `shape` property:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "mnist.data.shape" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are 70,000 rows and 784 columns.\n", "So each row is an image (60,0000 training images plus 10,000 test ones gives 70,000 total), and each column is a pixel value (784 = 28 * 28).\n", "\n", "Like Pandas DataFrames, numpy arrays give us a simple interface to summary statistics and subsets of the data.\n", "\n", "Numpy arrays are indexed dimension-wise, with each dimension separated by a column:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "row = mnist.data[0,:] # First row of the array\n", "col = mnist.data[:,0] # First column of the array\n", "\n", "print row.shape\n", "print col.shape" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this syntax, the \":\" means \"ALL\", as in standard python indexing.\n", "All of the usual python range indexing syntax works for each dimension of the array.\n", "We can compute summary statistics, too:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "print row.sum(), row.max(), row.min()\n", "print col.sum(), col.max(), col.min()\n", "\n", "print mnist.data[:10,:] # First ten rows\n", "print mnist.data[:,-10:] # Last ten columns" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Let's divide the array into two sets, one for training images and one for test:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "train = mnist.data[:60000]\n", "test = mnist.data[60000:]" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we can drop the trailing \",:\" when we want to just index the first dimension.\n", "\n", "### DIY\n", "\n", "* To start, we want to work with just a sample of the training data. Create a sample consisting of every 100th image in `test`.\n", "\n", "* Find the mean value of the 300th column in the sample data set." ] }, { "cell_type": "code", "collapsed": false, "input": [ "test_sample = None # Fix me" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing Image Data\n", "\n", "One of the nicest things about image data is that it is naturally visualized and understood.\n", "First, let's take a look at the raw data in the first image in the data set:\n", "\n" ] }, { "cell_type": "code", "collapsed": false, "input": [ "img = mnist.data[0]\n", "print img" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are all the pixel values in the image.\n", "We can see some patterns (e.g., the edges are empty), but it's hard to interpret.\n", "In fact, we can can a much better view if we use the `imshow` method from matplotlib to display an image:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "pylab.imshow(img, cmap=\"Greys\")" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Annnnnd... it breaks.\n", "What went wrong?\n", "Since images are two dimensional objects, `imshow` expects a two dimensional array of data to plot, but the rows of our data array are flat vectors.\n", "Luckily, numpy arrays provide a `reshape` method that let's us change the dimensions of our data, (as long as we leave total length the same).\n", "Let's reshape our image data into a 28x28 pixel square and try again:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "pylab.imshow(img.reshape(28, 28), cmap=\"Greys\")" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's a zero!\n", "Now we're getting somewhere.\n", "\n", "### DIY\n", "\n", "* Use `imshow` to visualize a number of images from `sample`. What can you say about how the data set is ordered?" ] }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unsupervised KNN\n", "\n", "Next, we're interested in uncovering more structure in the MNIST data.\n", "For example, we want to be able to answer questions like \"how similar are people's handwriting?\" and \"how distinct are the different digits?\"\n", "If we get a sense of the variance in the data, and of how tighly it is clustered, we can begin to see a good approach to modeling.\n", "\n", "In the spirit of doing the simplest possible thing that might work, we can look at nearest neighbors using simple Euclidean distance between pixels.\n", "The assumption is that most digits look the same, so they should have similar values in individual pixels.\n", "Let's find out if this assumption is a good one.\n", "\n", "We can use the [`sklearn.neighbors`](http://scikit-learn.org/stable/modules/neighbors.html) module to compute nearest neighbors, in particular with the [`NearestNeighbors`](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors) class:\n" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "from sklearn.neighbors import NearestNeighbors\n", "model = NearestNeighbors(algorithm='brute').fit(train)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note how fast we built a nearest neighbors model, just a few microseconds!\n", "This is because we're using the brute force implementation (`algorithm='brute'`), which simple stores the training data to build a model, and does a full pairwise comparison at query time.\n", "\n", "Let's query our new model. We can fetch the *k* nearest neighbors with the `kneighbors` method:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "query_img = test[0]\n", "_, result = model.kneighbors(query_img, n_neighbors=4)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that query time is significant, even for a single image.\n", "Also, notice that `kneighbors` returns two values.\n", "The first, which we will ignore, are the distance values to the nearest neighbors.\n", "The second is a list of indices where we can look up the nearest neighbors in the training set.\n", "\n", "With the results, now we can see how we did." ] }, { "cell_type": "code", "collapsed": false, "input": [ "print result" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are four results, as expected.\n", "Let's print them out with the utility function below:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "# Display several images in a row\n", "def show(imgs, n=1):\n", " fig = pylab.figure()\n", " for i in xrange(0, n):\n", " fig.add_subplot(1, n, i, xticklabels=[], yticklabels=[])\n", " if n == 1:\n", " img = imgs\n", " else:\n", " img = imgs[i]\n", " pylab.imshow(img.reshape(28, 28), cmap=\"Greys\")" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "show(query_img)\n", "show(train[result[0],:], len(result[0]))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The neighbors look pretty good!\n", "Importantly, they are all zeros.\n", "That means that to some extent, at least, our assumption about images of the same digit being \"close\" to one another in pixel-space is a good one.\n", "\n", "### DIY\n", "\n", "* Use the nearest neighbors model to inspect results for other images in the test set. Do all of the digits seem to perform as well as \"0\" does? You may want to increase \"k\" to find where things break down." ] }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## KNN Classification\n", "\n", "We can validate our model in a more rigorous way by using it to predict digits.\n", "Scikit learn provides a class for supervised nearest neighbors fitting called [`KNeighborsClassifier`](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier).\n", "It is very similar to the `NearestNeighbors` class, but it accepts labels when fitting a model, and it provides methods for making label predictions for test data.\n", "\n", "The MNIST labels are in `mnist.target`.\n", "Let's split them into training and test sets as we did with the image data:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "train_labels = mnist.target[:60000]\n", "test_labels = mnist.target[60000:]\n", "test_labels_sample = test_labels[::100]" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, as before, we fit a model to the training data:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "from sklearn.neighbors import KNeighborsClassifier\n", "model = KNeighborsClassifier(n_neighbors=4, algorithm='brute').fit(train, train_labels)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DIY\n", "\n", "* Use the KNeighborsClassifier.score method to measure the model's classification accuracy on `test_sample`." ] }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "# Score the model!" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Next, visualize the nearest neighbors of cases where the model makes erroneous predictions" ] }, { "cell_type": "code", "collapsed": false, "input": [ "preds = model.predict(test_sample)\n", "errors = [i for i in xrange(0, len(test_sample)) if preds[i] != test_labels_sample[i]]\n", "\n", "for i in errors:\n", " pass # Visualize error image and its nearest neighbors" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Finally, let's get a more global view of classification errors. To do this, we can create a **confusion matix**, that counts how often each class is mistaken for each other class. Use [`sklearn.metric.confusion_matrix`](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html) to generate a confusion matrix between model predictions and test labels. Which pair of digits are confused most frequently?" ] }, { "cell_type": "code", "collapsed": false, "input": [ "test_sample = test[::10]\n", "test_labels_sample = test_labels[::10]\n", "\n", "def plot_cm(cm):\n", " pylab.matshow(np.log(cm))\n", "\n", "from sklearn.metrics import confusion_matrix\n", "# Compute and plot the confusion matrix for test_sample" ], "language": "python", "metadata": {}, "outputs": [] } ], "metadata": {} } ] }