{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "> This is one of the 100 recipes of the [IPython Cookbook](http://ipython-books.github.io/), the definitive guide to high-performance scientific computing and data science in Python.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 8.3. Learning to recognize handwritten digits with a K-nearest neighbors classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. Let's do the traditional imports." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "import sklearn\n", "import sklearn.datasets as ds\n", "import sklearn.cross_validation as cv\n", "import sklearn.neighbors as nb\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. Let's load the digits dataset, part of the `datasets` module of scikit-learn. This dataset contains hand-written digits that have been manually labeled." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "digits = ds.load_digits()\n", "X = digits.data\n", "y = digits.target\n", "print((X.min(), X.max()))\n", "print(X.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the matrix `X`, each row contains the $8 \\times 8=64$ pixels (in grayscale, values between 0 and 16). The pixels are ordered according to the row-major order." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3. Let's display some of the images." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nrows, ncols = 2, 5\n", "plt.figure(figsize=(6,3));\n", "plt.gray()\n", "for i in range(ncols * nrows):\n", " ax = plt.subplot(nrows, ncols, i + 1)\n", " ax.matshow(digits.images[i,...])\n", " plt.xticks([]); plt.yticks([]);\n", " plt.title(digits.target[i]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "4. Now, let's fit a K-nearest neighbors classifier on the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "(X_train, X_test, \n", " y_train, y_test) = cv.train_test_split(X, y, test_size=.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "knc = nb.KNeighborsClassifier()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "knc.fit(X_train, y_train);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "5. Let's evaluate the score of the trained classifier on the test dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "knc.score(X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "6. Now, let's see if our classifier can recognize a \"hand-written\" digit!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Let's draw a 1.\n", "one = np.zeros((8, 8))\n", "one[1:-1, 4] = 16 # The image values are in [0, 16].\n", "one[2, 3] = 16" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(figsize=(2,2));\n", "plt.imshow(one, interpolation='none');\n", "plt.grid(False);\n", "plt.xticks(); plt.yticks();\n", "plt.title(\"One\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "knc.predict(one.ravel())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> You'll find all the explanations, figures, references, and much more in the book (to be released later this summer).\n", "\n", "> [IPython Cookbook](http://ipython-books.github.io/), by [Cyrille Rossant](http://cyrille.rossant.net), Packt Publishing, 2014 (500 pages)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.2" } }, "nbformat": 4, "nbformat_minor": 0 }